From 5e47cc7e8c9d68cbf7e0af69fab0ec21c7a695fd Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Tue, 27 Dec 2022 23:47:11 +0300 Subject: [PATCH] util/http: trie support method not allowed Signed-off-by: Vasiliy Tolstov --- util/http/trie.go | 26 ++++++++++++++++------ util/http/trie_test.go | 50 +++++++++++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/util/http/trie.go b/util/http/trie.go index 471bfc02..864c4a44 100644 --- a/util/http/trie.go +++ b/util/http/trie.go @@ -7,6 +7,7 @@ package http // Modified by Unistack LLC to support interface{} type handler and parameters in map[string]string import ( + "errors" "fmt" "net/http" "regexp" @@ -15,6 +16,11 @@ import ( "strings" ) +var ( + ErrNotFound = errors.New("route not found") + ErrMethodNotAllowed = errors.New("method not allowed") +) + type methodTyp uint const ( @@ -399,16 +405,19 @@ func (n *Trie) setEndpoint(method methodTyp, handler interface{}, pattern string } // Search try to find element in tree with path and method -func (n *Trie) Search(method string, path string) (interface{}, map[string]string, bool) { +func (n *Trie) Search(method string, path string) (interface{}, map[string]string, error) { params := &routeParams{} // Find the routing handlers for the path rn := n.findRoute(params, methodMap[method], path) - if rn == nil { - return nil, nil, false + if rn == nil && !params.methodNotAllowed { + return nil, nil, ErrNotFound + } + if params.methodNotAllowed { + return nil, nil, ErrMethodNotAllowed } ep, ok := rn.endpoints[methodMap[method]] if !ok { - return nil, nil, false + return nil, nil, ErrMethodNotAllowed } eparams := make(map[string]string, len(params.keys)) @@ -416,12 +425,13 @@ func (n *Trie) Search(method string, path string) (interface{}, map[string]strin eparams[key] = params.vals[idx] } - return ep.handler, eparams, true + return ep.handler, eparams, nil } type routeParams struct { - keys []string - vals []string + keys []string + vals []string + methodNotAllowed bool } // Recursive edge traversal by checking all nodeTyp groups along the way. @@ -495,6 +505,7 @@ func (n *Trie) findRoute(params *routeParams, method methodTyp, path string) *Tr params.keys = append(params.keys, h.paramKeys...) return xn } + params.methodNotAllowed = true } } @@ -530,6 +541,7 @@ func (n *Trie) findRoute(params *routeParams, method methodTyp, path string) *Tr params.keys = append(params.keys, h.paramKeys...) return xn } + params.methodNotAllowed = true } } diff --git a/util/http/trie_test.go b/util/http/trie_test.go index d4e0a4db..0dd58c81 100644 --- a/util/http/trie_test.go +++ b/util/http/trie_test.go @@ -21,22 +21,22 @@ func TestTrieWildcardPathPrefix(t *testing.T) { if err = tr.Insert([]string{http.MethodPost}, "/v1/*", &handler{name: "post_create"}); err != nil { t.Fatal(err) } - h, _, ok := tr.Search(http.MethodPost, "/v1/test/one") - if !ok { + h, _, err := tr.Search(http.MethodPost, "/v1/test/one") + if err != nil { t.Fatalf("unexpected error handler not found") } if h.(*handler).name != "post_create" { t.Fatalf("invalid handler %v", h) } - h, _, ok = tr.Search(http.MethodPost, "/v1/update") - if !ok { + h, _, err = tr.Search(http.MethodPost, "/v1/update") + if err != nil { t.Fatalf("unexpected error") } if h.(*handler).name != "post_update" { t.Fatalf("invalid handler %v", h) } - h, _, ok = tr.Search(http.MethodPost, "/v1/update/some/{x}") - if !ok { + h, _, err = tr.Search(http.MethodPost, "/v1/update/some/{x}") + if err != nil { t.Fatalf("unexpected error") } if h.(*handler).name != "post_create" { @@ -52,8 +52,8 @@ func TestTriePathPrefix(t *testing.T) { _ = tr.Insert([]string{http.MethodPost}, "/v1/create/{id}", &handler{name: "post_create"}) _ = tr.Insert([]string{http.MethodPost}, "/v1/update/{id}", &handler{name: "post_update"}) _ = tr.Insert([]string{http.MethodPost}, "/", &handler{name: "post_wildcard"}) - h, _, ok := tr.Search(http.MethodPost, "/") - if !ok { + h, _, err := tr.Search(http.MethodPost, "/") + if err != nil { t.Fatalf("unexpected error") } if h.(*handler).name != "post_wildcard" { @@ -68,8 +68,8 @@ func TestTrieFixedPattern(t *testing.T) { tr := NewTrie() _ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", &handler{name: "pattern"}) _ = tr.Insert([]string{http.MethodPut}, "/v1/create/12", &handler{name: "fixed"}) - h, _, ok := tr.Search(http.MethodPut, "/v1/create/12") - if !ok { + h, _, err := tr.Search(http.MethodPut, "/v1/create/12") + if err != nil { t.Fatalf("unexpected error") } if h.(*handler).name != "fixed" { @@ -80,8 +80,8 @@ func TestTrieFixedPattern(t *testing.T) { func TestTrieNoMatchMethod(t *testing.T) { tr := NewTrie() _ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", nil) - _, _, ok := tr.Search(http.MethodPost, "/v1/create") - if ok { + _, _, err := tr.Search(http.MethodPost, "/v1/create") + if err == nil && err != ErrNotFound { t.Fatalf("must be not found error") } } @@ -90,9 +90,9 @@ func TestTrieMatchRegexp(t *testing.T) { type handler struct{} tr := NewTrie() _ = tr.Insert([]string{http.MethodPut}, "/v1/create/{category}/{id:[0-9]+}", &handler{}) - _, params, ok := tr.Search(http.MethodPut, "/v1/create/test_cat/12345") + _, params, err := tr.Search(http.MethodPut, "/v1/create/test_cat/12345") switch { - case !ok: + case err != nil: t.Fatalf("route not found") case len(params) != 2: t.Fatalf("param matching error %v", params) @@ -105,8 +105,8 @@ func TestTrieMatchRegexpFail(t *testing.T) { type handler struct{} tr := NewTrie() _ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[a-z]+}", &handler{}) - _, _, ok := tr.Search(http.MethodPut, "/v1/create/12345") - if ok { + _, _, err := tr.Search(http.MethodPut, "/v1/create/12345") + if err != ErrNotFound { t.Fatalf("route must not be not found") } } @@ -118,14 +118,28 @@ func TestTrieMatchLongest(t *testing.T) { tr := NewTrie() _ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{name: "first"}) _ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[0-9]+}", &handler{name: "second"}) - if h, _, ok := tr.Search(http.MethodPut, "/v1/create/12345"); !ok { + if h, _, err := tr.Search(http.MethodPut, "/v1/create/12345"); err != nil { t.Fatalf("route must be found") } else if h.(*handler).name != "second" { t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "second") } - if h, _, ok := tr.Search(http.MethodPut, "/v1/create"); !ok { + if h, _, err := tr.Search(http.MethodPut, "/v1/create"); err != nil { t.Fatalf("route must be found") } else if h.(*handler).name != "first" { t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "first") } } + +func TestMethodNotAllowed(t *testing.T) { + type handler struct{} + tr := NewTrie() + _ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{}) + _, _, err := tr.Search(http.MethodPost, "/v1/create") + if err != ErrMethodNotAllowed { + t.Fatalf("route must be method not allowed: %v", err) + } + _, _, err = tr.Search(http.MethodPut, "/v1/create") + if err != nil { + t.Fatalf("route must be found: %v", err) + } +}