util/http: trie support method not allowed
Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
		| @@ -7,6 +7,7 @@ package http | |||||||
| // Modified by Unistack LLC to support interface{} type handler and parameters in map[string]string | // Modified by Unistack LLC to support interface{} type handler and parameters in map[string]string | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| @@ -15,6 +16,11 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	ErrNotFound         = errors.New("route not found") | ||||||
|  | 	ErrMethodNotAllowed = errors.New("method not allowed") | ||||||
|  | ) | ||||||
|  |  | ||||||
| type methodTyp uint | type methodTyp uint | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -399,16 +405,19 @@ func (n *Trie) setEndpoint(method methodTyp, handler interface{}, pattern string | |||||||
| } | } | ||||||
|  |  | ||||||
| // Search try to find element in tree with path and method | // Search try to find element in tree with path and method | ||||||
| func (n *Trie) Search(method string, path string) (interface{}, map[string]string, bool) { | func (n *Trie) Search(method string, path string) (interface{}, map[string]string, error) { | ||||||
| 	params := &routeParams{} | 	params := &routeParams{} | ||||||
| 	// Find the routing handlers for the path | 	// Find the routing handlers for the path | ||||||
| 	rn := n.findRoute(params, methodMap[method], path) | 	rn := n.findRoute(params, methodMap[method], path) | ||||||
| 	if rn == nil { | 	if rn == nil && !params.methodNotAllowed { | ||||||
| 		return nil, nil, false | 		return nil, nil, ErrNotFound | ||||||
|  | 	} | ||||||
|  | 	if params.methodNotAllowed { | ||||||
|  | 		return nil, nil, ErrMethodNotAllowed | ||||||
| 	} | 	} | ||||||
| 	ep, ok := rn.endpoints[methodMap[method]] | 	ep, ok := rn.endpoints[methodMap[method]] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, nil, false | 		return nil, nil, ErrMethodNotAllowed | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	eparams := make(map[string]string, len(params.keys)) | 	eparams := make(map[string]string, len(params.keys)) | ||||||
| @@ -416,12 +425,13 @@ func (n *Trie) Search(method string, path string) (interface{}, map[string]strin | |||||||
| 		eparams[key] = params.vals[idx] | 		eparams[key] = params.vals[idx] | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return ep.handler, eparams, true | 	return ep.handler, eparams, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type routeParams struct { | type routeParams struct { | ||||||
| 	keys             []string | 	keys             []string | ||||||
| 	vals             []string | 	vals             []string | ||||||
|  | 	methodNotAllowed bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // Recursive edge traversal by checking all nodeTyp groups along the way. | // Recursive edge traversal by checking all nodeTyp groups along the way. | ||||||
| @@ -495,6 +505,7 @@ func (n *Trie) findRoute(params *routeParams, method methodTyp, path string) *Tr | |||||||
| 							params.keys = append(params.keys, h.paramKeys...) | 							params.keys = append(params.keys, h.paramKeys...) | ||||||
| 							return xn | 							return xn | ||||||
| 						} | 						} | ||||||
|  | 						params.methodNotAllowed = true | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| @@ -530,6 +541,7 @@ func (n *Trie) findRoute(params *routeParams, method methodTyp, path string) *Tr | |||||||
| 					params.keys = append(params.keys, h.paramKeys...) | 					params.keys = append(params.keys, h.paramKeys...) | ||||||
| 					return xn | 					return xn | ||||||
| 				} | 				} | ||||||
|  | 				params.methodNotAllowed = true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -21,22 +21,22 @@ func TestTrieWildcardPathPrefix(t *testing.T) { | |||||||
| 	if err = tr.Insert([]string{http.MethodPost}, "/v1/*", &handler{name: "post_create"}); err != nil { | 	if err = tr.Insert([]string{http.MethodPost}, "/v1/*", &handler{name: "post_create"}); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	h, _, ok := tr.Search(http.MethodPost, "/v1/test/one") | 	h, _, err := tr.Search(http.MethodPost, "/v1/test/one") | ||||||
| 	if !ok { | 	if err != nil { | ||||||
| 		t.Fatalf("unexpected error handler not found") | 		t.Fatalf("unexpected error handler not found") | ||||||
| 	} | 	} | ||||||
| 	if h.(*handler).name != "post_create" { | 	if h.(*handler).name != "post_create" { | ||||||
| 		t.Fatalf("invalid handler %v", h) | 		t.Fatalf("invalid handler %v", h) | ||||||
| 	} | 	} | ||||||
| 	h, _, ok = tr.Search(http.MethodPost, "/v1/update") | 	h, _, err = tr.Search(http.MethodPost, "/v1/update") | ||||||
| 	if !ok { | 	if err != nil { | ||||||
| 		t.Fatalf("unexpected error") | 		t.Fatalf("unexpected error") | ||||||
| 	} | 	} | ||||||
| 	if h.(*handler).name != "post_update" { | 	if h.(*handler).name != "post_update" { | ||||||
| 		t.Fatalf("invalid handler %v", h) | 		t.Fatalf("invalid handler %v", h) | ||||||
| 	} | 	} | ||||||
| 	h, _, ok = tr.Search(http.MethodPost, "/v1/update/some/{x}") | 	h, _, err = tr.Search(http.MethodPost, "/v1/update/some/{x}") | ||||||
| 	if !ok { | 	if err != nil { | ||||||
| 		t.Fatalf("unexpected error") | 		t.Fatalf("unexpected error") | ||||||
| 	} | 	} | ||||||
| 	if h.(*handler).name != "post_create" { | 	if h.(*handler).name != "post_create" { | ||||||
| @@ -52,8 +52,8 @@ func TestTriePathPrefix(t *testing.T) { | |||||||
| 	_ = tr.Insert([]string{http.MethodPost}, "/v1/create/{id}", &handler{name: "post_create"}) | 	_ = tr.Insert([]string{http.MethodPost}, "/v1/create/{id}", &handler{name: "post_create"}) | ||||||
| 	_ = tr.Insert([]string{http.MethodPost}, "/v1/update/{id}", &handler{name: "post_update"}) | 	_ = tr.Insert([]string{http.MethodPost}, "/v1/update/{id}", &handler{name: "post_update"}) | ||||||
| 	_ = tr.Insert([]string{http.MethodPost}, "/", &handler{name: "post_wildcard"}) | 	_ = tr.Insert([]string{http.MethodPost}, "/", &handler{name: "post_wildcard"}) | ||||||
| 	h, _, ok := tr.Search(http.MethodPost, "/") | 	h, _, err := tr.Search(http.MethodPost, "/") | ||||||
| 	if !ok { | 	if err != nil { | ||||||
| 		t.Fatalf("unexpected error") | 		t.Fatalf("unexpected error") | ||||||
| 	} | 	} | ||||||
| 	if h.(*handler).name != "post_wildcard" { | 	if h.(*handler).name != "post_wildcard" { | ||||||
| @@ -68,8 +68,8 @@ func TestTrieFixedPattern(t *testing.T) { | |||||||
| 	tr := NewTrie() | 	tr := NewTrie() | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", &handler{name: "pattern"}) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", &handler{name: "pattern"}) | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/12", &handler{name: "fixed"}) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/12", &handler{name: "fixed"}) | ||||||
| 	h, _, ok := tr.Search(http.MethodPut, "/v1/create/12") | 	h, _, err := tr.Search(http.MethodPut, "/v1/create/12") | ||||||
| 	if !ok { | 	if err != nil { | ||||||
| 		t.Fatalf("unexpected error") | 		t.Fatalf("unexpected error") | ||||||
| 	} | 	} | ||||||
| 	if h.(*handler).name != "fixed" { | 	if h.(*handler).name != "fixed" { | ||||||
| @@ -80,8 +80,8 @@ func TestTrieFixedPattern(t *testing.T) { | |||||||
| func TestTrieNoMatchMethod(t *testing.T) { | func TestTrieNoMatchMethod(t *testing.T) { | ||||||
| 	tr := NewTrie() | 	tr := NewTrie() | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", nil) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", nil) | ||||||
| 	_, _, ok := tr.Search(http.MethodPost, "/v1/create") | 	_, _, err := tr.Search(http.MethodPost, "/v1/create") | ||||||
| 	if ok { | 	if err == nil && err != ErrNotFound { | ||||||
| 		t.Fatalf("must be not found error") | 		t.Fatalf("must be not found error") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -90,9 +90,9 @@ func TestTrieMatchRegexp(t *testing.T) { | |||||||
| 	type handler struct{} | 	type handler struct{} | ||||||
| 	tr := NewTrie() | 	tr := NewTrie() | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{category}/{id:[0-9]+}", &handler{}) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{category}/{id:[0-9]+}", &handler{}) | ||||||
| 	_, params, ok := tr.Search(http.MethodPut, "/v1/create/test_cat/12345") | 	_, params, err := tr.Search(http.MethodPut, "/v1/create/test_cat/12345") | ||||||
| 	switch { | 	switch { | ||||||
| 	case !ok: | 	case err != nil: | ||||||
| 		t.Fatalf("route not found") | 		t.Fatalf("route not found") | ||||||
| 	case len(params) != 2: | 	case len(params) != 2: | ||||||
| 		t.Fatalf("param matching error %v", params) | 		t.Fatalf("param matching error %v", params) | ||||||
| @@ -105,8 +105,8 @@ func TestTrieMatchRegexpFail(t *testing.T) { | |||||||
| 	type handler struct{} | 	type handler struct{} | ||||||
| 	tr := NewTrie() | 	tr := NewTrie() | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[a-z]+}", &handler{}) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[a-z]+}", &handler{}) | ||||||
| 	_, _, ok := tr.Search(http.MethodPut, "/v1/create/12345") | 	_, _, err := tr.Search(http.MethodPut, "/v1/create/12345") | ||||||
| 	if ok { | 	if err != ErrNotFound { | ||||||
| 		t.Fatalf("route must not be not found") | 		t.Fatalf("route must not be not found") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -118,14 +118,28 @@ func TestTrieMatchLongest(t *testing.T) { | |||||||
| 	tr := NewTrie() | 	tr := NewTrie() | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{name: "first"}) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{name: "first"}) | ||||||
| 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[0-9]+}", &handler{name: "second"}) | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[0-9]+}", &handler{name: "second"}) | ||||||
| 	if h, _, ok := tr.Search(http.MethodPut, "/v1/create/12345"); !ok { | 	if h, _, err := tr.Search(http.MethodPut, "/v1/create/12345"); err != nil { | ||||||
| 		t.Fatalf("route must be found") | 		t.Fatalf("route must be found") | ||||||
| 	} else if h.(*handler).name != "second" { | 	} else if h.(*handler).name != "second" { | ||||||
| 		t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "second") | 		t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "second") | ||||||
| 	} | 	} | ||||||
| 	if h, _, ok := tr.Search(http.MethodPut, "/v1/create"); !ok { | 	if h, _, err := tr.Search(http.MethodPut, "/v1/create"); err != nil { | ||||||
| 		t.Fatalf("route must be found") | 		t.Fatalf("route must be found") | ||||||
| 	} else if h.(*handler).name != "first" { | 	} else if h.(*handler).name != "first" { | ||||||
| 		t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "first") | 		t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "first") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestMethodNotAllowed(t *testing.T) { | ||||||
|  | 	type handler struct{} | ||||||
|  | 	tr := NewTrie() | ||||||
|  | 	_ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{}) | ||||||
|  | 	_, _, err := tr.Search(http.MethodPost, "/v1/create") | ||||||
|  | 	if err != ErrMethodNotAllowed { | ||||||
|  | 		t.Fatalf("route must be method not allowed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	_, _, err = tr.Search(http.MethodPut, "/v1/create") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("route must be found: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user