diff --git a/network/default.go b/network/default.go index 535b9a83..d0a8a683 100644 --- a/network/default.go +++ b/network/default.go @@ -846,7 +846,7 @@ func (n *network) Connect() error { } // dial into ControlChannel to send route adverts - ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMulticast()) + ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMode(tunnel.Multicast)) if err != nil { n.Unlock() return err @@ -855,14 +855,14 @@ func (n *network) Connect() error { n.tunClient[ControlChannel] = ctrlClient // listen on ControlChannel - ctrlListener, err := n.tunnel.Listen(ControlChannel) + ctrlListener, err := n.tunnel.Listen(ControlChannel, tunnel.ListenMode(tunnel.Multicast)) if err != nil { n.Unlock() return err } // dial into NetworkChannel to send network messages - netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMulticast()) + netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMode(tunnel.Multicast)) if err != nil { n.Unlock() return err @@ -871,7 +871,7 @@ func (n *network) Connect() error { n.tunClient[NetworkChannel] = netClient // listen on NetworkChannel - netListener, err := n.tunnel.Listen(NetworkChannel) + netListener, err := n.tunnel.Listen(NetworkChannel, tunnel.ListenMode(tunnel.Multicast)) if err != nil { n.Unlock() return err diff --git a/tunnel/broker/broker.go b/tunnel/broker/broker.go index 6778dfaa..d11e160c 100644 --- a/tunnel/broker/broker.go +++ b/tunnel/broker/broker.go @@ -58,7 +58,7 @@ func (t *tunBroker) Disconnect() error { func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error { // TODO: this is probably inefficient, we might want to just maintain an open connection // it may be easier to add broadcast to the tunnel - c, err := t.tunnel.Dial(topic, tunnel.DialMulticast()) + c, err := t.tunnel.Dial(topic, tunnel.DialMode(tunnel.Multicast)) if err != nil { return err } @@ -71,7 +71,7 @@ func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.Publ } func (t *tunBroker) Subscribe(topic string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) { - l, err := t.tunnel.Listen(topic) + l, err := t.tunnel.Listen(topic, tunnel.ListenMode(tunnel.Multicast)) if err != nil { return nil, err } diff --git a/tunnel/default.go b/tunnel/default.go index 16ca339e..38f0f100 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -326,7 +326,7 @@ func (t *tun) process() { } // check the multicast mappings - if msg.multicast { + if msg.mode == Multicast { link.RLock() _, ok := link.channels[msg.channel] link.RUnlock() @@ -366,7 +366,7 @@ func (t *tun) process() { sent = true // keep sending broadcast messages - if msg.broadcast || msg.multicast { + if msg.mode > Unicast { continue } @@ -523,7 +523,7 @@ func (t *tun) listen(link *link) { case "accept": s, exists := t.getSession(channel, sessionId) // we don't need this - if exists && s.multicast { + if exists && s.mode > Unicast { s.accepted = true continue } @@ -963,7 +963,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // set the multicast option - c.multicast = options.Multicast + c.mode = options.Mode // set the dial timeout c.timeout = options.Timeout @@ -1009,7 +1009,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // discovered so set the link if not multicast // TODO: pick the link efficiently based // on link status and saturation. - if c.discovered && !c.multicast { + if c.discovered && c.mode == Unicast { // set the link i := rand.Intn(len(links)) c.link = links[i] @@ -1019,7 +1019,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { if !c.discovered { // create a new discovery message for this channel msg := c.newMessage("discover") - msg.broadcast = true + msg.mode = Broadcast msg.outbound = true msg.link = "" @@ -1041,7 +1041,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { dialTimeout := after() // set a shorter delay for multicast - if c.multicast { + if c.mode > Unicast { // shorten this dialTimeout = time.Millisecond * 500 } @@ -1057,7 +1057,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // if its multicast just go ahead because this is best effort - if c.multicast { + if c.mode > Unicast { c.discovered = true c.accepted = true return c, nil @@ -1086,9 +1086,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // Accept a connection on the address -func (t *tun) Listen(channel string) (Listener, error) { +func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { log.Debugf("Tunnel listening on %s", channel) + var options ListenOptions + for _, o := range opts { + o(&options) + } + // create a new session by hashing the address c, ok := t.newSession(channel, "listener") if !ok { @@ -1103,6 +1108,8 @@ func (t *tun) Listen(channel string) (Listener, error) { c.remote = "remote" // set local c.local = channel + // set mode + c.mode = options.Mode tl := &tunListener{ channel: channel, diff --git a/tunnel/listener.go b/tunnel/listener.go index f154b2a6..c893297d 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -82,8 +82,8 @@ func (t *tunListener) process() { loopback: m.loopback, // the link the message was received on link: m.link, - // set multicast - multicast: m.multicast, + // set the connection mode + mode: m.mode, // close chan closed: make(chan bool), // recv called by the acceptor diff --git a/tunnel/options.go b/tunnel/options.go index 903f7fb7..7d6360ab 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -36,12 +36,19 @@ type DialOption func(*DialOptions) type DialOptions struct { // Link specifies the link to use Link string - // specify a multicast connection - Multicast bool + // specify mode of the session + Mode Mode // the dial timeout Timeout time.Duration } +type ListenOption func(*ListenOptions) + +type ListenOptions struct { + // specify mode of the session + Mode Mode +} + // The tunnel id func Id(id string) Option { return func(o *Options) { @@ -87,12 +94,19 @@ func DefaultOptions() Options { } } +// Listen options +func ListenMode(m Mode) ListenOption { + return func(o *ListenOptions) { + o.Mode = m + } +} + // Dial options // Dial multicast sets the multicast option to send only to those mapped -func DialMulticast() DialOption { +func DialMode(m Mode) DialOption { return func(o *DialOptions) { - o.Multicast = true + o.Mode = m } } diff --git a/tunnel/session.go b/tunnel/session.go index 6757f150..a185c3ce 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -37,10 +37,8 @@ type session struct { outbound bool // lookback marks the session as a loopback on the inbound loopback bool - // if the session is multicast - multicast bool - // if the session is broadcast - broadcast bool + // mode of the connection + mode Mode // the timeout timeout time.Duration // the link on which this message was received @@ -63,10 +61,8 @@ type message struct { outbound bool // loopback marks the message intended for loopback loopback bool - // whether to send as multicast - multicast bool - // broadcast sets the broadcast type - broadcast bool + // mode of the connection + mode Mode // the link to send the message on link string // transport data @@ -98,15 +94,15 @@ func (s *session) Channel() string { // newMessage creates a new message based on the session func (s *session) newMessage(typ string) *message { return &message{ - typ: typ, - tunnel: s.tunnel, - channel: s.channel, - session: s.session, - outbound: s.outbound, - loopback: s.loopback, - multicast: s.multicast, - link: s.link, - errChan: s.errChan, + typ: typ, + tunnel: s.tunnel, + channel: s.channel, + session: s.session, + outbound: s.outbound, + loopback: s.loopback, + mode: s.mode, + link: s.link, + errChan: s.errChan, } } @@ -128,8 +124,8 @@ func (s *session) Open() error { return io.EOF } - // we don't wait on multicast - if s.multicast { + // don't wait on multicast/broadcast + if s.mode == Multicast { s.accepted = true return nil } @@ -166,6 +162,11 @@ func (s *session) Accept() error { // no op here } + // don't wait on multicast/broadcast + if s.mode == Multicast { + return nil + } + // wait for send response select { case err := <-s.errChan: @@ -185,7 +186,7 @@ func (s *session) Announce() error { // we don't need an error back msg.errChan = nil // announce to all - msg.broadcast = true + msg.mode = Broadcast // we don't need the link msg.link = "" @@ -222,7 +223,7 @@ func (s *session) Send(m *transport.Message) error { msg.data = data // if multicast don't set the link - if s.multicast { + if s.mode == Multicast { msg.link = "" } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 312a6681..a2671f62 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -8,6 +8,15 @@ import ( "github.com/micro/go-micro/transport" ) +const ( + // send over one link + Unicast Mode = iota + // send to all channel listeners + Multicast + // send to all links + Broadcast +) + var ( // DefaultDialTimeout is the dial timeout if none is specified DefaultDialTimeout = time.Second * 5 @@ -19,6 +28,9 @@ var ( ErrLinkNotFound = errors.New("link not found") ) +// Mode of the session +type Mode uint8 + // Tunnel creates a gre tunnel on top of the go-micro/transport. // It establishes multiple streams using the Micro-Tunnel-Channel header // and Micro-Tunnel-Session header. The tunnel id is a hash of @@ -36,7 +48,7 @@ type Tunnel interface { // Connect to a channel Dial(channel string, opts ...DialOption) (Session, error) // Accept connections on a channel - Listen(channel string) (Listener, error) + Listen(channel string, opts ...ListenOption) (Listener, error) // Name of the tunnel implementation String() string }