From 66ef6b67ca42f2c68c5858f12cbb83dbc88bd03b Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Thu, 12 Jan 2017 14:11:25 +0000 Subject: [PATCH] add support for port range in http broker --- broker/http_broker.go | 40 +++++++++++++++++++------ transport/http_transport.go | 60 ++----------------------------------- 2 files changed, 34 insertions(+), 66 deletions(-) diff --git a/broker/http_broker.go b/broker/http_broker.go index bfbc4496..9be2f87b 100644 --- a/broker/http_broker.go +++ b/broker/http_broker.go @@ -20,7 +20,8 @@ import ( "github.com/micro/go-micro/broker/codec/json" "github.com/micro/go-micro/errors" "github.com/micro/go-micro/registry" - "github.com/micro/misc/lib/addr" + maddr "github.com/micro/misc/lib/addr" + mnet "github.com/micro/misc/lib/net" mls "github.com/micro/misc/lib/tls" "github.com/pborman/uuid" @@ -211,16 +212,37 @@ func (h *httpBroker) start() error { if h.opts.Secure || h.opts.TLSConfig != nil { config := h.opts.TLSConfig - if config == nil { - cert, err := mls.Certificate(h.address) - if err != nil { - return err + + fn := func(addr string) (net.Listener, error) { + if config == nil { + hosts := []string{addr} + + // check if its a valid host:port + if host, _, err := net.SplitHostPort(addr); err == nil { + if len(host) == 0 { + hosts = maddr.IPs() + } else { + hosts = []string{host} + } + } + + // generate a certificate + cert, err := mls.Certificate(hosts...) + if err != nil { + return nil, err + } + config = &tls.Config{Certificates: []tls.Certificate{cert}} } - config = &tls.Config{Certificates: []tls.Certificate{cert}} + return tls.Listen("tcp", addr, config) } - l, err = tls.Listen("tcp", h.address, config) + + l, err = mnet.Listen(h.address, fn) } else { - l, err = net.Listen("tcp", h.address) + fn := func(addr string) (net.Listener, error) { + return net.Listen("tcp", addr) + } + + l, err = mnet.Listen(h.address, fn) } if err != nil { @@ -412,7 +434,7 @@ func (h *httpBroker) Subscribe(topic string, handler Handler, opts ...SubscribeO host := strings.Join(parts[:len(parts)-1], ":") port, _ := strconv.Atoi(parts[len(parts)-1]) - addr, err := addr.Extract(host) + addr, err := maddr.Extract(host) if err != nil { return nil, err } diff --git a/transport/http_transport.go b/transport/http_transport.go index ce562129..35ec21d6 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -5,19 +5,17 @@ import ( "bytes" "crypto/tls" "errors" - "fmt" "io" "io/ioutil" "log" "net" "net/http" "net/url" - "strconv" - "strings" "sync" "time" maddr "github.com/micro/misc/lib/addr" + mnet "github.com/micro/misc/lib/net" mls "github.com/micro/misc/lib/tls" ) @@ -57,58 +55,6 @@ type httpTransportListener struct { listener net.Listener } -func listen(addr string, fn func(string) (net.Listener, error)) (net.Listener, error) { - // host:port || host:min-max - parts := strings.Split(addr, ":") - - // - if len(parts) < 2 { - return fn(addr) - } - - // try to extract port range - ports := strings.Split(parts[len(parts)-1], "-") - - // single port - if len(ports) < 2 { - return fn(addr) - } - - // we have a port range - - // extract min port - min, err := strconv.Atoi(ports[0]) - if err != nil { - return nil, errors.New("unable to extract port range") - } - - // extract max port - max, err := strconv.Atoi(ports[1]) - if err != nil { - return nil, errors.New("unable to extract port range") - } - - // set host - host := parts[:len(parts)-1] - - // range the ports - for port := min; port <= max; port++ { - // try bind to host:port - ln, err := fn(fmt.Sprintf("%s:%d", host, port)) - if err == nil { - return ln, nil - } - - // hit max port - if port == max { - return nil, err - } - } - - // why are we here? - return nil, fmt.Errorf("unable to bind to %s", addr) -} - func (b *buffer) Close() error { return nil } @@ -457,13 +403,13 @@ func (h *httpTransport) Listen(addr string, opts ...ListenOption) (Listener, err return tls.Listen("tcp", addr, config) } - l, err = listen(addr, fn) + l, err = mnet.Listen(addr, fn) } else { fn := func(addr string) (net.Listener, error) { return net.Listen("tcp", addr) } - l, err = listen(addr, fn) + l, err = mnet.Listen(addr, fn) } if err != nil {