446 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			446 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package web
 | |
| 
 | |
| import (
 | |
| 	"crypto/tls"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"os/signal"
 | |
| 	"path/filepath"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"syscall"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/micro/cli"
 | |
| 	"github.com/micro/go-micro"
 | |
| 	"github.com/micro/go-micro/registry"
 | |
| 	maddr "github.com/micro/go-micro/util/addr"
 | |
| 	mhttp "github.com/micro/go-micro/util/http"
 | |
| 	"github.com/micro/go-micro/util/log"
 | |
| 	mnet "github.com/micro/go-micro/util/net"
 | |
| 	mls "github.com/micro/go-micro/util/tls"
 | |
| )
 | |
| 
 | |
| type service struct {
 | |
| 	opts Options
 | |
| 
 | |
| 	mux *http.ServeMux
 | |
| 	srv *registry.Service
 | |
| 
 | |
| 	sync.Mutex
 | |
| 	running bool
 | |
| 	static  bool
 | |
| 	exit    chan chan error
 | |
| }
 | |
| 
 | |
| func newService(opts ...Option) Service {
 | |
| 	options := newOptions(opts...)
 | |
| 	s := &service{
 | |
| 		opts:   options,
 | |
| 		mux:    http.NewServeMux(),
 | |
| 		static: true,
 | |
| 	}
 | |
| 	s.srv = s.genSrv()
 | |
| 	return s
 | |
| }
 | |
| 
 | |
