util/http: trie support method not allowed #162
@@ -7,6 +7,7 @@ package http
 | 
				
			|||||||
// Modified by Unistack LLC to support interface{} type handler and parameters in map[string]string
 | 
					// Modified by Unistack LLC to support interface{} type handler and parameters in map[string]string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
@@ -15,6 +16,11 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						ErrNotFound         = errors.New("route not found")
 | 
				
			||||||
 | 
						ErrMethodNotAllowed = errors.New("method not allowed")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type methodTyp uint
 | 
					type methodTyp uint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@@ -399,16 +405,19 @@ func (n *Trie) setEndpoint(method methodTyp, handler interface{}, pattern string
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Search try to find element in tree with path and method
 | 
					// Search try to find element in tree with path and method
 | 
				
			||||||
func (n *Trie) Search(method string, path string) (interface{}, map[string]string, bool) {
 | 
					func (n *Trie) Search(method string, path string) (interface{}, map[string]string, error) {
 | 
				
			||||||
	params := &routeParams{}
 | 
						params := &routeParams{}
 | 
				
			||||||
	// Find the routing handlers for the path
 | 
						// Find the routing handlers for the path
 | 
				
			||||||
	rn := n.findRoute(params, methodMap[method], path)
 | 
						rn := n.findRoute(params, methodMap[method], path)
 | 
				
			||||||
	if rn == nil {
 | 
						if rn == nil && !params.methodNotAllowed {
 | 
				
			||||||
		return nil, nil, false
 | 
							return nil, nil, ErrNotFound
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if params.methodNotAllowed {
 | 
				
			||||||
 | 
							return nil, nil, ErrMethodNotAllowed
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	ep, ok := rn.endpoints[methodMap[method]]
 | 
						ep, ok := rn.endpoints[methodMap[method]]
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return nil, nil, false
 | 
							return nil, nil, ErrMethodNotAllowed
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	eparams := make(map[string]string, len(params.keys))
 | 
						eparams := make(map[string]string, len(params.keys))
 | 
				
			||||||
@@ -416,12 +425,13 @@ func (n *Trie) Search(method string, path string) (interface{}, map[string]strin
 | 
				
			|||||||
		eparams[key] = params.vals[idx]
 | 
							eparams[key] = params.vals[idx]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return ep.handler, eparams, true
 | 
						return ep.handler, eparams, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type routeParams struct {
 | 
					type routeParams struct {
 | 
				
			||||||
	keys             []string
 | 
						keys             []string
 | 
				
			||||||
	vals             []string
 | 
						vals             []string
 | 
				
			||||||
 | 
						methodNotAllowed bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Recursive edge traversal by checking all nodeTyp groups along the way.
 | 
					// Recursive edge traversal by checking all nodeTyp groups along the way.
 | 
				
			||||||
@@ -495,6 +505,7 @@ func (n *Trie) findRoute(params *routeParams, method methodTyp, path string) *Tr
 | 
				
			|||||||
							params.keys = append(params.keys, h.paramKeys...)
 | 
												params.keys = append(params.keys, h.paramKeys...)
 | 
				
			||||||
							return xn
 | 
												return xn
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
 | 
											params.methodNotAllowed = true
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -530,6 +541,7 @@ func (n *Trie) findRoute(params *routeParams, method methodTyp, path string) *Tr
 | 
				
			|||||||
					params.keys = append(params.keys, h.paramKeys...)
 | 
										params.keys = append(params.keys, h.paramKeys...)
 | 
				
			||||||
					return xn
 | 
										return xn
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									params.methodNotAllowed = true
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,22 +21,22 @@ func TestTrieWildcardPathPrefix(t *testing.T) {
 | 
				
			|||||||
	if err = tr.Insert([]string{http.MethodPost}, "/v1/*", &handler{name: "post_create"}); err != nil {
 | 
						if err = tr.Insert([]string{http.MethodPost}, "/v1/*", &handler{name: "post_create"}); err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	h, _, ok := tr.Search(http.MethodPost, "/v1/test/one")
 | 
						h, _, err := tr.Search(http.MethodPost, "/v1/test/one")
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("unexpected error handler not found")
 | 
							t.Fatalf("unexpected error handler not found")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if h.(*handler).name != "post_create" {
 | 
						if h.(*handler).name != "post_create" {
 | 
				
			||||||
		t.Fatalf("invalid handler %v", h)
 | 
							t.Fatalf("invalid handler %v", h)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	h, _, ok = tr.Search(http.MethodPost, "/v1/update")
 | 
						h, _, err = tr.Search(http.MethodPost, "/v1/update")
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("unexpected error")
 | 
							t.Fatalf("unexpected error")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if h.(*handler).name != "post_update" {
 | 
						if h.(*handler).name != "post_update" {
 | 
				
			||||||
		t.Fatalf("invalid handler %v", h)
 | 
							t.Fatalf("invalid handler %v", h)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	h, _, ok = tr.Search(http.MethodPost, "/v1/update/some/{x}")
 | 
						h, _, err = tr.Search(http.MethodPost, "/v1/update/some/{x}")
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("unexpected error")
 | 
							t.Fatalf("unexpected error")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if h.(*handler).name != "post_create" {
 | 
						if h.(*handler).name != "post_create" {
 | 
				
			||||||
@@ -52,8 +52,8 @@ func TestTriePathPrefix(t *testing.T) {
 | 
				
			|||||||
	_ = tr.Insert([]string{http.MethodPost}, "/v1/create/{id}", &handler{name: "post_create"})
 | 
						_ = tr.Insert([]string{http.MethodPost}, "/v1/create/{id}", &handler{name: "post_create"})
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPost}, "/v1/update/{id}", &handler{name: "post_update"})
 | 
						_ = tr.Insert([]string{http.MethodPost}, "/v1/update/{id}", &handler{name: "post_update"})
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPost}, "/", &handler{name: "post_wildcard"})
 | 
						_ = tr.Insert([]string{http.MethodPost}, "/", &handler{name: "post_wildcard"})
 | 
				
			||||||
	h, _, ok := tr.Search(http.MethodPost, "/")
 | 
						h, _, err := tr.Search(http.MethodPost, "/")
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("unexpected error")
 | 
							t.Fatalf("unexpected error")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if h.(*handler).name != "post_wildcard" {
 | 
						if h.(*handler).name != "post_wildcard" {
 | 
				
			||||||
@@ -68,8 +68,8 @@ func TestTrieFixedPattern(t *testing.T) {
 | 
				
			|||||||
	tr := NewTrie()
 | 
						tr := NewTrie()
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", &handler{name: "pattern"})
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", &handler{name: "pattern"})
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/12", &handler{name: "fixed"})
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create/12", &handler{name: "fixed"})
 | 
				
			||||||
	h, _, ok := tr.Search(http.MethodPut, "/v1/create/12")
 | 
						h, _, err := tr.Search(http.MethodPut, "/v1/create/12")
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("unexpected error")
 | 
							t.Fatalf("unexpected error")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if h.(*handler).name != "fixed" {
 | 
						if h.(*handler).name != "fixed" {
 | 
				
			||||||
@@ -80,8 +80,8 @@ func TestTrieFixedPattern(t *testing.T) {
 | 
				
			|||||||
func TestTrieNoMatchMethod(t *testing.T) {
 | 
					func TestTrieNoMatchMethod(t *testing.T) {
 | 
				
			||||||
	tr := NewTrie()
 | 
						tr := NewTrie()
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", nil)
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", nil)
 | 
				
			||||||
	_, _, ok := tr.Search(http.MethodPost, "/v1/create")
 | 
						_, _, err := tr.Search(http.MethodPost, "/v1/create")
 | 
				
			||||||
	if ok {
 | 
						if err == nil && err != ErrNotFound {
 | 
				
			||||||
		t.Fatalf("must be not found error")
 | 
							t.Fatalf("must be not found error")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -90,9 +90,9 @@ func TestTrieMatchRegexp(t *testing.T) {
 | 
				
			|||||||
	type handler struct{}
 | 
						type handler struct{}
 | 
				
			||||||
	tr := NewTrie()
 | 
						tr := NewTrie()
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{category}/{id:[0-9]+}", &handler{})
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{category}/{id:[0-9]+}", &handler{})
 | 
				
			||||||
	_, params, ok := tr.Search(http.MethodPut, "/v1/create/test_cat/12345")
 | 
						_, params, err := tr.Search(http.MethodPut, "/v1/create/test_cat/12345")
 | 
				
			||||||
	switch {
 | 
						switch {
 | 
				
			||||||
	case !ok:
 | 
						case err != nil:
 | 
				
			||||||
		t.Fatalf("route not found")
 | 
							t.Fatalf("route not found")
 | 
				
			||||||
	case len(params) != 2:
 | 
						case len(params) != 2:
 | 
				
			||||||
		t.Fatalf("param matching error %v", params)
 | 
							t.Fatalf("param matching error %v", params)
 | 
				
			||||||
@@ -105,8 +105,8 @@ func TestTrieMatchRegexpFail(t *testing.T) {
 | 
				
			|||||||
	type handler struct{}
 | 
						type handler struct{}
 | 
				
			||||||
	tr := NewTrie()
 | 
						tr := NewTrie()
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[a-z]+}", &handler{})
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[a-z]+}", &handler{})
 | 
				
			||||||
	_, _, ok := tr.Search(http.MethodPut, "/v1/create/12345")
 | 
						_, _, err := tr.Search(http.MethodPut, "/v1/create/12345")
 | 
				
			||||||
	if ok {
 | 
						if err != ErrNotFound {
 | 
				
			||||||
		t.Fatalf("route must not be not found")
 | 
							t.Fatalf("route must not be not found")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -118,14 +118,28 @@ func TestTrieMatchLongest(t *testing.T) {
 | 
				
			|||||||
	tr := NewTrie()
 | 
						tr := NewTrie()
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{name: "first"})
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{name: "first"})
 | 
				
			||||||
	_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[0-9]+}", &handler{name: "second"})
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[0-9]+}", &handler{name: "second"})
 | 
				
			||||||
	if h, _, ok := tr.Search(http.MethodPut, "/v1/create/12345"); !ok {
 | 
						if h, _, err := tr.Search(http.MethodPut, "/v1/create/12345"); err != nil {
 | 
				
			||||||
		t.Fatalf("route must be found")
 | 
							t.Fatalf("route must be found")
 | 
				
			||||||
	} else if h.(*handler).name != "second" {
 | 
						} else if h.(*handler).name != "second" {
 | 
				
			||||||
		t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "second")
 | 
							t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "second")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if h, _, ok := tr.Search(http.MethodPut, "/v1/create"); !ok {
 | 
						if h, _, err := tr.Search(http.MethodPut, "/v1/create"); err != nil {
 | 
				
			||||||
		t.Fatalf("route must be found")
 | 
							t.Fatalf("route must be found")
 | 
				
			||||||
	} else if h.(*handler).name != "first" {
 | 
						} else if h.(*handler).name != "first" {
 | 
				
			||||||
		t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "first")
 | 
							t.Fatalf("invalid handler found: %s != %s", h.(*handler).name, "first")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMethodNotAllowed(t *testing.T) {
 | 
				
			||||||
 | 
						type handler struct{}
 | 
				
			||||||
 | 
						tr := NewTrie()
 | 
				
			||||||
 | 
						_ = tr.Insert([]string{http.MethodPut}, "/v1/create", &handler{})
 | 
				
			||||||
 | 
						_, _, err := tr.Search(http.MethodPost, "/v1/create")
 | 
				
			||||||
 | 
						if err != ErrMethodNotAllowed {
 | 
				
			||||||
 | 
							t.Fatalf("route must be method not allowed: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_, _, err = tr.Search(http.MethodPut, "/v1/create")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("route must be found: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user