fsm: improve and convert to interface

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2023-01-30 00:17:29 +03:00
parent 4debc392d1
commit d18952951c
6 changed files with 273 additions and 204 deletions

126
fsm/default.go Normal file
View File

@ -0,0 +1,126 @@
package fsm
import (
"context"
"fmt"
"sync"
)
type state struct {
body interface{}
name string
}
var _ State = &state{}
func (s *state) Name() string {
return s.name
}
func (s *state) Body() interface{} {
return s.body
}
// fsm is a finite state machine
type fsm struct {
statesMap map[string]StateFunc
current string
statesOrder []string
opts Options
mu sync.Mutex
}
// NewFSM creates a new finite state machine having the specified initial state
// with specified options
func NewFSM(opts ...Option) *fsm {
return &fsm{
statesMap: map[string]StateFunc{},
opts: NewOptions(opts...),
}
}
// Current returns the current state
func (f *fsm) Current() string {
f.mu.Lock()
s := f.current
f.mu.Unlock()
return s
}
// Current returns the current state
func (f *fsm) Reset() {
f.mu.Lock()
f.current = f.opts.Initial
f.mu.Unlock()
}
// State adds state to fsm
func (f *fsm) State(state string, fn StateFunc) {
f.mu.Lock()
f.statesMap[state] = fn
f.statesOrder = append(f.statesOrder, state)
f.mu.Unlock()
}
// Start runs state machine with provided data
func (f *fsm) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) {
var err error
f.mu.Lock()
options := f.opts
for _, opt := range opts {
opt(&options)
}
sopts := []StateOption{StateDryRun(options.DryRun)}
cstate := options.Initial
states := make(map[string]StateFunc, len(f.statesMap))
for k, v := range f.statesMap {
states[k] = v
}
f.current = cstate
f.mu.Unlock()
var s State
s = &state{name: cstate, body: args}
nstate := s.Name()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
fn, ok := states[nstate]
if !ok {
return nil, fmt.Errorf(`state "%s" %w`, nstate, ErrInvalidState)
}
f.mu.Lock()
f.current = nstate
f.mu.Unlock()
// wrap the handler func
for i := len(options.Wrappers); i > 0; i-- {
fn = options.Wrappers[i-1](fn)
}
s, err = fn(ctx, s, sopts...)
switch {
case err != nil:
return s.Body(), err
case s.Name() == StateEnd:
return s.Body(), nil
case s.Name() == "":
for idx := range f.statesOrder {
if f.statesOrder[idx] == nstate && len(f.statesOrder) > idx+1 {
nstate = f.statesOrder[idx+1]
}
}
default:
nstate = s.Name()
}
}
}
}

View File

