commit
12f188e3ad
36
fsm/fsm.go
36
fsm/fsm.go
@ -23,10 +23,10 @@ type Options struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HookBeforeFunc func signature
|
// HookBeforeFunc func signature
|
||||||
type HookBeforeFunc func(ctx context.Context, state string, args map[string]interface{})
|
type HookBeforeFunc func(ctx context.Context, state string, args interface{})
|
||||||
|
|
||||||
// HookAfterFunc func signature
|
// HookAfterFunc func signature
|
||||||
type HookAfterFunc func(ctx context.Context, state string, args map[string]interface{})
|
type HookAfterFunc func(ctx context.Context, state string, args interface{})
|
||||||
|
|
||||||
// Option func signature
|
// Option func signature
|
||||||
type Option func(*Options)
|
type Option func(*Options)
|
||||||
@ -53,14 +53,15 @@ func StateHookAfter(fns ...HookAfterFunc) Option {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StateFunc called on state transition and return next step and error
|
// StateFunc called on state transition and return next step and error
|
||||||
type StateFunc func(ctx context.Context, args map[string]interface{}) (string, map[string]interface{}, error)
|
type StateFunc func(ctx context.Context, args interface{}) (string, interface{}, error)
|
||||||
|
|
||||||
// FSM is a finite state machine
|
// FSM is a finite state machine
|
||||||
type FSM struct {
|
type FSM struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
states map[string]StateFunc
|
statesMap map[string]StateFunc
|
||||||
opts *Options
|
statesOrder []string
|
||||||
current string
|
opts *Options
|
||||||
|
current string
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new finite state machine having the specified initial state
|
// New creates a new finite state machine having the specified initial state
|
||||||
@ -73,8 +74,8 @@ func New(opts ...Option) *FSM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &FSM{
|
return &FSM{
|
||||||
states: map[string]StateFunc{},
|
statesMap: map[string]StateFunc{},
|
||||||
opts: options,
|
opts: options,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,12 +96,13 @@ func (f *FSM) Reset() {
|
|||||||
// State adds state to fsm
|
// State adds state to fsm
|
||||||
func (f *FSM) State(state string, fn StateFunc) {
|
func (f *FSM) State(state string, fn StateFunc) {
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
f.states[state] = fn
|
f.statesMap[state] = fn
|
||||||
|
f.statesOrder = append(f.statesOrder, state)
|
||||||
f.mu.Unlock()
|
f.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start runs state machine with provided data
|
// Start runs state machine with provided data
|
||||||
func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Option) (map[string]interface{}, error) {
|
func (f *FSM) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) {
|
||||||
var err error
|
var err error
|
||||||
var ok bool
|
var ok bool
|
||||||
var fn StateFunc
|
var fn StateFunc
|
||||||
@ -114,8 +116,8 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op
|
|||||||
}
|
}
|
||||||
|
|
||||||
cstate := options.Initial
|
cstate := options.Initial
|
||||||
states := make(map[string]StateFunc, len(f.states))
|
states := make(map[string]StateFunc, len(f.statesMap))
|
||||||
for k, v := range f.states {
|
for k, v := range f.statesMap {
|
||||||
states[k] = v
|
states[k] = v
|
||||||
}
|
}
|
||||||
f.current = cstate
|
f.current = cstate
|
||||||
@ -142,8 +144,14 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return args, err
|
return args, err
|
||||||
} else if nstate == "" || nstate == StateEnd {
|
} else if nstate == StateEnd {
|
||||||
return args, nil
|
return args, nil
|
||||||
|
} else if nstate == "" {
|
||||||
|
for idx := range f.statesOrder {
|
||||||
|
if f.statesOrder[idx] == cstate && len(f.statesOrder) > idx+1 {
|
||||||
|
nstate = f.statesOrder[idx+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cstate = nstate
|
cstate = nstate
|
||||||
}
|
}
|
||||||
|
@ -10,40 +10,54 @@ import (
|
|||||||
func TestFSMStart(t *testing.T) {
|
func TestFSMStart(t *testing.T) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
buf := bytes.NewBuffer(nil)
|
buf := bytes.NewBuffer(nil)
|
||||||
pfb := func(_ context.Context, state string, _ map[string]interface{}) {
|
pfb := func(_ context.Context, state string, _ interface{}) {
|
||||||
fmt.Fprintf(buf, "before state %s\n", state)
|
fmt.Fprintf(buf, "before state %s\n", state)
|
||||||
}
|
}
|
||||||
pfa := func(_ context.Context, state string, _ map[string]interface{}) {
|
pfa := func(_ context.Context, state string, _ interface{}) {
|
||||||
fmt.Fprintf(buf, "after state %s\n", state)
|
fmt.Fprintf(buf, "after state %s\n", state)
|
||||||
}
|
}
|
||||||
f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa))
|
f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa))
|
||||||
f1 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) {
|
f1 := func(_ context.Context, req interface{}) (string, interface{}, error) {
|
||||||
|
args := req.(map[string]interface{})
|
||||||
if v, ok := args["request"].(string); !ok || v == "" {
|
if v, ok := args["request"].(string); !ok || v == "" {
|
||||||
return "", nil, fmt.Errorf("empty request")
|
return "", nil, fmt.Errorf("empty request")
|
||||||
}
|
}
|
||||||
return "2", map[string]interface{}{"response": "test2"}, nil
|
return "2", map[string]interface{}{"response": "test2"}, nil
|
||||||
}
|
}
|
||||||
f2 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) {
|
f2 := func(_ context.Context, req interface{}) (string, interface{}, error) {
|
||||||
|
args := req.(map[string]interface{})
|
||||||
if v, ok := args["response"].(string); !ok || v == "" {
|
if v, ok := args["response"].(string); !ok || v == "" {
|
||||||
return "", nil, fmt.Errorf("empty response")
|
return "", nil, fmt.Errorf("empty response")
|
||||||
}
|
}
|
||||||
return "", map[string]interface{}{"response": "test"}, nil
|
return "", map[string]interface{}{"response": "test"}, nil
|
||||||
}
|
}
|
||||||
|
f3 := func(_ context.Context, req interface{}) (string, interface{}, error) {
|
||||||
|
args := req.(map[string]interface{})
|
||||||
|
if v, ok := args["response"].(string); !ok || v == "" {
|
||||||
|
return "", nil, fmt.Errorf("empty response")
|
||||||
|
}
|
||||||
|
return StateEnd, map[string]interface{}{"response": "test_last"}, nil
|
||||||
|
}
|
||||||
f.State("1", f1)
|
f.State("1", f1)
|
||||||
f.State("2", f2)
|
f.State("2", f2)
|
||||||
args, err := f.Start(ctx, map[string]interface{}{"request": "test1"})
|
f.State("3", f3)
|
||||||
|
rsp, err := f.Start(ctx, map[string]interface{}{"request": "test1"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if v, ok := args["response"].(string); !ok || v == "" {
|
}
|
||||||
|
args := rsp.(map[string]interface{})
|
||||||
|
if v, ok := args["response"].(string); !ok || v == "" {
|
||||||
t.Fatalf("nil rsp: %#+v", args)
|
t.Fatalf("nil rsp: %#+v", args)
|
||||||
} else if v != "test" {
|
} else if v != "test_last" {
|
||||||
t.Fatalf("invalid rsp %#+v", args)
|
t.Fatalf("invalid rsp %#+v", args)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) ||
|
if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) ||
|
||||||
!bytes.Contains(buf.Bytes(), []byte(`before state 2`)) ||
|
!bytes.Contains(buf.Bytes(), []byte(`before state 2`)) ||
|
||||||
!bytes.Contains(buf.Bytes(), []byte(`after state 1`)) ||
|
!bytes.Contains(buf.Bytes(), []byte(`after state 1`)) ||
|
||||||
!bytes.Contains(buf.Bytes(), []byte(`after state 2`)) {
|
!bytes.Contains(buf.Bytes(), []byte(`after state 2`)) ||
|
||||||
|
!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) ||
|
||||||
|
!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) {
|
||||||
t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes())
|
t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user