diff --git a/fsm/fsm.go b/fsm/fsm.go index 364c4365..c0c72e73 100644 --- a/fsm/fsm.go +++ b/fsm/fsm.go @@ -23,10 +23,10 @@ type Options struct { } // HookBeforeFunc func signature -type HookBeforeFunc func(ctx context.Context, state string, args map[string]interface{}) +type HookBeforeFunc func(ctx context.Context, state string, args interface{}) // HookAfterFunc func signature -type HookAfterFunc func(ctx context.Context, state string, args map[string]interface{}) +type HookAfterFunc func(ctx context.Context, state string, args interface{}) // Option func signature type Option func(*Options) @@ -53,14 +53,15 @@ func StateHookAfter(fns ...HookAfterFunc) Option { } // StateFunc called on state transition and return next step and error -type StateFunc func(ctx context.Context, args map[string]interface{}) (string, map[string]interface{}, error) +type StateFunc func(ctx context.Context, args interface{}) (string, interface{}, error) // FSM is a finite state machine type FSM struct { - mu sync.Mutex - states map[string]StateFunc - opts *Options - current string + mu sync.Mutex + statesMap map[string]StateFunc + statesOrder []string + opts *Options + current string } // New creates a new finite state machine having the specified initial state @@ -73,8 +74,8 @@ func New(opts ...Option) *FSM { } return &FSM{ - states: map[string]StateFunc{}, - opts: options, + statesMap: map[string]StateFunc{}, + opts: options, } } @@ -95,12 +96,13 @@ func (f *FSM) Reset() { // State adds state to fsm func (f *FSM) State(state string, fn StateFunc) { f.mu.Lock() - f.states[state] = fn + f.statesMap[state] = fn + f.statesOrder = append(f.statesOrder, state) f.mu.Unlock() } // Start runs state machine with provided data -func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Option) (map[string]interface{}, error) { +func (f *FSM) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) { var err error var ok bool var fn StateFunc @@ -114,8 +116,8 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op } cstate := options.Initial - states := make(map[string]StateFunc, len(f.states)) - for k, v := range f.states { + states := make(map[string]StateFunc, len(f.statesMap)) + for k, v := range f.statesMap { states[k] = v } f.current = cstate @@ -142,8 +144,14 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op } if err != nil { return args, err - } else if nstate == "" || nstate == StateEnd { + } else if nstate == StateEnd { return args, nil + } else if nstate == "" { + for idx := range f.statesOrder { + if f.statesOrder[idx] == cstate && len(f.statesOrder) > idx+1 { + nstate = f.statesOrder[idx+1] + } + } } cstate = nstate } diff --git a/fsm/fsm_test.go b/fsm/fsm_test.go index 0024db91..57202f5f 100644 --- a/fsm/fsm_test.go +++ b/fsm/fsm_test.go @@ -10,40 +10,54 @@ import ( func TestFSMStart(t *testing.T) { ctx := context.TODO() buf := bytes.NewBuffer(nil) - pfb := func(_ context.Context, state string, _ map[string]interface{}) { + pfb := func(_ context.Context, state string, _ interface{}) { fmt.Fprintf(buf, "before state %s\n", state) } - pfa := func(_ context.Context, state string, _ map[string]interface{}) { + pfa := func(_ context.Context, state string, _ interface{}) { fmt.Fprintf(buf, "after state %s\n", state) } f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa)) - f1 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { + f1 := func(_ context.Context, req interface{}) (string, interface{}, error) { + args := req.(map[string]interface{}) if v, ok := args["request"].(string); !ok || v == "" { return "", nil, fmt.Errorf("empty request") } return "2", map[string]interface{}{"response": "test2"}, nil } - f2 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { + f2 := func(_ context.Context, req interface{}) (string, interface{}, error) { + args := req.(map[string]interface{}) if v, ok := args["response"].(string); !ok || v == "" { return "", nil, fmt.Errorf("empty response") } return "", map[string]interface{}{"response": "test"}, nil } + f3 := func(_ context.Context, req interface{}) (string, interface{}, error) { + args := req.(map[string]interface{}) + if v, ok := args["response"].(string); !ok || v == "" { + return "", nil, fmt.Errorf("empty response") + } + return StateEnd, map[string]interface{}{"response": "test_last"}, nil + } f.State("1", f1) f.State("2", f2) - args, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) + f.State("3", f3) + rsp, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) if err != nil { t.Fatal(err) - } else if v, ok := args["response"].(string); !ok || v == "" { + } + args := rsp.(map[string]interface{}) + if v, ok := args["response"].(string); !ok || v == "" { t.Fatalf("nil rsp: %#+v", args) - } else if v != "test" { + } else if v != "test_last" { t.Fatalf("invalid rsp %#+v", args) } if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) || !bytes.Contains(buf.Bytes(), []byte(`before state 2`)) || !bytes.Contains(buf.Bytes(), []byte(`after state 1`)) || - !bytes.Contains(buf.Bytes(), []byte(`after state 2`)) { + !bytes.Contains(buf.Bytes(), []byte(`after state 2`)) || + !bytes.Contains(buf.Bytes(), []byte(`after state 3`)) || + !bytes.Contains(buf.Bytes(), []byte(`after state 3`)) { t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes()) } }