@ -3,8 +3,6 @@ package fsm // import "go.unistack.org/micro/v3/fsm"
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync"
) )
var ( var (
@ -12,170 +10,20 @@ var (
StateEnd = "end" StateEnd = "end"
) )
// Options struct holding fsm options type State interface {
type Options struct { Name() string
// DryRun mode Body() interface{}
DryRun bool
// Initial state
Initial string
// HooksBefore func slice runs in order before state
HooksBefore []HookBeforeFunc
// HooksAfter func slice runs in order after state
HooksAfter []HookAfterFunc
} }
// HookBeforeFunc func signature // StateWrapper wraps the StateFunc and returns the equivalent
type HookBeforeFunc func(ctx context.Context, state string, args interface{}) type StateWrapper func(StateFunc) StateFunc
// HookAfterFunc func signature
type HookAfterFunc func(ctx context.Context, state string, args interface{})
// Option func signature
type Option func(*Options)
// StateOptions holds state options
type StateOptions struct {
DryRun bool
}
// StateDryRun says that state executes in dry run mode
func StateDryRun(b bool) StateOption {
return func(o *StateOptions) {
o.DryRun = b
}
}
// StateOption func signature
type StateOption func(*StateOptions)
// InitialState sets init state for state machine
func InitialState(initial string) Option {
return func(o *Options) {
o.Initial = initial
}
}
// HookBefore provides hook func slice
func HookBefore(fns ...HookBeforeFunc) Option {
return func(o *Options) {
o.HooksBefore = fns
}
}
// HookAfter provides hook func slice
func HookAfter(fns ...HookAfterFunc) Option {
return func(o *Options) {
o.HooksAfter = fns
}
}
// StateFunc called on state transition and return next step and error // StateFunc called on state transition and return next step and error
type StateFunc func(ctx context.Context, args interface{}, opts ...StateOption) (string, interface{}, error) type StateFunc func(ctx context.Context, state State, opts ...StateOption) (State, error)
// FSM is a finite state machine type FSM interface {
type FSM struct { Start(context.Context, interface{}, ...Option) (interface{}, error)
mu sync.Mutex Current() string
statesMap map[string]StateFunc Reset()
statesOrder []string State(string, StateFunc)
opts *Options
current string
}
// New creates a new finite state machine having the specified initial state
// with specified options
func New(opts ...Option) *FSM {
options := &Options{}
for _, opt := range opts {
opt(options)
}
return &FSM{
statesMap: map[string]StateFunc{},
opts: options,
}
}
// Current returns the current state
func (f *FSM) Current() string {
f.mu.Lock()
defer f.mu.Unlock()
return f.current
}
// Current returns the current state
func (f *FSM) Reset() {
f.mu.Lock()
f.current = f.opts.Initial
f.mu.Unlock()
}
// State adds state to fsm
func (f *FSM) State(state string, fn StateFunc) {
f.mu.Lock()
f.statesMap[state] = fn
f.statesOrder = append(f.statesOrder, state)
f.mu.Unlock()
}
// Init initialize fsm and check states
// Start runs state machine with provided data
func (f *FSM) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) {
var err error
var ok bool
var fn StateFunc
var nstate string
f.mu.Lock()
options := f.opts
for _, opt := range opts {
opt(options)
}
sopts := []StateOption{StateDryRun(options.DryRun)}
cstate := options.Initial
states := make(map[string]StateFunc, len(f.statesMap))
for k, v := range f.statesMap {
states[k] = v
}
f.current = cstate
f.mu.Unlock()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
fn, ok = states[cstate]
if !ok {
return nil, fmt.Errorf(`state "%s" %w`, cstate, ErrInvalidState)
}
f.mu.Lock()
f.current = cstate
f.mu.Unlock()
for _, fn := range options.HooksBefore {
fn(ctx, cstate, args)
}
nstate, args, err = fn(ctx, args, sopts...)
for _, fn := range options.HooksAfter {
fn(ctx, cstate, args)
}
switch {
case err != nil:
return args, err
case nstate == StateEnd:
return args, nil
case nstate == "":
for idx := range f.statesOrder {
if f.statesOrder[idx] == cstate && len(f.statesOrder) > idx+1 {
nstate = f.statesOrder[idx+1]
}
}
}
cstate = nstate
}
}
} }

View File

