diff --git a/util/http/trie.go b/util/http/trie.go new file mode 100644 index 00000000..05cee6a0 --- /dev/null +++ b/util/http/trie.go @@ -0,0 +1,201 @@ +package http + +import ( + "regexp" + "strings" + "sync" +) + +// Tree is a trie tree. +type Trie struct { + rmu sync.RWMutex + node *node + rcache map[string]*regexp.Regexp +} + +// node is a node of tree +type node struct { + actions map[string]interface{} // key is method, val is handler interface + children map[string]*node // key is label of next nodes + label string +} + +const ( + pathRoot string = "/" + pathDelimiter string = "/" + paramDelimiter string = ":" + leftPtnDelimiter string = "{" + rightPtnDelimiter string = "}" + ptnWildcard string = "(.+)" +) + +// NewTree creates a new trie tree. +func NewTrie() *Trie { + return &Trie{ + node: &node{ + label: pathRoot, + actions: make(map[string]interface{}), + children: make(map[string]*node), + }, + rcache: make(map[string]*regexp.Regexp), + } +} + +// Insert inserts a route definition to tree. +func (t *Trie) Insert(methods []string, path string, handler interface{}) { + curNode := t.node + if path == pathRoot { + curNode.label = path + for _, method := range methods { + curNode.actions[method] = handler + } + return + } + ep := splitPath(path) + for i, p := range ep { + nextNode, ok := curNode.children[p] + if ok { + curNode = nextNode + } + // Create a new node. + if !ok { + curNode.children[p] = &node{ + label: p, + actions: make(map[string]interface{}), + children: make(map[string]*node), + } + curNode = curNode.children[p] + } + // last loop. + // If there is already registered data, overwrite it. + if i == len(ep)-1 { + curNode.label = p + for _, method := range methods { + curNode.actions[method] = handler + } + break + } + } +} + +// Search searches a path from a tree. +func (t *Trie) Search(method string, path string) (interface{}, map[string]string, bool) { + params := make(map[string]string) + + curNode := t.node + for _, p := range splitPath(path) { + nextNode, ok := curNode.children[p] + if ok { + curNode = nextNode + continue + } + if len(curNode.children) == 0 { + if curNode.label != p { + // no matching path was found. + return nil, nil, false + } + break + } + isParamMatch := false + for c := range curNode.children { + if string([]rune(c)[0]) == leftPtnDelimiter { + ptn := getPattern(c) + t.rmu.RLock() + reg, ok := t.rcache[ptn] + t.rmu.RUnlock() + if !ok { + var err error + reg, err = regexp.Compile(ptn) + if err != nil { + return nil, nil, false + } + t.rmu.Lock() + t.rcache[ptn] = reg + t.rmu.Unlock() + } + if reg.Match([]byte(p)) { + pn := getParamName(c) + params[pn] = p + curNode = curNode.children[c] + isParamMatch = true + break + } + // no matching param was found. + return nil, nil, false + } + } + if !isParamMatch { + return nil, nil, false + } + } + if path == pathRoot { + if len(curNode.actions) == 0 { + return nil, nil, false + } + } + + handler, ok := curNode.actions[method] + if !ok || handler == nil { + return nil, nil, false + } + return handler, params, true +} + +// getPattern gets a pattern from a label +// {id:[^\d+$]} -> ^\d+$ +// {id} -> (.+) +func getPattern(label string) string { + leftI := strings.Index(label, leftPtnDelimiter) + rightI := strings.Index(label, paramDelimiter) + // if label doesn't have any pattern, return wild card pattern as default. + if leftI == -1 || rightI == -1 { + return ptnWildcard + } + return label[rightI+1 : len(label)-1] +} + +// getParamName gets a parameter from a label +// {id:[^\d+$]} -> id +// {id} -> id +func getParamName(label string) string { + leftI := strings.Index(label, leftPtnDelimiter) + rightI := func(l string) int { + r := []rune(l) + + var n int + + loop: + for i := 0; i < len(r); i++ { + n = i + switch string(r[i]) { + case paramDelimiter: + n = i + break loop + case rightPtnDelimiter: + n = i + break loop + } + + if i == len(r)-1 { + n = i + 1 + break loop + } + } + + return n + }(label) + + return label[leftI+1 : rightI] +} + +// splitPath removes an empty value in slice. +func splitPath(path string) []string { + s := strings.Split(path, pathDelimiter) + var r []string + for _, str := range s { + if str != "" { + r = append(r, str) + } + } + return r +} diff --git a/util/http/trie_test.go b/util/http/trie_test.go new file mode 100644 index 00000000..94c805e8 --- /dev/null +++ b/util/http/trie_test.go @@ -0,0 +1,42 @@ +package http + +import ( + "net/http" + "testing" +) + +func TestTrieNoMatchMethod(t *testing.T) { + tr := NewTrie() + tr.Insert([]string{http.MethodPut}, "/v1/create/{id}", nil) + + _, _, ok := tr.Search(http.MethodPost, "/v1/create") + if ok { + t.Fatalf("must be not found error") + } +} + +type handler struct{} + +func TestTrieMatchRegexp(t *testing.T) { + tr := NewTrie() + tr.Insert([]string{http.MethodPut}, "/v1/create/{category}/{id:[0-9]+}", &handler{}) + + _, params, ok := tr.Search(http.MethodPut, "/v1/create/test_cat/12345") + if !ok { + t.Fatalf("route not found") + } else if len(params) != 2 { + t.Fatalf("param matching error %v", params) + } else if params["category"] != "test_cat" { + t.Fatalf("param matching error %v", params) + } +} + +func TestTrieMatchRegexpFail(t *testing.T) { + tr := NewTrie() + tr.Insert([]string{http.MethodPut}, "/v1/create/{id:[a-z]+}", &handler{}) + + _, _, ok := tr.Search(http.MethodPut, "/v1/create/12345") + if ok { + t.Fatalf("route must not be not found") + } +}