package fsm

import (
	"context"
	"fmt"
	"sync"
)

type state struct {
	body interface{}
	name string
}

var _ State = &state{}

func (s *state) Name() string {
	return s.name
}

func (s *state) Body() interface{} {
	return s.body
}

// fsm is a finite state machine
type fsm struct {
	statesMap   map[string]StateFunc
	current     string
	statesOrder []string
	opts        Options
	mu          sync.Mutex
}

// NewFSM creates a new finite state machine having the specified initial state
// with specified options
func NewFSM(opts ...Option) FSM {
	return &fsm{
		statesMap: map[string]StateFunc{},
		opts:      NewOptions(opts...),
	}
}

// Current returns the current state
func (f *fsm) Current() string {
	f.mu.Lock()
	s := f.current
	f.mu.Unlock()
	return s
}

// Current returns the current state
func (f *fsm) Reset() {
	f.mu.Lock()
	f.current = f.opts.Initial
	f.mu.Unlock()
}

// State adds state to fsm
func (f *fsm) State(state string, fn StateFunc) {
	f.mu.Lock()
	f.statesMap[state] = fn
	f.statesOrder = append(f.statesOrder, state)
	f.mu.Unlock()
}

// Start runs state machine with provided data
func (f *fsm) Start(ctx context.Context, args interface{}, opts ...Option) (interface{}, error) {
	var err error

	f.mu.Lock()
	options := f.opts

	for _, opt := range opts {
		opt(&options)
	}

	sopts := []StateOption{StateDryRun(options.DryRun)}

	cstate := options.Initial
	states := make(map[string]StateFunc, len(f.statesMap))
	for k, v := range f.statesMap {
		states[k] = v
	}
	f.current = cstate
	f.mu.Unlock()

	var s State
	s = &state{name: cstate, body: args}
	nstate := s.Name()

	for {
		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		default:
			fn, ok := states[nstate]
			if !ok {
				return nil, fmt.Errorf(`state "%s" %w`, nstate, ErrInvalidState)
			}
			f.mu.Lock()
			f.current = nstate
			f.mu.Unlock()

			// wrap the handler func
			for i := len(options.Wrappers); i > 0; i-- {
				fn = options.Wrappers[i-1](fn)
			}

			s, err = fn(ctx, s, sopts...)

			switch {
			case err != nil:
				return s.Body(), err
			case s.Name() == StateEnd:
				return s.Body(), nil
			case s.Name() == "":
				for idx := range f.statesOrder {
					if f.statesOrder[idx] == nstate && len(f.statesOrder) > idx+1 {
						nstate = f.statesOrder[idx+1]
					}
				}
			default:
				nstate = s.Name()
			}
		}
	}
}