@ -1,63 +1,72 @@
package fsm package fsm
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"testing" "testing"
"go.unistack.org/micro/v3/logger"
) )
func TestFSMStart(t *testing.T) { func TestFSMStart(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
buf := bytes.NewBuffer(nil)
pfb := func(_ context.Context, state string, _ interface{}) { if err := logger.DefaultLogger.Init(); err != nil {
fmt.Fprintf(buf, "before state %s\n", state) t.Fatal(err)
} }
pfa := func(_ context.Context, state string, _ interface{}) {
fmt.Fprintf(buf, "after state %s\n", state) wrapper := func(next StateFunc) StateFunc {
return func(sctx context.Context, s State, opts ...StateOption) (State, error) {
sctx = logger.NewContext(sctx, logger.Fields("state", s.Name()))
return next(sctx, s, opts...)
} }
f := New(InitialState("1"), HookBefore(pfb), HookAfter(pfa)) }
f1 := func(_ context.Context, req interface{}, _ ...StateOption) (string, interface{}, error) {
args := req.(map[string]interface{}) f := NewFSM(InitialState("1"), WrapState(wrapper))
f1 := func(sctx context.Context, s State, opts ...StateOption) (State, error) {
_, ok := logger.FromContext(sctx)
if !ok {
t.Fatal("f1 context does not have logger")
}
args := s.Body().(map[string]interface{})
if v, ok := args["request"].(string); !ok || v == "" { if v, ok := args["request"].(string); !ok || v == "" {
return "", nil, fmt.Errorf("empty request") return nil, fmt.Errorf("empty request")
} }
return "2", map[string]interface{}{"response": "test2"}, nil return &state{name: "", body: map[string]interface{}{"response": "state1"}}, nil
} }
f2 := func(_ context.Context, req interface{}, _ ...StateOption) (string, interface{}, error) { f2 := func(sctx context.Context, s State, opts ...StateOption) (State, error) {
args := req.(map[string]interface{}) _, ok := logger.FromContext(sctx)
if !ok {
t.Fatal("f2 context does not have logger")
}
args := s.Body().(map[string]interface{})
if v, ok := args["response"].(string); !ok || v == "" { if v, ok := args["response"].(string); !ok || v == "" {
return "", nil, fmt.Errorf("empty response") return nil, fmt.Errorf("empty response")
} }
return "", map[string]interface{}{"response": "test"}, nil return &state{name: "", body: map[string]interface{}{"response": "state2"}}, nil
} }
f3 := func(_ context.Context, req interface{}, _ ...StateOption) (string, interface{}, error) { f3 := func(sctx context.Context, s State, opts ...StateOption) (State, error) {
args := req.(map[string]interface{}) _, ok := logger.FromContext(sctx)
if !ok {
t.Fatal("f3 context does not have logger")
}
args := s.Body().(map[string]interface{})
if v, ok := args["response"].(string); !ok || v == "" { if v, ok := args["response"].(string); !ok || v == "" {
return "", nil, fmt.Errorf("empty response") return nil, fmt.Errorf("empty response")
} }
return StateEnd, map[string]interface{}{"response": "test_last"}, nil return &state{name: StateEnd, body: map[string]interface{}{"response": "state3"}}, nil
} }
f.State("1", f1) f.State("1", f1)
f.State("2", f2) f.State("2", f2)
f.State("3", f3) f.State("3", f3)
rsp, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) rsp, err := f.Start(ctx, map[string]interface{}{"request": "state"})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
args := rsp.(map[string]interface{}) args := rsp.(map[string]interface{})
if v, ok := args["response"].(string); !ok || v == "" { if v, ok := args["response"].(string); !ok || v == "" {
t.Fatalf("nil rsp: %#+v", args) t.Fatalf("nil rsp: %#+v", args)
} else if v != "test_last" { } else if v != "state3" {
t.Fatalf("invalid rsp %#+v", args) t.Fatalf("invalid rsp %#+v", args)
} }
if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) ||
!bytes.Contains(buf.Bytes(), []byte(`before state 2`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 1`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 2`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) ||
!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) {
t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes())
}
} }

52
fsm/options.go Normal file
View File

@ -0,0 +1,52 @@
package fsm
// Options struct holding fsm options
type Options struct {
// Initial state
Initial string
// Wrappers runs before state
Wrappers []StateWrapper
// DryRun mode
DryRun bool
}
// Option func signature
type Option func(*Options)
// StateOptions holds state options
type StateOptions struct {
DryRun bool
}
// StateDryRun says that state executes in dry run mode
func StateDryRun(b bool) StateOption {
return func(o *StateOptions) {
o.DryRun = b
}
}
// StateOption func signature
type StateOption func(*StateOptions)
// InitialState sets init state for state machine
func InitialState(initial string) Option {
return func(o *Options) {
o.Initial = initial
}
}
// WrapState adds a state Wrapper to a list of options passed into the fsm
func WrapState(w StateWrapper) Option {
return func(o *Options) {
o.Wrappers = append(o.Wrappers, w)
}
}
// NewOptions returns new Options struct filled by passed Option
func NewOptions(opts ...Option) Options {
options := Options{}
for _, o := range opts {
o(&options)
}
return options
}

View File

