502 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			502 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package registry provides a dynamic api service router
 | |
| package registry
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net/http"
 | |
| 	"regexp"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/micro/go-micro/v3/api"
 | |
| 	"github.com/micro/go-micro/v3/api/router"
 | |
| 	"github.com/micro/go-micro/v3/api/router/util"
 | |
| 	"github.com/micro/go-micro/v3/logger"
 | |
| 	"github.com/micro/go-micro/v3/metadata"
 | |
| 	"github.com/micro/go-micro/v3/registry"
 | |
| 	"github.com/micro/go-micro/v3/registry/cache"
 | |
| )
 | |
| 
 | |
| // endpoint struct, that holds compiled pcre
 | |
| type endpoint struct {
 | |
| 	hostregs []*regexp.Regexp
 | |
| 	pathregs []util.Pattern
 | |
| 	pcreregs []*regexp.Regexp
 | |
| }
 | |
| 
 | |
| // router is the default router
 | |
| type registryRouter struct {
 | |
| 	exit chan bool
 | |
| 	opts router.Options
 | |
| 
 | |
| 	// registry cache
 | |
| 	rc cache.Cache
 | |
| 
 | |
| 	sync.RWMutex
 | |
| 	eps map[string]*api.Service
 | |
| 	// compiled regexp for host and path
 | |
| 	ceps map[string]*endpoint
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) isClosed() bool {
 | |
| 	select {
 | |
| 	case <-r.exit:
 | |
| 		return true
 | |
| 	default:
 | |
| 		return false
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // refresh list of api services
 | |
| func (r *registryRouter) refresh() {
 | |
| 	var attempts int
 | |
| 
 | |
| 	for {
 | |
| 		services, err := r.opts.Registry.ListServices()
 | |
| 		if err != nil {
 | |
| 			attempts++
 | |
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
 | |
| 				logger.Errorf("unable to list services: %v", err)
 | |
| 			}
 | |
| 			time.Sleep(time.Duration(attempts) * time.Second)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		attempts = 0
 | |
| 
 | |
| 		// for each service, get service and store endpoints
 | |
| 		for _, s := range services {
 | |
| 			service, err := r.rc.GetService(s.Name)
 | |
| 			if err != nil {
 | |
| 				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
 | |
| 					logger.Errorf("unable to get service: %v", err)
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 			r.store(service)
 | |
| 		}
 | |
| 
 | |
| 		// refresh list in 10 minutes... cruft
 | |
| 		// use registry watching
 | |
| 		select {
 | |
| 		case <-time.After(time.Minute * 10):
 | |
| 		case <-r.exit:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // process watch event
 | |
| func (r *registryRouter) process(res *registry.Result) {
 | |
| 	// skip these things
 | |
| 	if res == nil || res.Service == nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// get entry from cache
 | |
| 	service, err := r.rc.GetService(res.Service.Name)
 | |
| 	if err != nil {
 | |
| 		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
 | |
| 			logger.Errorf("unable to get %v service: %v", res.Service.Name, err)
 | |
| 		}
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// update our local endpoints
 | |
| 	r.store(service)
 | |
| }
 | |
| 
 | |
| // store local endpoint cache
 | |
| func (r *registryRouter) store(services []*registry.Service) {
 | |
| 	// endpoints
 | |
| 	eps := map[string]*api.Service{}
 | |
| 
 | |
| 	// services
 | |
| 	names := map[string]bool{}
 | |
| 
 | |
| 	// create a new endpoint mapping
 | |
| 	for _, service := range services {
 | |
| 		// set names we need later
 | |
| 		names[service.Name] = true
 | |
| 
 | |
| 		// map per endpoint
 | |
| 		for _, sep := range service.Endpoints {
 | |
| 			// create a key service:endpoint_name
 | |
| 			key := fmt.Sprintf("%s.%s", service.Name, sep.Name)
 | |
| 			// decode endpoint
 | |
| 			end := api.Decode(sep.Metadata)
 | |
| 			// no endpoint or no name
 | |
| 			if end == nil || len(end.Name) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 			// if we got nothing skip
 | |
| 			if err := api.Validate(end); err != nil {
 | |
| 				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
 | |
| 					logger.Tracef("endpoint validation failed: %v", err)
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			// try get endpoint
 | |
| 			ep, ok := eps[key]
 | |
| 			if !ok {
 | |
| 				ep = &api.Service{Name: service.Name}
 | |
| 			}
 | |
| 
 | |
| 			// overwrite the endpoint
 | |
| 			ep.Endpoint = end
 | |
| 			// append services
 | |
| 			ep.Services = append(ep.Services, service)
 | |
| 			// store it
 | |
| 			eps[key] = ep
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	r.Lock()
 | |
| 	defer r.Unlock()
 | |
| 
 | |
| 	// delete any existing eps for services we know
 | |
| 	for key, service := range r.eps {
 | |
| 		// skip what we don't care about
 | |
| 		if !names[service.Name] {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// ok we know this thing
 | |
| 		// delete delete delete
 | |
| 		delete(r.eps, key)
 | |
| 	}
 | |
| 
 | |
| 	// now set the eps we have
 | |
| 	for name, ep := range eps {
 | |
| 		r.eps[name] = ep
 | |
| 		cep := &endpoint{}
 | |
| 
 | |
| 		for _, h := range ep.Endpoint.Host {
 | |
| 			if h == "" || h == "*" {
 | |
| 				continue
 | |
| 			}
 | |
| 			hostreg, err := regexp.CompilePOSIX(h)
 | |
| 			if err != nil {
 | |
| 				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
 | |
| 					logger.Tracef("endpoint have invalid host regexp: %v", err)
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 			cep.hostregs = append(cep.hostregs, hostreg)
 | |
| 		}
 | |
| 
 | |
| 		for _, p := range ep.Endpoint.Path {
 | |
| 			var pcreok bool
 | |
| 
 | |
| 			if p[0] == '^' && p[len(p)-1] == '$' {
 | |
| 				pcrereg, err := regexp.CompilePOSIX(p)
 | |
| 				if err == nil {
 | |
| 					cep.pcreregs = append(cep.pcreregs, pcrereg)
 | |
| 					pcreok = true
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			rule, err := util.Parse(p)
 | |
| 			if err != nil && !pcreok {
 | |
| 				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
 | |
| 					logger.Tracef("endpoint have invalid path pattern: %v", err)
 | |
| 				}
 | |
| 				continue
 | |
| 			} else if err != nil && pcreok {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			tpl := rule.Compile()
 | |
| 			pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "")
 | |
| 			if err != nil {
 | |
| 				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
 | |
| 					logger.Tracef("endpoint have invalid path pattern: %v", err)
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 			cep.pathregs = append(cep.pathregs, pathreg)
 | |
| 		}
 | |
| 
 | |
| 		r.ceps[name] = cep
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // watch for endpoint changes
 | |
| func (r *registryRouter) watch() {
 | |
| 	var attempts int
 | |
| 
 | |
| 	for {
 | |
| 		if r.isClosed() {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// watch for changes
 | |
| 		w, err := r.opts.Registry.Watch()
 | |
| 		if err != nil {
 | |
| 			attempts++
 | |
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
 | |
| 				logger.Errorf("error watching endpoints: %v", err)
 | |
| 			}
 | |
| 			time.Sleep(time.Duration(attempts) * time.Second)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		ch := make(chan bool)
 | |
| 
 | |
| 		go func() {
 | |
| 			select {
 | |
| 			case <-ch:
 | |
| 				w.Stop()
 | |
| 			case <-r.exit:
 | |
| 				w.Stop()
 | |
| 			}
 | |
| 		}()
 | |
| 
 | |
| 		// reset if we get here
 | |
| 		attempts = 0
 | |
| 
 | |
| 		for {
 | |
| 			// process next event
 | |
| 			res, err := w.Next()
 | |
| 			if err != nil {
 | |
| 				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
 | |
| 					logger.Errorf("error getting next endpoint: %v", err)
 | |
| 				}
 | |
| 				close(ch)
 | |
| 				break
 | |
| 			}
 | |
| 			r.process(res)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) Options() router.Options {
 | |
| 	return r.opts
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) Close() error {
 | |
| 	select {
 | |
| 	case <-r.exit:
 | |
| 		return nil
 | |
| 	default:
 | |
| 		close(r.exit)
 | |
| 		r.rc.Stop()
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) Register(ep *api.Endpoint) error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) Deregister(ep *api.Endpoint) error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) Endpoint(req *http.Request) (*api.Service, error) {
 | |
| 	if r.isClosed() {
 | |
| 		return nil, errors.New("router closed")
 | |
| 	}
 | |
| 
 | |
| 	r.RLock()
 | |
| 	defer r.RUnlock()
 | |
| 
 | |
| 	var idx int
 | |
| 	if len(req.URL.Path) > 0 && req.URL.Path != "/" {
 | |
| 		idx = 1
 | |
| 	}
 | |
| 	path := strings.Split(req.URL.Path[idx:], "/")
 | |
| 
 | |
| 	// use the first match
 | |
| 	// TODO: weighted matching
 | |
| 	for n, e := range r.eps {
 | |
| 		cep, ok := r.ceps[n]
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		ep := e.Endpoint
 | |
| 		var mMatch, hMatch, pMatch bool
 | |
| 		// 1. try method
 | |
| 		for _, m := range ep.Method {
 | |
| 			if m == req.Method {
 | |
| 				mMatch = true
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 		if !mMatch {
 | |
| 			continue
 | |
| 		}
 | |
| 		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
 | |
| 			logger.Debugf("api method match %s", req.Method)
 | |
| 		}
 | |
| 
 | |
| 		// 2. try host
 | |
| 		if len(ep.Host) == 0 {
 | |
| 			hMatch = true
 | |
| 		} else {
 | |
| 			for idx, h := range ep.Host {
 | |
| 				if h == "" || h == "*" {
 | |
| 					hMatch = true
 | |
| 					break
 | |
| 				} else {
 | |
| 					if cep.hostregs[idx].MatchString(req.URL.Host) {
 | |
| 						hMatch = true
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		if !hMatch {
 | |
| 			continue
 | |
| 		}
 | |
| 		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
 | |
| 			logger.Debugf("api host match %s", req.URL.Host)
 | |
| 		}
 | |
| 
 | |
| 		// 3. try path via google.api path matching
 | |
| 		for _, pathreg := range cep.pathregs {
 | |
| 			matches, err := pathreg.Match(path, "")
 | |
| 			if err != nil {
 | |
| 				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
 | |
| 					logger.Debugf("api gpath not match %s != %v", path, pathreg)
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 			if logger.V(logger.DebugLevel, logger.DefaultLogger) {
 | |
| 				logger.Debugf("api gpath match %s = %v", path, pathreg)
 | |
| 			}
 | |
| 			pMatch = true
 | |
| 			ctx := req.Context()
 | |
| 			md, ok := metadata.FromContext(ctx)
 | |
| 			if !ok {
 | |
| 				md = make(metadata.Metadata)
 | |
| 			}
 | |
| 			for k, v := range matches {
 | |
| 				md[fmt.Sprintf("x-api-field-%s", k)] = v
 | |
| 			}
 | |
| 			md["x-api-body"] = ep.Body
 | |
| 			*req = *req.Clone(metadata.NewContext(ctx, md))
 | |
| 			break
 | |
| 		}
 | |
| 
 | |
| 		if !pMatch {
 | |
| 			// 4. try path via pcre path matching
 | |
| 			for _, pathreg := range cep.pcreregs {
 | |
| 				if !pathreg.MatchString(req.URL.Path) {
 | |
| 					if logger.V(logger.DebugLevel, logger.DefaultLogger) {
 | |
| 						logger.Debugf("api pcre path not match %s != %v", path, pathreg)
 | |
| 					}
 | |
| 					continue
 | |
| 				}
 | |
| 				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
 | |
| 					logger.Debugf("api pcre path match %s != %v", path, pathreg)
 | |
| 				}
 | |
| 				pMatch = true
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if !pMatch {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// TODO: Percentage traffic
 | |
| 		// we got here, so its a match
 | |
| 		return e, nil
 | |
| 	}
 | |
| 
 | |
| 	// no match
 | |
| 	return nil, errors.New("not found")
 | |
| }
 | |
| 
 | |
| func (r *registryRouter) Route(req *http.Request) (*api.Service, error) {
 | |
| 	if r.isClosed() {
 | |
| 		return nil, errors.New("router closed")
 | |
| 	}
 | |
| 
 | |
| 	// try get an endpoint
 | |
| 	ep, err := r.Endpoint(req)
 | |
| 	if err == nil {
 | |
| 		return ep, nil
 | |
| 	}
 | |
| 
 | |
| 	// error not nil
 | |
| 	// ignore that shit
 | |
| 	// TODO: don't ignore that shit
 | |
| 
 | |
| 	// get the service name
 | |
| 	rp, err := r.opts.Resolver.Resolve(req)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// service name
 | |
| 	name := rp.Name
 | |
| 
 | |
| 	// get service
 | |
| 	services, err := r.rc.GetService(name, registry.GetDomain(rp.Domain))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// only use endpoint matching when the meta handler is set aka api.Default
 | |
| 	switch r.opts.Handler {
 | |
| 	// rpc handlers
 | |
| 	case "meta", "api", "rpc":
 | |
| 		handler := r.opts.Handler
 | |
| 
 | |
| 		// set default handler to api
 | |
| 		if r.opts.Handler == "meta" {
 | |
| 			handler = "rpc"
 | |
| 		}
 | |
| 
 | |
| 		// construct api service
 | |
| 		return &api.Service{
 | |
| 			Name: name,
 | |
| 			Endpoint: &api.Endpoint{
 | |
| 				Name:    rp.Method,
 | |
| 				Handler: handler,
 | |
| 			},
 | |
| 			Services: services,
 | |
| 		}, nil
 | |
| 	// http handler
 | |
| 	case "http", "proxy", "web":
 | |
| 		// construct api service
 | |
| 		return &api.Service{
 | |
| 			Name: name,
 | |
| 			Endpoint: &api.Endpoint{
 | |
| 				Name:    req.URL.String(),
 | |
| 				Handler: r.opts.Handler,
 | |
| 				Host:    []string{req.Host},
 | |
| 				Method:  []string{req.Method},
 | |
| 				Path:    []string{req.URL.Path},
 | |
| 			},
 | |
| 			Services: services,
 | |
| 		}, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, errors.New("unknown handler")
 | |
| }
 | |
| 
 | |
| func newRouter(opts ...router.Option) *registryRouter {
 | |
| 	options := router.NewOptions(opts...)
 | |
| 	r := ®istryRouter{
 | |
| 		exit: make(chan bool),
 | |
| 		opts: options,
 | |
| 		rc:   cache.New(options.Registry),
 | |
| 		eps:  make(map[string]*api.Service),
 | |
| 		ceps: make(map[string]*endpoint),
 | |
| 	}
 | |
| 	go r.watch()
 | |
| 	go r.refresh()
 | |
| 	return r
 | |
| }
 | |
| 
 | |
| // NewRouter returns the default router
 | |
| func NewRouter(opts ...router.Option) router.Router {
 | |
| 	return newRouter(opts...)
 | |
| }
 |