diff --git a/fsm/fsm.go b/fsm/fsm.go new file mode 100644 index 00000000..364c4365 --- /dev/null +++ b/fsm/fsm.go @@ -0,0 +1,151 @@ +package fsm // import "go.unistack.org/micro/v3/fsm" + +import ( + "context" + "errors" + "fmt" + "sync" +) + +var ( + ErrInvalidState = errors.New("does not exists") + StateEnd = "end" +) + +// Options struct holding fsm options +type Options struct { + // Initial state + Initial string + // HooksBefore func slice runs in order before state + HooksBefore []HookBeforeFunc + // HooksAfter func slice runs in order after state + HooksAfter []HookAfterFunc +} + +// HookBeforeFunc func signature +type HookBeforeFunc func(ctx context.Context, state string, args map[string]interface{}) + +// HookAfterFunc func signature +type HookAfterFunc func(ctx context.Context, state string, args map[string]interface{}) + +// Option func signature +type Option func(*Options) + +// StateInitial sets init state for state machine +func StateInitial(initial string) Option { + return func(o *Options) { + o.Initial = initial + } +} + +// StateHookBefore provides hook func slice +func StateHookBefore(fns ...HookBeforeFunc) Option { + return func(o *Options) { + o.HooksBefore = fns + } +} + +// StateHookAfter provides hook func slice +func StateHookAfter(fns ...HookAfterFunc) Option { + return func(o *Options) { + o.HooksAfter = fns + } +} + +// StateFunc called on state transition and return next step and error +type StateFunc func(ctx context.Context, args map[string]interface{}) (string, map[string]interface{}, error) + +// FSM is a finite state machine +type FSM struct { + mu sync.Mutex + states map[string]StateFunc + opts *Options + current string +} + +// New creates a new finite state machine having the specified initial state +// with specified options +func New(opts ...Option) *FSM { + options := &Options{} + + for _, opt := range opts { + opt(options) + } + + return &FSM{ + states: map[string]StateFunc{}, + opts: options, + } +} + +// Current returns the current state +func (f *FSM) Current() string { + f.mu.Lock() + defer f.mu.Unlock() + return f.current +} + +// Current returns the current state +func (f *FSM) Reset() { + f.mu.Lock() + f.current = f.opts.Initial + f.mu.Unlock() +} + +// State adds state to fsm +func (f *FSM) State(state string, fn StateFunc) { + f.mu.Lock() + f.states[state] = fn + f.mu.Unlock() +} + +// Start runs state machine with provided data +func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Option) (map[string]interface{}, error) { + var err error + var ok bool + var fn StateFunc + var nstate string + + f.mu.Lock() + options := f.opts + + for _, opt := range opts { + opt(options) + } + + cstate := options.Initial + states := make(map[string]StateFunc, len(f.states)) + for k, v := range f.states { + states[k] = v + } + f.current = cstate + f.mu.Unlock() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + fn, ok = states[cstate] + if !ok { + return nil, fmt.Errorf(`state "%s" %w`, cstate, ErrInvalidState) + } + f.mu.Lock() + f.current = cstate + f.mu.Unlock() + for _, fn := range options.HooksBefore { + fn(ctx, cstate, args) + } + nstate, args, err = fn(ctx, args) + for _, fn := range options.HooksAfter { + fn(ctx, cstate, args) + } + if err != nil { + return args, err + } else if nstate == "" || nstate == StateEnd { + return args, nil + } + cstate = nstate + } + } +} diff --git a/fsm/fsm_test.go b/fsm/fsm_test.go new file mode 100644 index 00000000..0024db91 --- /dev/null +++ b/fsm/fsm_test.go @@ -0,0 +1,49 @@ +package fsm + +import ( + "bytes" + "context" + "fmt" + "testing" +) + +func TestFSMStart(t *testing.T) { + ctx := context.TODO() + buf := bytes.NewBuffer(nil) + pfb := func(_ context.Context, state string, _ map[string]interface{}) { + fmt.Fprintf(buf, "before state %s\n", state) + } + pfa := func(_ context.Context, state string, _ map[string]interface{}) { + fmt.Fprintf(buf, "after state %s\n", state) + } + f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa)) + f1 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { + if v, ok := args["request"].(string); !ok || v == "" { + return "", nil, fmt.Errorf("empty request") + } + return "2", map[string]interface{}{"response": "test2"}, nil + } + f2 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { + if v, ok := args["response"].(string); !ok || v == "" { + return "", nil, fmt.Errorf("empty response") + } + return "", map[string]interface{}{"response": "test"}, nil + } + f.State("1", f1) + f.State("2", f2) + args, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) + if err != nil { + t.Fatal(err) + } else if v, ok := args["response"].(string); !ok || v == "" { + t.Fatalf("nil rsp: %#+v", args) + } else if v != "test" { + t.Fatalf("invalid rsp %#+v", args) + } + + if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) || + !bytes.Contains(buf.Bytes(), []byte(`before state 2`)) || + !bytes.Contains(buf.Bytes(), []byte(`after state 1`)) || + !bytes.Contains(buf.Bytes(), []byte(`after state 2`)) { + t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes()) + } +}