| func (s *service) genSrv() *registry.Service {
 | |
| 	// default host:port
 | |
| 	parts := strings.Split(s.opts.Address, ":")
 | |
| 	host := strings.Join(parts[:len(parts)-1], ":")
 | |
| 	port, _ := strconv.Atoi(parts[len(parts)-1])
 | |
| 
 | |
| 	// check the advertise address first
 | |
| 	// if it exists then use it, otherwise
 | |
| 	// use the address
 | |
| 	if len(s.opts.Advertise) > 0 {
 | |
| 		parts = strings.Split(s.opts.Advertise, ":")
 | |
| 
 | |
| 		// we have host:port
 | |
| 		if len(parts) > 1 {
 | |
| 			// set the host
 | |
| 			host = strings.Join(parts[:len(parts)-1], ":")
 | |
| 
 | |
| 			// get the port
 | |
| 			if aport, _ := strconv.Atoi(parts[len(parts)-1]); aport > 0 {
 | |
| 				port = aport
 | |
| 			}
 | |
| 		} else {
 | |
| 			host = parts[0]
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	addr, err := maddr.Extract(host)
 | |
| 	if err != nil {
 | |
| 		// best effort localhost
 | |
| 		addr = "127.0.0.1"
 | |
| 	}
 | |
| 
 | |
| 	return ®istry.Service{
 | |
| 		Name:    s.opts.Name,
 | |
| 		Version: s.opts.Version,
 | |
| 		Nodes: []*registry.Node{®istry.Node{
 | |
| 			Id:       s.opts.Id,
 | |
| 			Address:  addr,
 | |
| 			Port:     port,
 | |
| 			Metadata: s.opts.Metadata,
 | |
| 		}},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *service) run(exit chan bool) {
 | |
| 	if s.opts.RegisterInterval <= time.Duration(0) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	t := time.NewTicker(s.opts.RegisterInterval)
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-t.C:
 | |
| 			s.register()
 | |
| 		case <-exit:
 | |
| 			t.Stop()
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *service) register() error {
 | |
| 	if s.srv == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	// default to service registry
 | |
| 	r := s.opts.Service.Client().Options().Registry
 | |
| 	// switch to option if specified
 | |
| 	if s.opts.Registry != nil {
 | |
| 		r = s.opts.Registry
 | |
| 	}
 | |
| 	return r.Register(s.srv, registry.RegisterTTL(s.opts.RegisterTTL))
 | |
| }
 | |
| 
 | |
| func (s *service) deregister() error {
 | |
| 	if s.srv == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	// default to service registry
 | |
| 	r := s.opts.Service.Client().Options().Registry
 | |
| 	// switch to option if specified
 | |
| 	if s.opts.Registry != nil {
 | |
| 		r = s.opts.Registry
 | |
| 	}
 | |
| 	return r.Deregister(s.srv)
 | |
| }
 | |
| 
 | |
| func (s *service) start() error {
 | |
| 	s.Lock()
 | |
| 	defer s.Unlock()
 | |
| 
 | |
| 	if s.running {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	l, err := s.listen("tcp", s.opts.Address)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	s.opts.Address = l.Addr().String()
 | |
| 	srv := s.genSrv()
 | |
| 	srv.Endpoints = s.srv.Endpoints
 | |
| 	s.srv = srv
 | |
| 
 | |
| 	var h http.Handler
 | |
| 
 | |
| 	if s.opts.Handler != nil {
 | |
| 		h = s.opts.Handler
 | |
| 	} else {
 | |
| 		h = s.mux
 | |
| 		var r sync.Once
 | |
| 
 | |
| 		// register the html dir
 | |
| 		r.Do(func() {
 | |
| 			// static dir
 | |
| 			static := s.opts.StaticDir
 | |
| 			if s.opts.StaticDir[0] != '/' {
 | |
| 				dir, _ := os.Getwd()
 | |
| 				static = filepath.Join(dir, static)
 | |
| 			}
 | |
| 
 | |
| 			// set static if no / handler is registered
 | |
| 			if s.static {
 | |
| 				_, err := os.Stat(static)
 | |
| 				if err == nil {
 | |
| 					log.Logf("Enabling static file serving from %s", static)
 | |
| 					s.mux.Handle("/", http.FileServer(http.Dir(static)))
 | |
| 				}
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	for _, fn := range s.opts.BeforeStart {
 | |
| 		if err := fn(); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var httpSrv *http.Server
 | |
| 	if s.opts.Server != nil {
 | |
| 		httpSrv = s.opts.Server
 | |
| 	} else {
 | |
| 		httpSrv = &http.Server{}
 | |
| 	}
 | |
| 
 | |
| 	httpSrv.Handler = h
 | |
| 
 | |
| 	go httpSrv.Serve(l)
 | |
| 
 | |
| 	for _, fn := range s.opts.AfterStart {
 | |
| 		if err := fn(); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	s.exit = make(chan chan error, 1)
 | |
| 	s.running = true
 | |
| 
 | |
| 	go func() {
 | |
| 		ch := <-s.exit
 | |
| 		ch <- l.Close()
 | |
| 	}()
 | |
| 
 | |
| 	log.Logf("Listening on %v\n", l.Addr().String())
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (s *service) stop() error {
 | |
| 	s.Lock()
 | |
| 	defer s.Unlock()
 | |
| 
 | |
| 	if !s.running {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	for _, fn := range s.opts.BeforeStop {
 | |
| 		if err := fn(); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	ch := make(chan error, 1)
 | |
| 	s.exit <- ch
 | |
| 	s.running = false
 | |
| 
 | |
| 	log.Log("Stopping")
 | |
| 
 | |
| 	for _, fn := range s.opts.AfterStop {
 | |
| 		if err := fn(); err != nil {
 | |
| 			if chErr := <-ch; chErr != nil {
 | |
| 				return chErr
 | |
| 			}
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return <-ch
 | |
| }
 | |
| 
 | |
| func (s *service) Client() *http.Client {
 | |
| 	rt := mhttp.NewRoundTripper(
 | |
| 		mhttp.WithRegistry(registry.DefaultRegistry),
 | |
| 	)
 | |
| 	return &http.Client{
 | |
| 		Transport: rt,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *service) Handle(pattern string, handler http.Handler) {
 | |
| 	var seen bool
 | |
| 	for _, ep := range s.srv.Endpoints {
 | |
| 		if ep.Name == pattern {
 | |
| 			seen = true
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// if its unseen then add an endpoint
 | |
| 	if !seen {
 | |
| 		s.srv.Endpoints = append(s.srv.Endpoints, ®istry.Endpoint{
 | |
| 			Name: pattern,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	// disable static serving
 | |
| 	if pattern == "/" {
 | |
| 		s.Lock()
 | |
| 		s.static = false
 | |
| 		s.Unlock()
 | |
| 	}
 | |
| 
 | |
| 	// register the handler
 | |
| 	s.mux.Handle(pattern, handler)
 | |
| }
 | |
| 
 | |
| func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
 | |
| 	var seen bool
 | |
| 	for _, ep := range s.srv.Endpoints {
 | |
| 		if ep.Name == pattern {
 | |
| 			seen = true
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	if !seen {
 | |
| 		s.srv.Endpoints = append(s.srv.Endpoints, ®istry.Endpoint{
 | |
| 			Name: pattern,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	s.mux.HandleFunc(pattern, handler)
 | |
| }
 | |
| 
 | |
| func (s *service) Init(opts ...Option) error {
 | |
| 	for _, o := range opts {
 | |
| 		o(&s.opts)
 | |
| 	}
 | |
| 
 | |
| 	serviceOpts := []micro.Option{}
 | |
| 
 | |
| 	if len(s.opts.Flags) > 0 {
 | |
| 		serviceOpts = append(serviceOpts, micro.Flags(s.opts.Flags...))
 | |
| 	}
 | |
| 
 | |
| 	if s.opts.Registry != nil {
 | |
| 		serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
 | |
| 	}
 | |
| 
 | |
| 	serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) {
 | |
| 		if ttl := ctx.Int("register_ttl"); ttl > 0 {
 | |
| 			s.opts.RegisterTTL = time.Duration(ttl) * time.Second
 | |
| 		}
 | |
| 
 | |
| 		if interval := ctx.Int("register_interval"); interval > 0 {
 | |
| 			s.opts.RegisterInterval = time.Duration(interval) * time.Second
 | |
| 		}
 | |
| 
 | |
| 		if name := ctx.String("server_name"); len(name) > 0 {
 | |
| 			s.opts.Name = name
 | |
| 		}
 | |
| 
 | |
| 		if ver := ctx.String("server_version"); len(ver) > 0 {
 | |
| 			s.opts.Version = ver
 | |
| 		}
 | |
| 
 | |
| 		if id := ctx.String("server_id"); len(id) > 0 {
 | |
| 			s.opts.Id = id
 | |
| 		}
 | |
| 
 | |
| 		if addr := ctx.String("server_address"); len(addr) > 0 {
 | |
| 			s.opts.Address = addr
 | |
| 		}
 | |
| 
 | |
| 		if adv := ctx.String("server_advertise"); len(adv) > 0 {
 | |
| 			s.opts.Advertise = adv
 | |
| 		}
 | |
| 
 | |
| 		if s.opts.Action != nil {
 | |
| 			s.opts.Action(ctx)
 | |
| 		}
 | |
| 	}))
 | |
| 
 | |
| 	s.opts.Service.Init(serviceOpts...)
 | |
| 	srv := s.genSrv()
 | |
| 	srv.Endpoints = s.srv.Endpoints
 | |
| 	s.srv = srv
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (s *service) Run() error {
 | |
| 	if err := s.start(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if err := s.register(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// start reg loop
 | |
| 	ex := make(chan bool)
 | |
| 	go s.run(ex)
 | |
| 
 | |
| 	ch := make(chan os.Signal, 1)
 | |
| 	signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL)
 | |
| 
 | |
| 	select {
 | |
| 	// wait on kill signal
 | |
| 	case sig := <-ch:
 | |
| 		log.Logf("Received signal %s\n", sig)
 | |
| 	// wait on context cancel
 | |
| 	case <-s.opts.Context.Done():
 | |
| 		log.Logf("Received context shutdown")
 | |
| 	}
 | |
| 
 | |
| 	// exit reg loop
 | |
| 	close(ex)
 | |
| 
 | |
| 	if err := s.deregister(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return s.stop()
 | |
| }
 | |
| 
 | |
| // Options returns the options for the given service
 | |
| func (s *service) Options() Options {
 | |
| 	return s.opts
 | |
| }
 | |
| 
 | |
| func (s *service) listen(network, addr string) (net.Listener, error) {
 | |
| 	var l net.Listener
 | |
| 	var err error
 | |
| 
 | |
| 	// TODO: support use of listen options
 | |
| 	if s.opts.Secure || s.opts.TLSConfig != nil {
 | |
| 		config := s.opts.TLSConfig
 | |
| 
 | |
| 		fn := func(addr string) (net.Listener, error) {
 | |
| 			if config == nil {
 | |
| 				hosts := []string{addr}
 | |
| 
 | |
| 				// check if its a valid host:port
 | |
| 				if host, _, err := net.SplitHostPort(addr); err == nil {
 | |
| 					if len(host) == 0 {
 | |
| 						hosts = maddr.IPs()
 | |
| 					} else {
 | |
| 						hosts = []string{host}
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				// generate a certificate
 | |
| 				cert, err := mls.Certificate(hosts...)
 | |
| 				if err != nil {
 | |
| 					return nil, err
 | |
| 				}
 | |
| 				config = &tls.Config{Certificates: []tls.Certificate{cert}}
 | |
| 			}
 | |
| 			return tls.Listen(network, addr, config)
 | |
| 		}
 | |
| 
 | |
| 		l, err = mnet.Listen(addr, fn)
 | |
| 	} else {
 | |
| 		fn := func(addr string) (net.Listener, error) {
 | |
| 			return net.Listen(network, addr)
 | |
| 		}
 | |
| 
 | |
| 		l, err = mnet.Listen(addr, fn)
 | |
| 	}
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return l, nil
 | |
| }
 |