fsm: run steps in order
Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
		
							
								
								
									
										28
									
								
								fsm/fsm.go
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								fsm/fsm.go
									
									
									
									
									
								
							| @@ -23,10 +23,10 @@ type Options struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // HookBeforeFunc func signature | // HookBeforeFunc func signature | ||||||
| type HookBeforeFunc func(ctx context.Context, state string, args map[string]interface{}) | type HookBeforeFunc func(ctx context.Context, state string, args interface{}) | ||||||
|  |  | ||||||
| // HookAfterFunc func signature | // HookAfterFunc func signature | ||||||
| type HookAfterFunc func(ctx context.Context, state string, args map[string]interface{}) | type HookAfterFunc func(ctx context.Context, state string, args interface{}) | ||||||
|  |  | ||||||
| // Option func signature | // Option func signature | ||||||
| type Option func(*Options) | type Option func(*Options) | ||||||
| @@ -53,12 +53,13 @@ func StateHookAfter(fns ...HookAfterFunc) Option { | |||||||
| } | } | ||||||
|  |  | ||||||
| // StateFunc called on state transition and return next step and error | // StateFunc called on state transition and return next step and error | ||||||
| type StateFunc func(ctx context.Context, args map[string]interface{}) (string, map[string]interface{}, error) | type StateFunc func(ctx context.Context, args interface{}) (string, interface{}, error) | ||||||
|  |  | ||||||
| // FSM is a finite state machine | // FSM is a finite state machine | ||||||
| type FSM struct { | type FSM struct { | ||||||
| 	mu          sync.Mutex | 	mu          sync.Mutex | ||||||
| 	states  map[string]StateFunc | 	statesMap   map[string]StateFunc | ||||||
|  | 	statesOrder []string | ||||||
| 	opts        *Options | 	opts        *Options | ||||||
| 	current     string | 	current     string | ||||||
| } | } | ||||||
| @@ -73,7 +74,7 @@ func New(opts ...Option) *FSM { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return &FSM{ | 	return &FSM{ | ||||||
| 		states: map[string]StateFunc{}, | 		statesMap: map[string]StateFunc{}, | ||||||
| 		opts:      options, | 		opts:      options, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -95,12 +96,13 @@ func (f *FSM) Reset() { | |||||||
| // State adds state to fsm | // State adds state to fsm | ||||||
| func (f *FSM) State(state string, fn StateFunc) { | func (f *FSM) State(state string, fn StateFunc) { | ||||||
| 	f.mu.Lock() | 	f.mu.Lock() | ||||||
| 	f.states[state] = fn | 	f.statesMap[state] = fn | ||||||
|  | 	f.statesOrder = append(f.statesOrder, state) | ||||||
| 	f.mu.Unlock() | 	f.mu.Unlock() | ||||||
| } | } | ||||||
|  |  | ||||||
| // Start runs state machine with provided data | // Start runs state machine with provided data | ||||||
| func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Option) (map[string]interface{}, error) { | func (f *FSM) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) { | ||||||
| 	var err error | 	var err error | ||||||
| 	var ok bool | 	var ok bool | ||||||
| 	var fn StateFunc | 	var fn StateFunc | ||||||
| @@ -114,8 +116,8 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cstate := options.Initial | 	cstate := options.Initial | ||||||
| 	states := make(map[string]StateFunc, len(f.states)) | 	states := make(map[string]StateFunc, len(f.statesMap)) | ||||||
| 	for k, v := range f.states { | 	for k, v := range f.statesMap { | ||||||
| 		states[k] = v | 		states[k] = v | ||||||
| 	} | 	} | ||||||
| 	f.current = cstate | 	f.current = cstate | ||||||
| @@ -142,8 +144,14 @@ func (f *FSM) Start(ctx context.Context, args map[string]interface{}, opts ...Op | |||||||
| 			} | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return args, err | 				return args, err | ||||||
| 			} else if nstate == "" || nstate == StateEnd { | 			} else if nstate == StateEnd { | ||||||
| 				return args, nil | 				return args, nil | ||||||
|  | 			} else if nstate == "" { | ||||||
|  | 				for idx := range f.statesOrder { | ||||||
|  | 					if f.statesOrder[idx] == cstate && len(f.statesOrder) > idx+1 { | ||||||
|  | 						nstate = f.statesOrder[idx+1] | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			cstate = nstate | 			cstate = nstate | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -10,40 +10,54 @@ import ( | |||||||
| func TestFSMStart(t *testing.T) { | func TestFSMStart(t *testing.T) { | ||||||
| 	ctx := context.TODO() | 	ctx := context.TODO() | ||||||
| 	buf := bytes.NewBuffer(nil) | 	buf := bytes.NewBuffer(nil) | ||||||
| 	pfb := func(_ context.Context, state string, _ map[string]interface{}) { | 	pfb := func(_ context.Context, state string, _ interface{}) { | ||||||
| 		fmt.Fprintf(buf, "before state %s\n", state) | 		fmt.Fprintf(buf, "before state %s\n", state) | ||||||
| 	} | 	} | ||||||
| 	pfa := func(_ context.Context, state string, _ map[string]interface{}) { | 	pfa := func(_ context.Context, state string, _ interface{}) { | ||||||
| 		fmt.Fprintf(buf, "after state %s\n", state) | 		fmt.Fprintf(buf, "after state %s\n", state) | ||||||
| 	} | 	} | ||||||
| 	f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa)) | 	f := New(StateInitial("1"), StateHookBefore(pfb), StateHookAfter(pfa)) | ||||||
| 	f1 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { | 	f1 := func(_ context.Context, req interface{}) (string, interface{}, error) { | ||||||
|  | 		args := req.(map[string]interface{}) | ||||||
| 		if v, ok := args["request"].(string); !ok || v == "" { | 		if v, ok := args["request"].(string); !ok || v == "" { | ||||||
| 			return "", nil, fmt.Errorf("empty request") | 			return "", nil, fmt.Errorf("empty request") | ||||||
| 		} | 		} | ||||||
| 		return "2", map[string]interface{}{"response": "test2"}, nil | 		return "2", map[string]interface{}{"response": "test2"}, nil | ||||||
| 	} | 	} | ||||||
| 	f2 := func(_ context.Context, args map[string]interface{}) (string, map[string]interface{}, error) { | 	f2 := func(_ context.Context, req interface{}) (string, interface{}, error) { | ||||||
|  | 		args := req.(map[string]interface{}) | ||||||
| 		if v, ok := args["response"].(string); !ok || v == "" { | 		if v, ok := args["response"].(string); !ok || v == "" { | ||||||
| 			return "", nil, fmt.Errorf("empty response") | 			return "", nil, fmt.Errorf("empty response") | ||||||
| 		} | 		} | ||||||
| 		return "", map[string]interface{}{"response": "test"}, nil | 		return "", map[string]interface{}{"response": "test"}, nil | ||||||
| 	} | 	} | ||||||
|  | 	f3 := func(_ context.Context, req interface{}) (string, interface{}, error) { | ||||||
|  | 		args := req.(map[string]interface{}) | ||||||
|  | 		if v, ok := args["response"].(string); !ok || v == "" { | ||||||
|  | 			return "", nil, fmt.Errorf("empty response") | ||||||
|  | 		} | ||||||
|  | 		return StateEnd, map[string]interface{}{"response": "test_last"}, nil | ||||||
|  | 	} | ||||||
| 	f.State("1", f1) | 	f.State("1", f1) | ||||||
| 	f.State("2", f2) | 	f.State("2", f2) | ||||||
| 	args, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) | 	f.State("3", f3) | ||||||
|  | 	rsp, err := f.Start(ctx, map[string]interface{}{"request": "test1"}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} else if v, ok := args["response"].(string); !ok || v == "" { | 	} | ||||||
|  | 	args := rsp.(map[string]interface{}) | ||||||
|  | 	if v, ok := args["response"].(string); !ok || v == "" { | ||||||
| 		t.Fatalf("nil rsp: %#+v", args) | 		t.Fatalf("nil rsp: %#+v", args) | ||||||
| 	} else if v != "test" { | 	} else if v != "test_last" { | ||||||
| 		t.Fatalf("invalid rsp %#+v", args) | 		t.Fatalf("invalid rsp %#+v", args) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) || | 	if !bytes.Contains(buf.Bytes(), []byte(`before state 1`)) || | ||||||
| 		!bytes.Contains(buf.Bytes(), []byte(`before state 2`)) || | 		!bytes.Contains(buf.Bytes(), []byte(`before state 2`)) || | ||||||
| 		!bytes.Contains(buf.Bytes(), []byte(`after state 1`)) || | 		!bytes.Contains(buf.Bytes(), []byte(`after state 1`)) || | ||||||
| 		!bytes.Contains(buf.Bytes(), []byte(`after state 2`)) { | 		!bytes.Contains(buf.Bytes(), []byte(`after state 2`)) || | ||||||
|  | 		!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) || | ||||||
|  | 		!bytes.Contains(buf.Bytes(), []byte(`after state 3`)) { | ||||||
| 		t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes()) | 		t.Fatalf("fsm not works properly or hooks error, buf: %s", buf.Bytes()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user