@ -75,15 +75,23 @@ func (l *defaultLogger) Level(level Level) {
} }
func (l *defaultLogger) Fields(fields ...interface{}) Logger { func (l *defaultLogger) Fields(fields ...interface{}) Logger {
l.RLock()
nl := &defaultLogger{opts: l.opts, enc: l.enc} nl := &defaultLogger{opts: l.opts, enc: l.enc}
if len(fields) == 0 { if len(fields) == 0 {
l.RUnlock()
return nl return nl
} else if len(fields)%2 != 0 { } else if len(fields)%2 != 0 {
fields = fields[:len(fields)-1] fields = fields[:len(fields)-1]
} }
nl.logFunc = l.logFunc nl.logFunc = nl.Log
nl.logfFunc = l.logfFunc nl.logfFunc = nl.Logf
for i := len(nl.opts.Wrappers); i > 0; i-- {
nl.logFunc = nl.opts.Wrappers[i-1].Log(nl.logFunc)
nl.logfFunc = nl.opts.Wrappers[i-1].Logf(nl.logfFunc)
}
nl.opts.Fields = copyFields(l.opts.Fields)
nl.opts.Fields = append(nl.opts.Fields, fields...) nl.opts.Fields = append(nl.opts.Fields, fields...)
l.RUnlock()
return nl return nl
} }
@ -118,27 +126,27 @@ func logCallerfilePath(loggingFilePath string) string {
} }
func (l *defaultLogger) Info(ctx context.Context, args ...interface{}) { func (l *defaultLogger) Info(ctx context.Context, args ...interface{}) {
l.Log(ctx, InfoLevel, args...) l.logFunc(ctx, InfoLevel, args...)
} }
func (l *defaultLogger) Error(ctx context.Context, args ...interface{}) { func (l *defaultLogger) Error(ctx context.Context, args ...interface{}) {
l.Log(ctx, ErrorLevel, args...) l.logFunc(ctx, ErrorLevel, args...)
} }
func (l *defaultLogger) Debug(ctx context.Context, args ...interface{}) { func (l *defaultLogger) Debug(ctx context.Context, args ...interface{}) {
l.Log(ctx, DebugLevel, args...) l.logFunc(ctx, DebugLevel, args...)
} }
func (l *defaultLogger) Warn(ctx context.Context, args ...interface{}) { func (l *defaultLogger) Warn(ctx context.Context, args ...interface{}) {
l.Log(ctx, WarnLevel, args...) l.logFunc(ctx, WarnLevel, args...)
} }
func (l *defaultLogger) Trace(ctx context.Context, args ...interface{}) { func (l *defaultLogger) Trace(ctx context.Context, args ...interface{}) {
l.Log(ctx, TraceLevel, args...) l.logFunc(ctx, TraceLevel, args...)
} }
func (l *defaultLogger) Fatal(ctx context.Context, args ...interface{}) { func (l *defaultLogger) Fatal(ctx context.Context, args ...interface{}) {
l.Log(ctx, FatalLevel, args...) l.logFunc(ctx, FatalLevel, args...)
os.Exit(1) os.Exit(1)
} }

View File

@ -32,7 +32,33 @@ func TestFields(t *testing.T) {
if err := l.Init(); err != nil { if err := l.Init(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
l.Fields("key", "val").Info(ctx, "message")
nl := l.Fields("key", "val")
nl.Info(ctx, "message")
if !bytes.Contains(buf.Bytes(), []byte(`"key":"val"`)) {
t.Fatalf("logger fields not works, buf contains: %s", buf.Bytes())
}
}
func TestFromContextWithFields(t *testing.T) {
ctx := context.TODO()
buf := bytes.NewBuffer(nil)
var ok bool
l := NewLogger(WithLevel(TraceLevel), WithOutput(buf))
if err := l.Init(); err != nil {
t.Fatal(err)
}
nl := l.Fields("key", "val")
ctx = NewContext(ctx, nl)
l, ok = FromContext(ctx)
if !ok {
t.Fatalf("context does not have logger")
}
l.Info(ctx, "message")
if !bytes.Contains(buf.Bytes(), []byte(`"key":"val"`)) { if !bytes.Contains(buf.Bytes(), []byte(`"key":"val"`)) {
t.Fatalf("logger fields not works, buf contains: %s", buf.Bytes()) t.Fatalf("logger fields not works, buf contains: %s", buf.Bytes())
} }