diff --git a/network/default.go b/network/default.go index 26789436..83557592 100644 --- a/network/default.go +++ b/network/default.go @@ -234,7 +234,7 @@ func (n *network) resolveNodes() ([]string, error) { dns := &dns.Resolver{} // append seed nodes if we have them - for _, node := range n.options.Peers { + for _, node := range n.options.Nodes { // resolve anything that looks like a host name records, err := dns.Resolve(node) if err != nil { @@ -282,7 +282,8 @@ func (n *network) handleNetConn(s tunnel.Session, msg chan *message) { m := new(transport.Message) if err := s.Recv(m); err != nil { log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err) - if err == io.EOF { + switch err { + case io.EOF, tunnel.ErrReadTimeout: s.Close() return } @@ -338,39 +339,10 @@ func (n *network) acceptNetConn(l tunnel.Listener, recv chan *message) { } } -// updatePeerLinks updates link for a given peer -func (n *network) updatePeerLinks(peerAddr string, linkId string) error { - n.Lock() - defer n.Unlock() - log.Tracef("Network looking up link %s in the peer links", linkId) - // lookup the peer link - var peerLink tunnel.Link - for _, link := range n.tunnel.Links() { - if link.Id() == linkId { - peerLink = link - break - } - } - if peerLink == nil { - return ErrPeerLinkNotFound - } - // if the peerLink is found in the returned links update peerLinks - log.Tracef("Network updating peer links for peer %s", peerAddr) - // add peerLink to the peerLinks map - if link, ok := n.peerLinks[peerAddr]; ok { - // if the existing has better Length then the new, replace it - if link.Length() < peerLink.Length() { - n.peerLinks[peerAddr] = peerLink - } - } else { - n.peerLinks[peerAddr] = peerLink - } - - return nil -} - // processNetChan processes messages received on NetworkChannel func (n *network) processNetChan(listener tunnel.Listener) { + defer listener.Close() + // receive network message queue recv := make(chan *message, 128) @@ -707,13 +679,51 @@ func (n *network) sendMsg(method, channel string, msg proto.Message) error { }) } +// updatePeerLinks updates link for a given peer +func (n *network) updatePeerLinks(peerAddr string, linkId string) error { + n.Lock() + defer n.Unlock() + + log.Tracef("Network looking up link %s in the peer links", linkId) + + // lookup the peer link + var peerLink tunnel.Link + + for _, link := range n.tunnel.Links() { + if link.Id() == linkId { + peerLink = link + break + } + } + + if peerLink == nil { + return ErrPeerLinkNotFound + } + + // if the peerLink is found in the returned links update peerLinks + log.Tracef("Network updating peer links for peer %s", peerAddr) + + // add peerLink to the peerLinks map + if link, ok := n.peerLinks[peerAddr]; ok { + // if the existing has better Length then the new, replace it + if link.Length() < peerLink.Length() { + n.peerLinks[peerAddr] = peerLink + } + } else { + n.peerLinks[peerAddr] = peerLink + } + + return nil +} + // handleCtrlConn handles ControlChannel connections func (n *network) handleCtrlConn(s tunnel.Session, msg chan *message) { for { m := new(transport.Message) if err := s.Recv(m); err != nil { log.Debugf("Network tunnel [%s] receive error: %v", ControlChannel, err) - if err == io.EOF { + switch err { + case io.EOF, tunnel.ErrReadTimeout: s.Close() return } @@ -843,6 +853,8 @@ func (n *network) getRouteMetric(router string, gateway string, link string) int // processCtrlChan processes messages received on ControlChannel func (n *network) processCtrlChan(listener tunnel.Listener) { + defer listener.Close() + // receive control message queue recv := make(chan *message, 128) @@ -1151,12 +1163,7 @@ func (n *network) connect() { // Connect connects the network func (n *network) Connect() error { n.Lock() - - // connect network tunnel - if err := n.tunnel.Connect(); err != nil { - n.Unlock() - return err - } + defer n.Unlock() // try to resolve network nodes nodes, err := n.resolveNodes() @@ -1169,10 +1176,14 @@ func (n *network) Connect() error { tunnel.Nodes(nodes...), ) + // connect network tunnel + if err := n.tunnel.Connect(); err != nil { + n.Unlock() + return err + } + // return if already connected if n.connected { - // unlock first - n.Unlock() // send the connect message n.sendConnect() return nil @@ -1187,32 +1198,36 @@ func (n *network) Connect() error { // dial into ControlChannel to send route adverts ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMode(tunnel.Multicast)) if err != nil { - n.Unlock() return err } n.tunClient[ControlChannel] = ctrlClient // listen on ControlChannel - ctrlListener, err := n.tunnel.Listen(ControlChannel, tunnel.ListenMode(tunnel.Multicast)) + ctrlListener, err := n.tunnel.Listen( + ControlChannel, + tunnel.ListenMode(tunnel.Multicast), + tunnel.ListenTimeout(router.AdvertiseTableTick*2), + ) if err != nil { - n.Unlock() return err } // dial into NetworkChannel to send network messages netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMode(tunnel.Multicast)) if err != nil { - n.Unlock() return err } n.tunClient[NetworkChannel] = netClient // listen on NetworkChannel - netListener, err := n.tunnel.Listen(NetworkChannel, tunnel.ListenMode(tunnel.Multicast)) + netListener, err := n.tunnel.Listen( + NetworkChannel, + tunnel.ListenMode(tunnel.Multicast), + tunnel.ListenTimeout(AnnounceTime*2), + ) if err != nil { - n.Unlock() return err } @@ -1221,23 +1236,19 @@ func (n *network) Connect() error { // start the router if err := n.options.Router.Start(); err != nil { - n.Unlock() return err } // start advertising routes advertChan, err := n.options.Router.Advertise() if err != nil { - n.Unlock() return err } // start the server if err := n.server.Start(); err != nil { - n.Unlock() return err } - n.Unlock() // send connect after there's a link established go n.connect() @@ -1252,9 +1263,8 @@ func (n *network) Connect() error { // accept and process routes go n.processCtrlChan(ctrlListener) - n.Lock() + // we're now connected n.connected = true - n.Unlock() return nil } @@ -1340,6 +1350,7 @@ func (n *network) Close() error { Address: n.node.address, }, } + if err := n.sendMsg("close", NetworkChannel, msg); err != nil { log.Debugf("Network failed to send close message: %s", err) } diff --git a/network/network.go b/network/network.go index 8f71709c..e927241b 100644 --- a/network/network.go +++ b/network/network.go @@ -6,8 +6,6 @@ import ( "github.com/micro/go-micro/client" "github.com/micro/go-micro/server" - "github.com/micro/go-micro/transport" - "github.com/micro/go-micro/tunnel" ) var ( diff --git a/network/options.go b/network/options.go index 0a3c5b2b..63ff6fe7 100644 --- a/network/options.go +++ b/network/options.go @@ -22,8 +22,8 @@ type Options struct { Address string // Advertise sets the address to advertise Advertise string - // Peers is a list of peers to connect to - Peers []string + // Nodes is a list of nodes to connect to + Nodes []string // Tunnel is network tunnel Tunnel tunnel.Tunnel // Router is network router @@ -62,10 +62,10 @@ func Advertise(a string) Option { } } -// Peers is a list of peers to connect to -func Peers(n ...string) Option { +// Nodes is a list of nodes to connect to +func Nodes(n ...string) Option { return func(o *Options) { - o.Peers = n + o.Nodes = n } } diff --git a/network/resolver/dns/dns.go b/network/resolver/dns/dns.go index b4029f06..a8d27b06 100644 --- a/network/resolver/dns/dns.go +++ b/network/resolver/dns/dns.go @@ -31,6 +31,17 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { r.Address = "1.0.0.1:53" } + //nolint:prealloc + var records []*resolver.Record + + // parsed an actual ip + if v := net.ParseIP(host); v != nil { + records = append(records, &resolver.Record{ + Address: net.JoinHostPort(host, port), + }) + return records, nil + } + m := new(dns.Msg) m.SetQuestion(dns.Fqdn(host), dns.TypeA) rec, err := dns.ExchangeContext(context.Background(), m, r.Address) @@ -38,9 +49,6 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { return nil, err } - //nolint:prealloc - var records []*resolver.Record - for _, answer := range rec.Answer { h := answer.Header() // check record type matches @@ -59,5 +67,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { }) } + // no records returned so just best effort it + if len(records) == 0 { + records = append(records, &resolver.Record{ + Address: net.JoinHostPort(host, port), + }) + } + return records, nil } diff --git a/network/service/handler/handler.go b/network/service/handler/handler.go index 5a9f9379..d8538a7c 100644 --- a/network/service/handler/handler.go +++ b/network/service/handler/handler.go @@ -55,7 +55,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp * } // get list of existing nodes - nodes := n.Network.Options().Peers + nodes := n.Network.Options().Nodes // generate a node map nodeMap := make(map[string]bool) @@ -84,7 +84,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp * // reinitialise the peers n.Network.Init( - network.Peers(nodes...), + network.Nodes(nodes...), ) // call the connect method diff --git a/tunnel/default.go b/tunnel/default.go index 5943757f..15909654 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -197,73 +197,96 @@ func (t *tun) announce(channel, session string, link *link) { } } -// monitor monitors outbound links and attempts to reconnect to the failed ones -func (t *tun) monitor() { +// manage monitors outbound links and attempts to reconnect to the failed ones +func (t *tun) manage() { reconnect := time.NewTicker(ReconnectTime) defer reconnect.Stop() + // do it immediately + t.manageLinks() + for { select { case <-t.closed: return case <-reconnect.C: - t.RLock() - - var delLinks []string - // check the link status and purge dead links - for node, link := range t.links { - // check link status - switch link.State() { - case "closed": - delLinks = append(delLinks, node) - case "error": - delLinks = append(delLinks, node) - } - } - - t.RUnlock() - - // delete the dead links - if len(delLinks) > 0 { - t.Lock() - for _, node := range delLinks { - log.Debugf("Tunnel deleting dead link for %s", node) - if link, ok := t.links[node]; ok { - link.Close() - delete(t.links, node) - } - } - t.Unlock() - } - - // check current link status - var connect []string - - // build list of unknown nodes to connect to - t.RLock() - for _, node := range t.options.Nodes { - if _, ok := t.links[node]; !ok { - connect = append(connect, node) - } - } - t.RUnlock() - - for _, node := range connect { - // create new link - link, err := t.setupLink(node) - if err != nil { - log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) - continue - } - // save the link - t.Lock() - t.links[node] = link - t.Unlock() - } + t.manageLinks() } } } +// manageLinks is a function that can be called to immediately to link setup +func (t *tun) manageLinks() { + var delLinks []string + + t.RLock() + + // check the link status and purge dead links + for node, link := range t.links { + // check link status + switch link.State() { + case "closed": + delLinks = append(delLinks, node) + case "error": + delLinks = append(delLinks, node) + } + } + + t.RUnlock() + + // delete the dead links + if len(delLinks) > 0 { + t.Lock() + for _, node := range delLinks { + log.Debugf("Tunnel deleting dead link for %s", node) + if link, ok := t.links[node]; ok { + link.Close() + delete(t.links, node) + } + } + t.Unlock() + } + + // check current link status + var connect []string + + // build list of unknown nodes to connect to + t.RLock() + + for _, node := range t.options.Nodes { + if _, ok := t.links[node]; !ok { + connect = append(connect, node) + } + } + + t.RUnlock() + + var wg sync.WaitGroup + + for _, node := range connect { + wg.Add(1) + + go func() { + defer wg.Done() + + // create new link + link, err := t.setupLink(node) + if err != nil { + log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) + return + } + + // save the link + t.Lock() + t.links[node] = link + t.Unlock() + }() + } + + // wait for all threads to finish + wg.Wait() +} + // process outgoing messages sent by all local sessions func (t *tun) process() { // manage the send buffer @@ -757,11 +780,13 @@ func (t *tun) keepalive(link *link) { // It returns error if the link failed to be established func (t *tun) setupLink(node string) (*link, error) { log.Debugf("Tunnel setting up link: %s", node) + c, err := t.options.Transport.Dial(node) if err != nil { log.Debugf("Tunnel failed to connect to %s: %v", node, err) return nil, err } + log.Debugf("Tunnel connected to %s", node) // create a new link @@ -795,30 +820,6 @@ func (t *tun) setupLink(node string) (*link, error) { return link, nil } -func (t *tun) setupLinks() { - for _, node := range t.options.Nodes { - // skip zero length nodes - if len(node) == 0 { - continue - } - - // link already exists - if _, ok := t.links[node]; ok { - continue - } - - // connect to node and return link - link, err := t.setupLink(node) - if err != nil { - log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) - continue - } - - // save the link - t.links[node] = link - } -} - // connect the tunnel to all the nodes and listen for incoming tunnel connections func (t *tun) connect() error { l, err := t.options.Transport.Listen(t.options.Address) @@ -869,11 +870,10 @@ func (t *tun) Connect() error { // already connected if t.connected { // setup links - t.setupLinks() return nil } - // send the connect message + // connect the tunnel: start the listener if err := t.connect(); err != nil { return err } @@ -883,16 +883,13 @@ func (t *tun) Connect() error { // create new close channel t.closed = make(chan bool) - // setup links - t.setupLinks() + // manage the links + go t.manage() // process outbound messages to be sent // process sends to all links go t.process() - // monitor links - go t.monitor() - return nil } @@ -1029,7 +1026,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // set the multicast option c.mode = options.Mode // set the dial timeout - c.timeout = options.Timeout + c.dialTimeout = options.Timeout + // set read timeout set to never + c.readTimeout = time.Duration(-1) var links []*link // did we measure the rtt @@ -1145,7 +1144,11 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { log.Debugf("Tunnel listening on %s", channel) - var options ListenOptions + options := ListenOptions{ + // Read timeout defaults to never + Timeout: time.Duration(-1), + } + for _, o := range opts { o(&options) } @@ -1167,6 +1170,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { c.local = channel // set mode c.mode = options.Mode + // set the timeout + c.readTimeout = options.Timeout tl := &tunListener{ channel: channel, diff --git a/tunnel/listener.go b/tunnel/listener.go index 288b7dac..2417edd6 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -65,18 +65,8 @@ func (t *tunListener) process() { return // receive a new message case m := <-t.session.recv: - var sessionId string - - // get the session id - switch m.mode { - case Multicast, Broadcast: - // use channel name if multicast/broadcast - sessionId = "multicast" - log.Tracef("Tunnel listener using session %s for real session %s", sessionId, m.session) - default: - // use session id if unicast - sessionId = m.session - } + // session id + sessionId := m.session // get a session sess, ok := conns[sessionId] @@ -113,6 +103,8 @@ func (t *tunListener) process() { send: t.session.send, // error channel errChan: make(chan error, 1), + // set the read timeout + readTimeout: t.session.readTimeout, } // save the session @@ -137,12 +129,10 @@ func (t *tunListener) process() { // no op delete(conns, sessionId) default: - if sess.mode == Unicast { - // only close if unicast session - // close and delete session - close(sess.closed) - delete(conns, sessionId) - } + // only close if unicast session + // close and delete session + close(sess.closed) + delete(conns, sessionId) } // continue diff --git a/tunnel/options.go b/tunnel/options.go index 3a8f13db..145db19f 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -47,6 +47,8 @@ type ListenOption func(*ListenOptions) type ListenOptions struct { // specify mode of the session Mode Mode + // The read timeout + Timeout time.Duration } // The tunnel id @@ -84,16 +86,6 @@ func Transport(t transport.Transport) Option { } } -// DefaultOptions returns router default options -func DefaultOptions() Options { - return Options{ - Id: uuid.New().String(), - Address: DefaultAddress, - Token: DefaultToken, - Transport: quic.NewTransport(), - } -} - // Listen options func ListenMode(m Mode) ListenOption { return func(o *ListenOptions) { @@ -101,6 +93,13 @@ func ListenMode(m Mode) ListenOption { } } +// Timeout for reads and writes on the listener session +func ListenTimeout(t time.Duration) ListenOption { + return func(o *ListenOptions) { + o.Timeout = t + } +} + // Dial options // Dial multicast sets the multicast option to send only to those mapped @@ -124,3 +123,13 @@ func DialLink(id string) DialOption { o.Link = id } } + +// DefaultOptions returns router default options +func DefaultOptions() Options { + return Options{ + Id: uuid.New().String(), + Address: DefaultAddress, + Token: DefaultToken, + Transport: quic.NewTransport(), + } +} diff --git a/tunnel/session.go b/tunnel/session.go index e545d316..ebb179df 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -39,8 +39,10 @@ type session struct { loopback bool // mode of the connection mode Mode - // the timeout - timeout time.Duration + // the dial timeout + dialTimeout time.Duration + // the read timeout + readTimeout time.Duration // the link on which this message was received link string // the error response @@ -133,31 +135,43 @@ func (s *session) wait(msg *message) error { func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { now := time.Now() - after := func(timeout time.Duration) time.Duration { + after := func(timeout time.Duration) <-chan time.Time { + if timeout < time.Duration(0) { + return nil + } + + // get the delta d := time.Since(now) + // dial timeout minus time since wait := timeout - d if wait < time.Duration(0) { - return time.Duration(0) + wait = time.Duration(0) } - return wait + return time.After(wait) } // wait for the message type for { select { case msg := <-s.recv: + // there may be no message type + if len(msgType) == 0 { + return msg, nil + } + // ignore what we don't want if msg.typ != msgType { log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) continue } + // got the message return msg, nil - case <-time.After(after(timeout)): - return nil, ErrDialTimeout + case <-after(timeout): + return nil, ErrReadTimeout case <-s.closed: return nil, io.EOF } @@ -193,7 +207,8 @@ func (s *session) Discover() error { after := func() time.Duration { d := time.Since(now) // dial timeout minus time since - wait := s.timeout - d + wait := s.dialTimeout - d + // make sure its always > 0 if wait < time.Duration(0) { return time.Duration(0) } @@ -248,7 +263,7 @@ func (s *session) Open() error { } // now wait for the accept message to be returned - msg, err := s.waitFor("accept", s.timeout) + msg, err := s.waitFor("accept", s.dialTimeout) if err != nil { return err } @@ -341,11 +356,9 @@ func (s *session) Send(m *transport.Message) error { func (s *session) Recv(m *transport.Message) error { var msg *message - select { - case <-s.closed: - return io.EOF - // recv from backlog - case msg = <-s.recv: + msg, err := s.waitFor("", s.readTimeout) + if err != nil { + return err } // check the error if one exists diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index abd61cb4..15ff6078 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -26,6 +26,8 @@ var ( ErrDiscoverChan = errors.New("failed to discover channel") // ErrLinkNotFound is returned when a link is specified at dial time and does not exist ErrLinkNotFound = errors.New("link not found") + // ErrReadTimeout is a timeout on session.Recv + ErrReadTimeout = errors.New("read timeout") ) // Mode of the session