fsm: run steps in order #120

Merged
vtolstov merged 1 commits from fsm into v3 2025-01-02 23:56:08 +03:00
2 changed files with 44 additions and 22 deletions
Showing only changes of commit 08aaf14a79 - Show all commits

View File

@ -23,10 +23,10 @@ type Options struct {
} }
// HookBeforeFunc func signature // HookBeforeFunc func signature
type HookBeforeFunc func(ctx context.Context, state string, args map[string]interface{}) type HookBeforeFunc func(ctx context.Context, state string, args interface{})
// HookAfterFunc func signature // HookAfterFunc func signature
type HookAfterFunc func(ctx context.Context, state string, args map[string]interface{}) type HookAfterFunc func(ctx context.Context, state string, args interface{})
// Option func signature // Option func signature
type Option func(*Options) type Option func(*Options)
@ -53,14 +53,15 @@ func StateHookAfter(fns ...HookAfterFunc) Option {
} }
// StateFunc called on state transition and return next step and error // StateFunc called on state transition and return next step and error
type StateFunc func(ctx context.Context, args map[string]interface{}) (string, map[string]interface{}, error) type StateFunc func(ctx context.Context, args interface{}) (string, interface{}, error)
// FSM is a finite state machine // FSM is a finite state machine
type FSM struct { type FSM struct {
mu sync.Mutex mu sync.Mutex
states map[string]StateFunc statesMap map[string]StateFunc
opts *Options statesOrder []string
current string opts *Options
current string
} }
// New creates a new finite state machine having the specified initial state // New creates a new finite state machine having the specified initial state
@ -73,8 +74,8 @@ func New(opts ...Option) *FSM {
} }
return &FSM{ return &FSM{
states: map[string]StateFunc{}, statesMap: map[string]StateFunc{},
opts: options, opts: options,
} }
} }
@ -95,12 +96,13 @@ func (f *FSM) Reset() {
// State adds state to fsm // State adds state to fsm
func (f *FSM) State(state string, fn StateFunc) { func (f *FSM) State(state string, fn StateFunc) {
f.mu.Lock() f.mu.Lock()
f.states[state] = fn f.statesMap[state] = fn
f.statesOrder = append(f.statesOrder, state)
f.mu.Unlock() f.mu.Unlock()
} }
// Start runs state machine with provided data // Start runs state machine with provided data
func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Option) (map[string]interface{}, error) { func (f *FSM) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) {
var err error var err error
var ok bool var ok bool
var fn StateFunc var fn StateFunc
@ -114,8 +116,8 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op
} }
cstate := options.Initial cstate := options.Initial
states := make(map[string]StateFunc, len(f.states)) states := make(map[string]StateFunc, len(f.statesMap))
for k, v := range f.states { for k, v := range f.statesMap {
states[k] = v states[k] = v
} }
f.current = cstate f.current = cstate
@ -142,8 +144,14 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op
} }
if err != nil { if err != nil {
return args, err return args, err
} else if nstate == "" || nstate == StateEnd { } else if nstate == StateEnd {
return args, nil return args, nil
} else if nstate == "" {
for idx := range f.statesOrder {
if f.statesOrder[idx] == cstate && len(f.statesOrder) > idx+1 {
nstate = f.statesOrder[idx+1]
}
}
} }
cstate = nstate cstate = nstate
} }

View File

@ -10,40 +10,54 @@ import (
func TestFSMStart(t *testing.T) { func TestFSMStart(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
pfb := func(_ context.Context, state string, _ map[string]interface{}) { pfb := func(_ context.Context, state string, _ interface{}) {
fmt.Fprintf(buf, "before state %s\n", state) fmt.Fprintf(buf, "before state %s\n", state)
} }
pfa := func(_ context.Context, state string, _ map[string]interface{}) { pfa := func(_ context.Context, state string, _ interface{}) {
fmt.Fprintf(buf, "after state %s\n", state) fmt.Fprintf(buf, "after state %s\n", state)
} }
f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa)) f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa))
f1 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { f1 := func(_ context.Context, req interface{}) (string, interface{}, error) {
args := req.(map[string]interface{})
if v, ok := args["request"].(string); !ok || v == "" { if v, ok := args["request"].(string); !ok || v == "" {
return "", nil, fmt.Errorf("empty request") return "", nil, fmt.Errorf("empty request")
} }
return "2", map[string]interface{}{"response": "test2"}, nil return "2", map[string]interface{}{"response": "test2"}, nil
} }
f2 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { f2 := func(_ context.Context, req interface{}) (string, interface{}, error) {
args := req.(map[string]interface{})
if v, ok := args["response"].(string); !ok || v == "" { if v, ok := args["response"].(string); !ok || v == "" {
return "", nil, fmt.Errorf("empty response") return "", nil, fmt.Errorf("empty response")
} }
return "", map[string]interface{}{"response": "test"}, nil return "", map[string]interface{}{"response": "test"}, nil
} }
f3 := func(_ context.Context, req interface{}) (string, interface{}, error) {
args := req.(map[string]interface{})
if v, ok := args["response"].(string); !ok || v == "" {
return "", nil, fmt.Errorf("empty response")
}
return StateEnd, map[string]interface{}{"response": "test_last"}, nil
}
f.State("1", f1) f.State("1", f1)
f.State("2", f2) f.State("2", f2)
args, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) f.State("3", f3)
rsp, err := f.Start(ctx, map[string]interface{}{"request": "test1"})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} else if v, ok := args["response"].(string); !ok || v == "" { }
args := rsp.(map[string]interface{})
if v, ok := args["response"].(string); !ok || v == "" {
t.Fatalf("nil rsp: %#+v", args) t.Fatalf("nil rsp: %#+v", args)
} else if v != "test" { } else if v != "test_last" {
t.Fatalf("invalid rsp %#+v", args) t.Fatalf("invalid rsp %#+v", args)
} }
if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) || if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) ||
!bytes.Contains(buf.Bytes(), []byte(`before state 2`)) || !bytes.Contains(buf.Bytes(), []byte(`before state 2`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 1`)) || !bytes.Contains(buf.Bytes(), []byte(`after state 1`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 2`)) { !bytes.Contains(buf.Bytes(), []byte(`after state 2`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) {
t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes()) t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes())
} }
} }