diff --git a/client/rpc_client.go b/client/rpc_client.go index 75ce75ac..4da6df45 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -316,6 +316,22 @@ func (r *rpcClient) Options() Options { return r.opts } +// hasProxy checks if we have proxy set in the environment +func (r *rpcClient) hasProxy() bool { + // get proxy + if prx := os.Getenv("MICRO_PROXY"); len(prx) > 0 { + return true + } + + // get proxy address + if prx := os.Getenv("MICRO_PROXY_ADDRESS"); len(prx) > 0 { + return true + } + + return false +} + +// next returns an iterator for the next nodes to call func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) { service := request.Service() @@ -431,10 +447,18 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac return err } - ch := make(chan error, callOpts.Retries+1) + // get the retries + retries := callOpts.Retries + + // disable retries when using a proxy + if r.hasProxy() { + retries = 0 + } + + ch := make(chan error, retries+1) var gerr error - for i := 0; i <= callOpts.Retries; i++ { + for i := 0; i <= retries; i++ { go func(i int) { ch <- call(i) }(i) @@ -514,10 +538,18 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt err error } - ch := make(chan response, callOpts.Retries+1) + // get the retries + retries := callOpts.Retries + + // disable retries when using a proxy + if r.hasProxy() { + retries = 0 + } + + ch := make(chan response, retries+1) var grr error - for i := 0; i <= callOpts.Retries; i++ { + for i := 0; i <= retries; i++ { go func(i int) { s, err := call(i) ch <- response{s, err} diff --git a/client/rpc_codec.go b/client/rpc_codec.go index 60dc02b5..a71f6a11 100644 --- a/client/rpc_codec.go +++ b/client/rpc_codec.go @@ -88,32 +88,24 @@ func (rwc *readWriteCloser) Close() error { } func getHeaders(m *codec.Message) { - get := func(hdr string) string { - if hd := m.Header[hdr]; len(hd) > 0 { - return hd + set := func(v, hdr string) string { + if len(v) > 0 { + return v } - // old - return m.Header["X-"+hdr] + return m.Header[hdr] } // check error in header - if len(m.Error) == 0 { - m.Error = get("Micro-Error") - } + m.Error = set(m.Error, "Micro-Error") // check endpoint in header - if len(m.Endpoint) == 0 { - m.Endpoint = get("Micro-Endpoint") - } + m.Endpoint = set(m.Endpoint, "Micro-Endpoint") // check method in header - if len(m.Method) == 0 { - m.Method = get("Micro-Method") - } + m.Method = set(m.Method, "Micro-Method") - if len(m.Id) == 0 { - m.Id = get("Micro-Id") - } + // set the request id + m.Id = set(m.Id, "Micro-Id") } func setHeaders(m *codec.Message, stream string) { @@ -122,7 +114,6 @@ func setHeaders(m *codec.Message, stream string) { return } m.Header[hdr] = v - m.Header["X-"+hdr] = v } set("Micro-Id", m.Id) diff --git a/network/default.go b/network/default.go index ef859d51..9f29414d 100644 --- a/network/default.go +++ b/network/default.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "hash/fnv" + "io" "math" "sync" "time" @@ -70,6 +71,18 @@ type network struct { connected bool // closed closes the network closed chan bool + // whether we've discovered by the network + discovered chan bool + // solicted checks whether routes were solicited by one node + solicited chan string +} + +// message is network message +type message struct { + // msg is transport message + msg *transport.Message + // session is tunnel session + session tunnel.Session } // newNetwork returns a new network node @@ -145,14 +158,16 @@ func newNetwork(opts ...Option) Network { address: peerAddress, peers: make(map[string]*node), }, - options: options, - router: options.Router, - proxy: options.Proxy, - tunnel: options.Tunnel, - server: server, - client: client, - tunClient: make(map[string]transport.Client), - peerLinks: make(map[string]tunnel.Link), + options: options, + router: options.Router, + proxy: options.Proxy, + tunnel: options.Tunnel, + server: server, + client: client, + tunClient: make(map[string]transport.Client), + peerLinks: make(map[string]tunnel.Link), + discovered: make(chan bool, 1), + solicited: make(chan string, 1), } network.node.network = network @@ -187,10 +202,30 @@ func (n *network) Name() string { return n.options.Name } +func (n *network) initNodes(startup bool) { + nodes, err := n.resolveNodes() + if err != nil && !startup { + log.Debugf("Network failed to resolve nodes: %v", err) + return + } + + // initialize the tunnel + log.Tracef("Network initialising nodes %+v\n", nodes) + + n.tunnel.Init( + tunnel.Nodes(nodes...), + ) +} + // resolveNodes resolves network nodes to addresses func (n *network) resolveNodes() ([]string, error) { // resolve the network address to network nodes records, err := n.options.Resolver.Resolve(n.options.Name) + if err != nil { + log.Debugf("Network failed to resolve nodes: %v", err) + } + + // keep processing nodeMap := make(map[string]bool) @@ -219,7 +254,7 @@ func (n *network) resolveNodes() ([]string, error) { dns := &dns.Resolver{} // append seed nodes if we have them - for _, node := range n.options.Peers { + for _, node := range n.options.Nodes { // resolve anything that looks like a host name records, err := dns.Resolve(node) if err != nil { @@ -235,30 +270,7 @@ func (n *network) resolveNodes() ([]string, error) { } } - return nodes, err -} - -// resolve continuously resolves network nodes and initializes network tunnel with resolved addresses -func (n *network) resolve() { - resolve := time.NewTicker(ResolveTime) - defer resolve.Stop() - - for { - select { - case <-n.closed: - return - case <-resolve.C: - nodes, err := n.resolveNodes() - if err != nil { - log.Debugf("Network failed to resolve nodes: %v", err) - continue - } - // initialize the tunnel - n.tunnel.Init( - tunnel.Nodes(nodes...), - ) - } - } + return nodes, nil } // handleNetConn handles network announcement messages @@ -267,10 +279,20 @@ func (n *network) handleNetConn(s tunnel.Session, msg chan *message) { m := new(transport.Message) if err := s.Recv(m); err != nil { log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err) - if sessionErr := s.Close(); sessionErr != nil { - log.Debugf("Network tunnel [%s] closing connection error: %v", NetworkChannel, sessionErr) + switch err { + case io.EOF, tunnel.ErrReadTimeout: + s.Close() + return } - return + continue + } + + // check if peer is set + peer := m.Header["Micro-Peer"] + + // check who the message is intended for + if len(peer) > 0 && peer != n.options.Id { + continue } select { @@ -314,39 +336,10 @@ func (n *network) acceptNetConn(l tunnel.Listener, recv chan *message) { } } -// updatePeerLinks updates link for a given peer -func (n *network) updatePeerLinks(peerAddr string, linkId string) error { - n.Lock() - defer n.Unlock() - log.Tracef("Network looking up link %s in the peer links", linkId) - // lookup the peer link - var peerLink tunnel.Link - for _, link := range n.tunnel.Links() { - if link.Id() == linkId { - peerLink = link - break - } - } - if peerLink == nil { - return ErrPeerLinkNotFound - } - // if the peerLink is found in the returned links update peerLinks - log.Tracef("Network updating peer links for peer %s", peerAddr) - // add peerLink to the peerLinks map - if link, ok := n.peerLinks[peerAddr]; ok { - // if the existing has better Length then the new, replace it - if link.Length() < peerLink.Length() { - n.peerLinks[peerAddr] = peerLink - } - } else { - n.peerLinks[peerAddr] = peerLink - } - - return nil -} - // processNetChan processes messages received on NetworkChannel func (n *network) processNetChan(listener tunnel.Listener) { + defer listener.Close() + // receive network message queue recv := make(chan *message, 128) @@ -362,26 +355,32 @@ func (n *network) processNetChan(listener tunnel.Listener) { // mark the time the message has been received now := time.Now() pbNetConnect := &pbNet.Connect{} + if err := proto.Unmarshal(m.msg.Body, pbNetConnect); err != nil { log.Debugf("Network tunnel [%s] connect unmarshal error: %v", NetworkChannel, err) continue } + // don't process your own messages if pbNetConnect.Node.Id == n.options.Id { continue } + log.Debugf("Network received connect message from: %s", pbNetConnect.Node.Id) + peer := &node{ id: pbNetConnect.Node.Id, address: pbNetConnect.Node.Address, peers: make(map[string]*node), lastSeen: now, } + // update peer links - log.Tracef("Network updating peer link %s for peer: %s", m.session.Link(), pbNetConnect.Node.Address) - if err := n.updatePeerLinks(pbNetConnect.Node.Address, m.session.Link()); err != nil { + + if err := n.updatePeerLinks(pbNetConnect.Node.Address, m); err != nil { log.Debugf("Network failed updating peer links: %s", err) } + // add peer to the list of node peers if err := n.node.AddPeer(peer); err == ErrPeerExists { log.Debugf("Network peer exists, refreshing: %s", peer.id) @@ -389,50 +388,75 @@ func (n *network) processNetChan(listener tunnel.Listener) { if err := n.RefreshPeer(peer.id, now); err != nil { log.Debugf("Network failed refreshing peer %s: %v", peer.id, err) } - continue } + + // we send the peer message because someone has sent connect + // and wants to know what's on the network. The faster we + // respond the faster we start to converge + // get node peers down to MaxDepth encoded in protobuf msg := PeersToProto(n.node, MaxDepth) + node := pbNetConnect.Node.Id + // advertise yourself to the network - if err := n.sendMsg("peer", msg, NetworkChannel); err != nil { + if err := n.sendTo("peer", NetworkChannel, node, msg); err != nil { log.Debugf("Network failed to advertise peers: %v", err) } + // advertise all the routes when a new node has connected if err := n.router.Solicit(); err != nil { log.Debugf("Network failed to solicit routes: %s", err) } + + // specify that we're soliciting + select { + case n.solicited <- node: + default: + // don't block + } case "peer": // mark the time the message has been received now := time.Now() pbNetPeer := &pbNet.Peer{} + if err := proto.Unmarshal(m.msg.Body, pbNetPeer); err != nil { log.Debugf("Network tunnel [%s] peer unmarshal error: %v", NetworkChannel, err) continue } + // don't process your own messages if pbNetPeer.Node.Id == n.options.Id { continue } + log.Debugf("Network received peer message from: %s %s", pbNetPeer.Node.Id, pbNetPeer.Node.Address) + peer := &node{ id: pbNetPeer.Node.Id, address: pbNetPeer.Node.Address, peers: make(map[string]*node), lastSeen: now, } + // update peer links - log.Tracef("Network updating peer link %s for peer: %s", m.session.Link(), pbNetPeer.Node.Address) - if err := n.updatePeerLinks(pbNetPeer.Node.Address, m.session.Link()); err != nil { + + if err := n.updatePeerLinks(pbNetPeer.Node.Address, m); err != nil { log.Debugf("Network failed updating peer links: %s", err) } + if err := n.node.AddPeer(peer); err == nil { // send a solicit message when discovering new peer msg := &pbRtr.Solicit{ Id: n.options.Id, } - if err := n.sendMsg("solicit", msg, ControlChannel); err != nil { + + node := pbNetPeer.Node.Id + + // only solicit this peer + if err := n.sendTo("solicit", ControlChannel, node, msg); err != nil { log.Debugf("Network failed to send solicit message: %s", err) } + continue // we're expecting any error to be ErrPeerExists } else if err != ErrPeerExists { @@ -441,6 +465,7 @@ func (n *network) processNetChan(listener tunnel.Listener) { } log.Debugf("Network peer exists, refreshing: %s", pbNetPeer.Node.Id) + // update lastSeen time for the peer if err := n.RefreshPeer(pbNetPeer.Node.Id, now); err != nil { log.Debugf("Network failed refreshing peer %s: %v", pbNetPeer.Node.Id, err) @@ -452,28 +477,42 @@ func (n *network) processNetChan(listener tunnel.Listener) { if err := n.node.UpdatePeer(peer); err != nil { log.Debugf("Network failed to update peers: %v", err) } + + // tell the connect loop that we've been discovered + // so it stops sending connect messages out + select { + case n.discovered <- true: + default: + // don't block here + } case "close": pbNetClose := &pbNet.Close{} if err := proto.Unmarshal(m.msg.Body, pbNetClose); err != nil { log.Debugf("Network tunnel [%s] close unmarshal error: %v", NetworkChannel, err) continue } + // don't process your own messages if pbNetClose.Node.Id == n.options.Id { continue } + log.Debugf("Network received close message from: %s", pbNetClose.Node.Id) + peer := &node{ id: pbNetClose.Node.Id, address: pbNetClose.Node.Address, } + if err := n.DeletePeerNode(peer.id); err != nil { log.Debugf("Network failed to delete node %s routes: %v", peer.id, err) } + if err := n.prunePeerRoutes(peer); err != nil { log.Debugf("Network failed pruning peer %s routes: %v", peer.id, err) } - // deelete peer from the peerLinks + + // delete peer from the peerLinks n.Lock() delete(n.peerLinks, pbNetClose.Node.Address) n.Unlock() @@ -484,57 +523,6 @@ func (n *network) processNetChan(listener tunnel.Listener) { } } -// sendMsg sends a message to the tunnel channel -func (n *network) sendMsg(method string, msg proto.Message, channel string) error { - body, err := proto.Marshal(msg) - if err != nil { - return err - } - // create transport message and chuck it down the pipe - m := transport.Message{ - Header: map[string]string{ - "Micro-Method": method, - }, - Body: body, - } - - // check if the channel client is initialized - n.RLock() - client, ok := n.tunClient[channel] - if !ok || client == nil { - n.RUnlock() - return ErrClientNotFound - } - n.RUnlock() - - log.Debugf("Network sending %s message from: %s", method, n.options.Id) - if err := client.Send(&m); err != nil { - return err - } - - return nil -} - -// announce announces node peers to the network -func (n *network) announce(client transport.Client) { - announce := time.NewTicker(AnnounceTime) - defer announce.Stop() - - for { - select { - case <-n.closed: - return - case <-announce.C: - msg := PeersToProto(n.node, MaxDepth) - // advertise yourself to the network - if err := n.sendMsg("peer", msg, NetworkChannel); err != nil { - log.Debugf("Network failed to advertise peers: %v", err) - continue - } - } - } -} - // pruneRoutes prunes routes return by given query func (n *network) pruneRoutes(q ...router.QueryOption) error { routes, err := n.router.Table().Query(q...) @@ -572,60 +560,186 @@ func (n *network) prunePeerRoutes(peer *node) error { return nil } -// prune deltes node peers that have not been seen for longer than PruneTime seconds -// prune also removes all the routes either originated by or routable by the stale nodes -func (n *network) prune() { +// manage the process of announcing to peers and prune any peer nodes that have not been +// seen for a period of time. Also removes all the routes either originated by or routable +//by the stale nodes. it also resolves nodes periodically and adds them to the tunnel +func (n *network) manage() { + announce := time.NewTicker(AnnounceTime) + defer announce.Stop() prune := time.NewTicker(PruneTime) defer prune.Stop() + resolve := time.NewTicker(ResolveTime) + defer resolve.Stop() for { select { case <-n.closed: return + case <-announce.C: + msg := PeersToProto(n.node, MaxDepth) + // advertise yourself to the network + if err := n.sendMsg("peer", NetworkChannel, msg); err != nil { + log.Debugf("Network failed to advertise peers: %v", err) + } case <-prune.C: - pruned := n.PruneStalePeerNodes(PruneTime) + pruned := n.PruneStalePeers(PruneTime) + for id, peer := range pruned { log.Debugf("Network peer exceeded prune time: %s", id) + n.Lock() delete(n.peerLinks, peer.address) n.Unlock() + if err := n.prunePeerRoutes(peer); err != nil { log.Debugf("Network failed pruning peer %s routes: %v", id, err) } } + // get a list of all routes routes, err := n.options.Router.Table().List() if err != nil { log.Debugf("Network failed listing routes when pruning peers: %v", err) continue } + // collect all the router IDs in the routing table routers := make(map[string]bool) + for _, route := range routes { - if _, ok := routers[route.Router]; !ok { - routers[route.Router] = true - // if the router is NOT in our peer graph, delete all routes originated by it - if peerNode := n.node.GetPeerNode(route.Router); peerNode == nil { - if err := n.pruneRoutes(router.QueryRouter(route.Router)); err != nil { - log.Debugf("Network failed deleting routes by %s: %v", route.Router, err) - } - } + // check if its been processed + if _, ok := routers[route.Router]; ok { + continue + } + + // mark as processed + routers[route.Router] = true + + // if the router is NOT in our peer graph, delete all routes originated by it + if peer := n.node.GetPeerNode(route.Router); peer != nil { + continue + } + + if err := n.pruneRoutes(router.QueryRouter(route.Router)); err != nil { + log.Debugf("Network failed deleting routes by %s: %v", route.Router, err) } } + case <-resolve.C: + n.initNodes(false) } } } +// sendTo sends a message to a specific node as a one off. +// we need this because when links die, we have no discovery info, +// and sending to an existing multicast link doesn't immediately work +func (n *network) sendTo(method, channel, peer string, msg proto.Message) error { + body, err := proto.Marshal(msg) + if err != nil { + return err + } + c, err := n.tunnel.Dial(channel, tunnel.DialMode(tunnel.Multicast)) + if err != nil { + return err + } + defer c.Close() + + log.Debugf("Network sending %s message from: %s to %s", method, n.options.Id, peer) + + return c.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Method": method, + "Micro-Peer": peer, + }, + Body: body, + }) +} + +// sendMsg sends a message to the tunnel channel +func (n *network) sendMsg(method, channel string, msg proto.Message) error { + body, err := proto.Marshal(msg) + if err != nil { + return err + } + + // check if the channel client is initialized + n.RLock() + client, ok := n.tunClient[channel] + if !ok || client == nil { + n.RUnlock() + return ErrClientNotFound + } + n.RUnlock() + + log.Debugf("Network sending %s message from: %s", method, n.options.Id) + + return client.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Method": method, + }, + Body: body, + }) +} + +// updatePeerLinks updates link for a given peer +func (n *network) updatePeerLinks(peerAddr string, m *message) error { + n.Lock() + defer n.Unlock() + + linkId := m.msg.Header["Micro-Link"] + + log.Tracef("Network looking up link %s in the peer links", linkId) + + // lookup the peer link + var peerLink tunnel.Link + + for _, link := range n.tunnel.Links() { + if link.Id() == linkId { + peerLink = link + break + } + } + + if peerLink == nil { + return ErrPeerLinkNotFound + } + + // if the peerLink is found in the returned links update peerLinks + log.Tracef("Network updating peer links for peer %s", peerAddr) + + // add peerLink to the peerLinks map + if link, ok := n.peerLinks[peerAddr]; ok { + // if the existing has better Length then the new, replace it + if link.Length() < peerLink.Length() { + n.peerLinks[peerAddr] = peerLink + } + } else { + n.peerLinks[peerAddr] = peerLink + } + + return nil +} + // handleCtrlConn handles ControlChannel connections func (n *network) handleCtrlConn(s tunnel.Session, msg chan *message) { for { m := new(transport.Message) if err := s.Recv(m); err != nil { log.Debugf("Network tunnel [%s] receive error: %v", ControlChannel, err) - if sessionErr := s.Close(); sessionErr != nil { - log.Debugf("Network tunnel [%s] closing connection error: %v", ControlChannel, sessionErr) + switch err { + case io.EOF, tunnel.ErrReadTimeout: + s.Close() + return } - return + continue + } + + // check if peer is set + peer := m.Header["Micro-Peer"] + + // check who the message is intended for + if len(peer) > 0 && peer != n.options.Id { + continue } select { @@ -743,6 +857,8 @@ func (n *network) getRouteMetric(router string, gateway string, link string) int // processCtrlChan processes messages received on ControlChannel func (n *network) processCtrlChan(listener tunnel.Listener) { + defer listener.Close() + // receive control message queue recv := make(chan *message, 128) @@ -756,15 +872,19 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { switch m.msg.Header["Micro-Method"] { case "advert": pbRtrAdvert := &pbRtr.Advert{} + if err := proto.Unmarshal(m.msg.Body, pbRtrAdvert); err != nil { log.Debugf("Network fail to unmarshal advert message: %v", err) continue } + // don't process your own messages if pbRtrAdvert.Id == n.options.Id { continue } + log.Debugf("Network received advert message from: %s", pbRtrAdvert.Id) + // loookup advertising node in our peer topology advertNode := n.node.GetPeerNode(pbRtrAdvert.Id) if advertNode == nil { @@ -774,6 +894,7 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { } var events []*router.Event + for _, event := range pbRtrAdvert.Events { // we know the advertising node is not the origin of the route if pbRtrAdvert.Id != event.Route.Router { @@ -784,6 +905,7 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { continue } } + route := router.Route{ Service: event.Route.Service, Address: event.Route.Address, @@ -793,6 +915,7 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { Link: event.Route.Link, Metric: event.Route.Metric, } + // calculate route metric and add to the advertised metric // we need to make sure we do not overflow math.MaxInt64 metric := n.getRouteMetric(event.Route.Router, event.Route.Gateway, event.Route.Link) @@ -815,11 +938,13 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { } events = append(events, e) } + // if no events are eligible for processing continue if len(events) == 0 { log.Tracef("Network no events to be processed by router: %s", n.options.Id) continue } + // create an advert and process it advert := &router.Advert{ Id: pbRtrAdvert.Id, @@ -839,16 +964,27 @@ func (n *network) processCtrlChan(listener tunnel.Listener) { log.Debugf("Network fail to unmarshal solicit message: %v", err) continue } + log.Debugf("Network received solicit message from: %s", pbRtrSolicit.Id) + // ignore solicitation when requested by you if pbRtrSolicit.Id == n.options.Id { continue } + log.Debugf("Network router flushing routes for: %s", pbRtrSolicit.Id) + // advertise all the routes when a new node has connected if err := n.router.Solicit(); err != nil { log.Debugf("Network failed to solicit routes: %s", err) } + + // specify that someone solicited the route + select { + case n.solicited <- pbRtrSolicit.Id: + default: + // don't block + } } case <-n.closed: return @@ -865,6 +1001,7 @@ func (n *network) advertise(advertChan <-chan *router.Advert) { case advert := <-advertChan: // create a proto advert var events []*pbRtr.Event + for _, event := range advert.Events { // the routes service address address := event.Route.Address @@ -898,16 +1035,33 @@ func (n *network) advertise(advertChan <-chan *router.Advert) { } events = append(events, e) } + msg := &pbRtr.Advert{ Id: advert.Id, Type: pbRtr.AdvertType(advert.Type), Timestamp: advert.Timestamp.UnixNano(), Events: events, } - if err := n.sendMsg("advert", msg, ControlChannel); err != nil { - log.Debugf("Network failed to advertise routes: %v", err) + + // send the advert to all on the control channel + // since its not a solicitation + if advert.Type != router.Solicitation { + if err := n.sendMsg("advert", ControlChannel, msg); err != nil { + log.Debugf("Network failed to advertise routes: %v", err) + } continue } + + // it's a solication, someone asked for it + // so we're going to pick off the node and send it + select { + case node := <-n.solicited: + // someone requested the route + n.sendTo("advert", ControlChannel, node, msg) + default: + // send to all since we can't get anything + n.sendMsg("advert", ControlChannel, msg) + } case <-n.closed: return } @@ -925,138 +1079,187 @@ func (n *network) sendConnect() { Address: n.node.address, }, } - if err := n.sendMsg("connect", msg, NetworkChannel); err != nil { + + if err := n.sendMsg("connect", NetworkChannel, msg); err != nil { log.Debugf("Network failed to send connect message: %s", err) } } +// connect will wait for a link to be established and send the connect +// message. We're trying to ensure convergence pretty quickly. So we want +// to hear back. In the case we become completely disconnected we'll +// connect again once a new link is established +func (n *network) connect() { + // discovered lets us know what we received a peer message back + var discovered bool + var attempts int + + // our advertise address + loopback := n.server.Options().Advertise + // actual address + address := n.tunnel.Address() + + for { + // connected is used to define if the link is connected + var connected bool + + // check the links state + for _, link := range n.tunnel.Links() { + // skip loopback + if link.Loopback() { + continue + } + + // if remote is ourselves + switch link.Remote() { + case loopback, address: + continue + } + + if link.State() == "connected" { + connected = true + break + } + } + + // if we're not connected wait + if !connected { + // reset discovered + discovered = false + // sleep for a second + time.Sleep(time.Second) + // now try again + continue + } + + // we're connected but are we discovered? + if !discovered { + // recreate the clients because all the tunnel links are gone + // so we haven't send discovery beneath + if err := n.createClients(); err != nil { + log.Debugf("Failed to recreate network/control clients: %v", err) + continue + } + + // send the connect message + n.sendConnect() + } + + // check if we've been discovered + select { + case <-n.discovered: + discovered = true + attempts = 0 + case <-n.closed: + return + case <-time.After(time.Second + backoff.Do(attempts)): + // we have to try again + attempts++ + + // reset attempts 5 == ~2mins + if attempts > 5 { + attempts = 0 + } + } + } +} + // Connect connects the network func (n *network) Connect() error { n.Lock() - - // try to resolve network nodes - nodes, err := n.resolveNodes() - if err != nil { - log.Debugf("Network failed to resolve nodes: %v", err) - } + defer n.Unlock() // connect network tunnel if err := n.tunnel.Connect(); err != nil { - n.Unlock() return err } - // initialize the tunnel to resolved nodes - n.tunnel.Init( - tunnel.Nodes(nodes...), - ) - // return if already connected if n.connected { - // unlock first - n.Unlock() + // initialise the nodes + n.initNodes(false) // send the connect message - n.sendConnect() + go n.sendConnect() return nil } + // initialise the nodes + n.initNodes(true) + // set our internal node address // if advertise address is not set if len(n.options.Advertise) == 0 { n.server.Init(server.Advertise(n.tunnel.Address())) } + // listen on NetworkChannel + netListener, err := n.tunnel.Listen( + NetworkChannel, + tunnel.ListenMode(tunnel.Multicast), + tunnel.ListenTimeout(AnnounceTime*2), + ) + if err != nil { + return err + } + + // listen on ControlChannel + ctrlListener, err := n.tunnel.Listen( + ControlChannel, + tunnel.ListenMode(tunnel.Multicast), + tunnel.ListenTimeout(router.AdvertiseTableTick*2), + ) + if err != nil { + return err + } + // dial into ControlChannel to send route adverts ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMode(tunnel.Multicast)) if err != nil { - n.Unlock() return err } n.tunClient[ControlChannel] = ctrlClient - // listen on ControlChannel - ctrlListener, err := n.tunnel.Listen(ControlChannel, tunnel.ListenMode(tunnel.Multicast)) - if err != nil { - n.Unlock() - return err - } - // dial into NetworkChannel to send network messages netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMode(tunnel.Multicast)) if err != nil { - n.Unlock() return err } n.tunClient[NetworkChannel] = netClient - // listen on NetworkChannel - netListener, err := n.tunnel.Listen(NetworkChannel, tunnel.ListenMode(tunnel.Multicast)) - if err != nil { - n.Unlock() - return err - } - // create closed channel n.closed = make(chan bool) // start the router if err := n.options.Router.Start(); err != nil { - n.Unlock() return err } // start advertising routes advertChan, err := n.options.Router.Advertise() if err != nil { - n.Unlock() return err } // start the server if err := n.server.Start(); err != nil { - n.Unlock() return err } - n.Unlock() - // send connect after there's a link established - go func() { - // wait for 30 ticks e.g 30 seconds - for i := 0; i < 30; i++ { - // get the current links - links := n.tunnel.Links() - - // if there are no links wait until we have one - if len(links) == 0 { - time.Sleep(time.Second) - continue - } - - // send the connect message - n.sendConnect() - // most importantly - break - } - }() - - // go resolving network nodes - go n.resolve() - // broadcast peers - go n.announce(netClient) - // prune stale nodes - go n.prune() - // listen to network messages - go n.processNetChan(netListener) // advertise service routes go n.advertise(advertChan) + // listen to network messages + go n.processNetChan(netListener) // accept and process routes go n.processCtrlChan(ctrlListener) + // manage connection once links are established + go n.connect() + // resolve nodes, broadcast announcements and prune stale nodes + go n.manage() - n.Lock() + // we're now connected n.connected = true - n.Unlock() return nil } @@ -1080,6 +1283,40 @@ func (n *network) close() error { return nil } +// createClients is used to create new clients in the event we lose all the tunnels +func (n *network) createClients() error { + // dial into ControlChannel to send route adverts + ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMode(tunnel.Multicast)) + if err != nil { + return err + } + + // dial into NetworkChannel to send network messages + netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMode(tunnel.Multicast)) + if err != nil { + return err + } + + n.Lock() + defer n.Unlock() + + // set the control client + c, ok := n.tunClient[ControlChannel] + if ok { + c.Close() + } + n.tunClient[ControlChannel] = ctrlClient + + // set the network client + c, ok = n.tunClient[NetworkChannel] + if ok { + c.Close() + } + n.tunClient[NetworkChannel] = netClient + + return nil +} + // Close closes network connection func (n *network) Close() error { n.Lock() @@ -1108,7 +1345,8 @@ func (n *network) Close() error { Address: n.node.address, }, } - if err := n.sendMsg("close", msg, NetworkChannel); err != nil { + + if err := n.sendMsg("close", NetworkChannel, msg); err != nil { log.Debugf("Network failed to send close message: %s", err) } } diff --git a/network/network.go b/network/network.go index 08a45665..e927241b 100644 --- a/network/network.go +++ b/network/network.go @@ -6,8 +6,6 @@ import ( "github.com/micro/go-micro/client" "github.com/micro/go-micro/server" - "github.com/micro/go-micro/transport" - "github.com/micro/go-micro/tunnel" ) var ( @@ -56,14 +54,6 @@ type Network interface { Server() server.Server } -// message is network message -type message struct { - // msg is transport message - msg *transport.Message - // session is tunnel session - session tunnel.Session -} - // NewNetwork returns a new network interface func NewNetwork(opts ...Option) Network { return newNetwork(opts...) diff --git a/network/node.go b/network/node.go index f0d3809c..f1a51b7d 100644 --- a/network/node.go +++ b/network/node.go @@ -216,7 +216,7 @@ func (n *node) DeletePeerNode(id string) error { // PruneStalePeerNodes prune the peers that have not been seen for longer than given time // It returns a map of the the nodes that got pruned -func (n *node) PruneStalePeerNodes(pruneTime time.Duration) map[string]*node { +func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node { n.Lock() defer n.Unlock() diff --git a/network/node_test.go b/network/node_test.go index 80a41c70..71b21a19 100644 --- a/network/node_test.go +++ b/network/node_test.go @@ -225,7 +225,7 @@ func TestPruneStalePeerNodes(t *testing.T) { time.Sleep(pruneTime) // should delete all nodes besides node - pruned := node.PruneStalePeerNodes(pruneTime) + pruned := node.PruneStalePeers(pruneTime) if len(pruned) != len(nodes)-1 { t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned)) diff --git a/network/options.go b/network/options.go index 0a3c5b2b..63ff6fe7 100644 --- a/network/options.go +++ b/network/options.go @@ -22,8 +22,8 @@ type Options struct { Address string // Advertise sets the address to advertise Advertise string - // Peers is a list of peers to connect to - Peers []string + // Nodes is a list of nodes to connect to + Nodes []string // Tunnel is network tunnel Tunnel tunnel.Tunnel // Router is network router @@ -62,10 +62,10 @@ func Advertise(a string) Option { } } -// Peers is a list of peers to connect to -func Peers(n ...string) Option { +// Nodes is a list of nodes to connect to +func Nodes(n ...string) Option { return func(o *Options) { - o.Peers = n + o.Nodes = n } } diff --git a/network/resolver/dns/dns.go b/network/resolver/dns/dns.go index b4029f06..a8d27b06 100644 --- a/network/resolver/dns/dns.go +++ b/network/resolver/dns/dns.go @@ -31,6 +31,17 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { r.Address = "1.0.0.1:53" } + //nolint:prealloc + var records []*resolver.Record + + // parsed an actual ip + if v := net.ParseIP(host); v != nil { + records = append(records, &resolver.Record{ + Address: net.JoinHostPort(host, port), + }) + return records, nil + } + m := new(dns.Msg) m.SetQuestion(dns.Fqdn(host), dns.TypeA) rec, err := dns.ExchangeContext(context.Background(), m, r.Address) @@ -38,9 +49,6 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { return nil, err } - //nolint:prealloc - var records []*resolver.Record - for _, answer := range rec.Answer { h := answer.Header() // check record type matches @@ -59,5 +67,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { }) } + // no records returned so just best effort it + if len(records) == 0 { + records = append(records, &resolver.Record{ + Address: net.JoinHostPort(host, port), + }) + } + return records, nil } diff --git a/network/service/handler/handler.go b/network/service/handler/handler.go index 5a9f9379..d8538a7c 100644 --- a/network/service/handler/handler.go +++ b/network/service/handler/handler.go @@ -55,7 +55,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp * } // get list of existing nodes - nodes := n.Network.Options().Peers + nodes := n.Network.Options().Nodes // generate a node map nodeMap := make(map[string]bool) @@ -84,7 +84,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp * // reinitialise the peers n.Network.Init( - network.Peers(nodes...), + network.Nodes(nodes...), ) // call the connect method diff --git a/proxy/mucp/mucp.go b/proxy/mucp/mucp.go index 0e04c7d6..d565d47a 100644 --- a/proxy/mucp/mucp.go +++ b/proxy/mucp/mucp.go @@ -83,7 +83,8 @@ func readLoop(r server.Request, s client.Stream) error { // toNodes returns a list of node addresses from given routes func toNodes(routes []router.Route) []string { - nodes := make([]string, len(routes)) + nodes := make([]string, 0, len(routes)) + for _, node := range routes { address := node.Address if len(node.Gateway) > 0 { @@ -91,11 +92,13 @@ func toNodes(routes []router.Route) []string { } nodes = append(nodes, address) } + return nodes } func toSlice(r map[uint64]router.Route) []router.Route { routes := make([]router.Route, 0, len(r)) + for _, v := range r { routes = append(routes, v) } @@ -161,6 +164,8 @@ func (p *Proxy) filterRoutes(ctx context.Context, routes []router.Route) []route filteredRoutes = append(filteredRoutes, route) } + log.Tracef("Proxy filtered routes %+v\n", filteredRoutes) + return filteredRoutes } @@ -225,13 +230,15 @@ func (p *Proxy) cacheRoutes(service string) ([]router.Route, error) { // refreshMetrics will refresh any metrics for our local cached routes. // we may not receive new watch events for these as they change. func (p *Proxy) refreshMetrics() { - services := make([]string, 0, len(p.Routes)) - // get a list of services to update p.RLock() + + services := make([]string, 0, len(p.Routes)) + for service := range p.Routes { services = append(services, service) } + p.RUnlock() // get and cache the routes for the service @@ -246,6 +253,8 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error { p.Lock() defer p.Unlock() + log.Tracef("Proxy taking route action %v %+v\n", action, route) + switch action { case "create", "update": if _, ok := p.Routes[route.Service]; !ok { @@ -253,7 +262,12 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error { } p.Routes[route.Service][route.Hash()] = route case "delete": + // delete that specific route delete(p.Routes[route.Service], route.Hash()) + // clean up the cache entirely + if len(p.Routes[route.Service]) == 0 { + delete(p.Routes, route.Service) + } default: return fmt.Errorf("unknown action: %s", action) } @@ -288,7 +302,7 @@ func (p *Proxy) ProcessMessage(ctx context.Context, msg server.Message) error { // TODO: check that we're not broadcast storming by sending to the same topic // that we're actually subscribed to - log.Tracef("Received message for %s", msg.Topic()) + log.Tracef("Proxy received message for %s", msg.Topic()) var errors []string @@ -329,7 +343,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server return errors.BadRequest("go.micro.proxy", "service name is blank") } - log.Tracef("Received request for %s", service) + log.Tracef("Proxy received request for %s", service) // are we network routing or local routing if len(p.Links) == 0 { @@ -363,15 +377,17 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server } //nolint:prealloc - var opts []client.CallOption - - // set strategy to round robin - opts = append(opts, client.WithSelectOption(selector.WithStrategy(selector.RoundRobin))) + opts := []client.CallOption{ + // set strategy to round robin + client.WithSelectOption(selector.WithStrategy(selector.RoundRobin)), + } // if the address is already set just serve it // TODO: figure it out if we should know to pick a link if len(addresses) > 0 { - opts = append(opts, client.WithAddress(addresses...)) + opts = append(opts, + client.WithAddress(addresses...), + ) // serve the normal way return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) @@ -387,10 +403,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server opts = append(opts, client.WithAddress(addresses...)) } + log.Tracef("Proxy calling %+v\n", addresses) // serve the normal way return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) } + // we're assuming we need routes to operate on + if len(routes) == 0 { + return errors.InternalServerError("go.micro.proxy", "route not found") + } + var gerr error // we're routing globally with multiple links @@ -404,11 +426,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server continue } - log.Debugf("Proxy using route %+v\n", route) + log.Tracef("Proxy using route %+v\n", route) // set the address to call addresses := toNodes([]router.Route{route}) - opts = append(opts, client.WithAddress(addresses...)) + // set the address in the options + // disable retries since its one route processing + opts = append(opts, + client.WithAddress(addresses...), + client.WithRetries(0), + ) // do the request with the link gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...) @@ -558,7 +585,9 @@ func NewProxy(opts ...options.Option) proxy.Proxy { }() go func() { - t := time.NewTicker(time.Minute) + // TODO: speed up refreshing of metrics + // without this ticking effort e.g stream + t := time.NewTicker(time.Second * 10) defer t.Stop() // we must refresh route metrics since they do not trigger new events diff --git a/router/default.go b/router/default.go index b0739ebf..02ceac1b 100644 --- a/router/default.go +++ b/router/default.go @@ -799,7 +799,8 @@ func (r *router) flushRouteEvents(evType EventType) ([]*Event, error) { // build a list of events to advertise events := make([]*Event, len(bestRoutes)) - i := 0 + var i int + for _, route := range bestRoutes { event := &Event{ Type: evType, @@ -823,9 +824,10 @@ func (r *router) Solicit() error { // advertise the routes r.advertWg.Add(1) + go func() { - defer r.advertWg.Done() - r.publishAdvert(RouteUpdate, events) + r.publishAdvert(Solicitation, events) + r.advertWg.Done() }() return nil diff --git a/router/router.go b/router/router.go index e6018766..7758817c 100644 --- a/router/router.go +++ b/router/router.go @@ -111,6 +111,8 @@ const ( Announce AdvertType = iota // RouteUpdate advertises route updates RouteUpdate + // Solicitation indicates routes were solicited + Solicitation ) // String returns human readable advertisement type @@ -120,6 +122,8 @@ func (t AdvertType) String() string { return "announce" case RouteUpdate: return "update" + case Solicitation: + return "solicitation" default: return "unknown" } diff --git a/server/rpc_codec.go b/server/rpc_codec.go index 342110aa..53d552f4 100644 --- a/server/rpc_codec.go +++ b/server/rpc_codec.go @@ -86,24 +86,18 @@ func getHeader(hdr string, md map[string]string) string { } func getHeaders(m *codec.Message) { - get := func(hdr, v string) string { + set := func(v, hdr string) string { if len(v) > 0 { return v } - - if hd := m.Header[hdr]; len(hd) > 0 { - return hd - } - - // old - return m.Header["X-"+hdr] + return m.Header[hdr] } - m.Id = get("Micro-Id", m.Id) - m.Error = get("Micro-Error", m.Error) - m.Endpoint = get("Micro-Endpoint", m.Endpoint) - m.Method = get("Micro-Method", m.Method) - m.Target = get("Micro-Service", m.Target) + m.Id = set(m.Id, "Micro-Id") + m.Error = set(m.Error, "Micro-Error") + m.Endpoint = set(m.Endpoint, "Micro-Endpoint") + m.Method = set(m.Method, "Micro-Method") + m.Target = set(m.Target, "Micro-Service") // TODO: remove this cruft if len(m.Endpoint) == 0 { @@ -321,7 +315,6 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error { // write an error if it failed m.Error = errors.Wrapf(err, "Unable to encode body").Error() - m.Header["X-Micro-Error"] = m.Error m.Header["Micro-Error"] = m.Error // no body to write if err := c.codec.Write(m, nil); err != nil { diff --git a/server/rpc_server.go b/server/rpc_server.go index 191bf1e8..d23401c8 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -549,6 +549,7 @@ func (s *rpcServer) Register() error { node.Metadata["protocol"] = "mucp" s.RLock() + // Maps are ordered randomly, sort the keys for consistency var handlerList []string for n, e := range s.handlers { @@ -557,6 +558,7 @@ func (s *rpcServer) Register() error { handlerList = append(handlerList, n) } } + sort.Strings(handlerList) var subscriberList []Subscriber @@ -566,18 +568,20 @@ func (s *rpcServer) Register() error { subscriberList = append(subscriberList, e) } } + sort.Slice(subscriberList, func(i, j int) bool { return subscriberList[i].Topic() > subscriberList[j].Topic() }) endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList)) + for _, n := range handlerList { endpoints = append(endpoints, s.handlers[n].Endpoints()...) } + for _, e := range subscriberList { endpoints = append(endpoints, e.Endpoints()...) } - s.RUnlock() service := ®istry.Service{ Name: config.Name, @@ -586,9 +590,10 @@ func (s *rpcServer) Register() error { Endpoints: endpoints, } - s.Lock() + // get registered value registered := s.registered - s.Unlock() + + s.RUnlock() if !registered { log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id) @@ -610,6 +615,8 @@ func (s *rpcServer) Register() error { defer s.Unlock() s.registered = true + // set what we're advertising + s.opts.Advertise = addr // subscribe to the topic with own name sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent) diff --git a/service.go b/service.go index 915e6512..efddd0a4 100644 --- a/service.go +++ b/service.go @@ -9,9 +9,9 @@ import ( "github.com/micro/go-micro/client" "github.com/micro/go-micro/config/cmd" - "github.com/micro/go-micro/debug/service/handler" "github.com/micro/go-micro/debug/profile" "github.com/micro/go-micro/debug/profile/pprof" + "github.com/micro/go-micro/debug/service/handler" "github.com/micro/go-micro/plugin" "github.com/micro/go-micro/server" "github.com/micro/go-micro/util/log" diff --git a/tunnel/default.go b/tunnel/default.go index 1cb0dbfc..a6cb08c9 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -127,7 +127,6 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) { closed: make(chan bool), recv: make(chan *message, 128), send: t.send, - wait: make(chan bool), errChan: make(chan error, 1), } @@ -198,8 +197,8 @@ func (t *tun) announce(channel, session string, link *link) { } } -// monitor monitors outbound links and attempts to reconnect to the failed ones -func (t *tun) monitor() { +// manage monitors outbound links and attempts to reconnect to the failed ones +func (t *tun) manage() { reconnect := time.NewTicker(ReconnectTime) defer reconnect.Stop() @@ -208,63 +207,121 @@ func (t *tun) monitor() { case <-t.closed: return case <-reconnect.C: - t.RLock() + t.manageLinks() + } + } +} - var delLinks []string - // check the link status and purge dead links - for node, link := range t.links { - // check link status - switch link.State() { - case "closed": - delLinks = append(delLinks, node) - case "error": - delLinks = append(delLinks, node) - } +// manageLink sends channel discover requests periodically and +// keepalive messages to link +func (t *tun) manageLink(link *link) { + keepalive := time.NewTicker(KeepAliveTime) + defer keepalive.Stop() + discover := time.NewTicker(DiscoverTime) + defer discover.Stop() + + for { + select { + case <-t.closed: + return + case <-link.closed: + return + case <-discover.C: + // send a discovery message to all links + if err := t.sendMsg("discover", link); err != nil { + log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err) } - - t.RUnlock() - - // delete the dead links - if len(delLinks) > 0 { - t.Lock() - for _, node := range delLinks { - log.Debugf("Tunnel deleting dead link for %s", node) - if link, ok := t.links[node]; ok { - link.Close() - delete(t.links, node) - } - } - t.Unlock() - } - - // check current link status - var connect []string - - // build list of unknown nodes to connect to - t.RLock() - for _, node := range t.options.Nodes { - if _, ok := t.links[node]; !ok { - connect = append(connect, node) - } - } - t.RUnlock() - - for _, node := range connect { - // create new link - link, err := t.setupLink(node) - if err != nil { - log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) - continue - } - // save the link - t.Lock() - t.links[node] = link - t.Unlock() + case <-keepalive.C: + // send keepalive message + log.Debugf("Tunnel sending keepalive to link: %v", link.Remote()) + if err := t.sendMsg("keepalive", link); err != nil { + log.Debugf("Tunnel error sending keepalive to link %v: %v", link.Remote(), err) + t.delLink(link.Remote()) + return } } } } +// manageLinks is a function that can be called to immediately to link setup +func (t *tun) manageLinks() { + var delLinks []string + + t.RLock() + + // check the link status and purge dead links + for node, link := range t.links { + // check link status + switch link.State() { + case "closed": + delLinks = append(delLinks, node) + case "error": + delLinks = append(delLinks, node) + } + } + + t.RUnlock() + + // delete the dead links + if len(delLinks) > 0 { + t.Lock() + for _, node := range delLinks { + log.Debugf("Tunnel deleting dead link for %s", node) + if link, ok := t.links[node]; ok { + link.Close() + delete(t.links, node) + } + } + t.Unlock() + } + + // check current link status + var connect []string + + // build list of unknown nodes to connect to + t.RLock() + + for _, node := range t.options.Nodes { + if _, ok := t.links[node]; !ok { + connect = append(connect, node) + } + } + + t.RUnlock() + + var wg sync.WaitGroup + + for _, node := range connect { + wg.Add(1) + + go func(node string) { + defer wg.Done() + + // create new link + link, err := t.setupLink(node) + if err != nil { + log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) + return + } + + // save the link + t.Lock() + defer t.Unlock() + + // just check nothing else was setup in the interim + if _, ok := t.links[node]; ok { + link.Close() + return + } + // save the link + t.links[node] = link + }(node) + } + + // wait for all threads to finish + wg.Wait() +} + // process outgoing messages sent by all local sessions func (t *tun) process() { // manage the send buffer @@ -328,6 +385,7 @@ func (t *tun) process() { // and the message is being sent outbound via // a dialled connection don't use this link if loopback && msg.outbound { + log.Tracef("Link for node %s is loopback", node) err = errors.New("link is loopback") continue } @@ -335,6 +393,7 @@ func (t *tun) process() { // if the message was being returned by the loopback listener // send it back up the loopback link only if msg.loopback && !loopback { + log.Tracef("Link for message %s is loopback", node) err = errors.New("link is not loopback") continue } @@ -364,7 +423,7 @@ func (t *tun) process() { // send the message for _, link := range sendTo { // send the message via the current link - log.Tracef("Sending %+v to %s", newMsg.Header, link.Remote()) + log.Tracef("Tunnel sending %+v to %s", newMsg.Header, link.Remote()) if errr := link.Send(newMsg); errr != nil { log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr) @@ -470,6 +529,9 @@ func (t *tun) listen(link *link) { return } + // this state machine block handles the only message types + // that we know or care about; connect, close, open, accept, + // discover, announce, session, keepalive switch mtype { case "connect": log.Debugf("Tunnel link %s received connect message", link.Remote()) @@ -495,14 +557,14 @@ func (t *tun) listen(link *link) { t.links[link.Remote()] = link t.Unlock() - // send back a discovery + // send back an announcement of our channels discovery go t.announce("", "", link) + // ask for the things on the other wise + go t.sendMsg("discover", link) // nothing more to do continue case "close": - // TODO: handle the close message - // maybe report io.EOF or kill the link - + log.Debugf("Tunnel link %s received close message", link.Remote()) // if there is no channel then we close the link // as its a signal from the other side to close the connection if len(channel) == 0 { @@ -521,6 +583,8 @@ func (t *tun) listen(link *link) { // try get the dialing socket s, exists := t.getSession(channel, sessionId) if exists && !loopback { + // only delete the session if its unicast + // otherwise ignore close on the multicast if s.mode == Unicast { // only delete this if its unicast // but not if its a loopback conn @@ -541,20 +605,24 @@ func (t *tun) listen(link *link) { // an accept returned by the listener case "accept": s, exists := t.getSession(channel, sessionId) - // we don't need this + // just set accepted on anything not unicast if exists && s.mode > Unicast { s.accepted = true continue } + // if its already accepted move on if exists && s.accepted { continue } + // otherwise we're going to process to accept // a continued session case "session": // process message - log.Tracef("Received %+v from %s", msg.Header, link.Remote()) + log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote()) // an announcement of a channel listener case "announce": + log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote()) + // process the announcement channels := strings.Split(channel, ",") @@ -562,7 +630,10 @@ func (t *tun) listen(link *link) { link.setChannel(channels...) // this was an announcement not intended for anything - if sessionId == "listener" || sessionId == "" { + // if the dialing side sent "discover" then a session + // id would be present. We skip in case of multicast. + switch sessionId { + case "listener", "multicast", "": continue } @@ -574,14 +645,19 @@ func (t *tun) listen(link *link) { continue } - // send the announce back to the caller - s.recv <- &message{ + msg := &message{ typ: "announce", tunnel: id, channel: channel, session: sessionId, link: link.id, } + + // send the announce back to the caller + select { + case <-s.closed: + case s.recv <- msg: + } } continue case "discover": @@ -618,7 +694,7 @@ func (t *tun) listen(link *link) { s, exists = t.getSession(channel, "listener") // only return accept to the session case mtype == "accept": - log.Debugf("Received accept message for %s %s", channel, sessionId) + log.Debugf("Tunnel received accept message for channel: %s session: %s", channel, sessionId) s, exists = t.getSession(channel, sessionId) if exists && s.accepted { continue @@ -638,7 +714,7 @@ func (t *tun) listen(link *link) { // bail if no session or listener has been found if !exists { - log.Debugf("Tunnel skipping no session %s %s exists", channel, sessionId) + log.Tracef("Tunnel skipping no channel: %s session: %s exists", channel, sessionId) // drop it, we don't care about // messages we don't know about continue @@ -651,22 +727,10 @@ func (t *tun) listen(link *link) { delete(t.sessions, channel) continue default: - // process + // otherwise process } - log.Debugf("Tunnel using channel %s session %s", s.channel, s.session) - - // is the session new? - select { - // if its new the session is actually blocked waiting - // for a connection. so we check if its waiting. - case <-s.wait: - // if its waiting e.g its new then we close it - default: - // set remote address of the session - s.remote = msg.Header["Remote"] - close(s.wait) - } + log.Tracef("Tunnel using channel: %s session: %s type: %s", s.channel, s.session, mtype) // construct a new transport message tmsg := &transport.Message{ @@ -696,68 +760,26 @@ func (t *tun) listen(link *link) { } } -// discover sends channel discover requests periodically -func (t *tun) discover(link *link) { - tick := time.NewTicker(DiscoverTime) - defer tick.Stop() - - for { - select { - case <-tick.C: - // send a discovery message to all links - if err := link.Send(&transport.Message{ - Header: map[string]string{ - "Micro-Tunnel": "discover", - "Micro-Tunnel-Id": t.id, - }, - }); err != nil { - log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err) - } - case <-link.closed: - return - case <-t.closed: - return - } - } -} - -// keepalive periodically sends keepalive messages to link -func (t *tun) keepalive(link *link) { - keepalive := time.NewTicker(KeepAliveTime) - defer keepalive.Stop() - - for { - select { - case <-t.closed: - return - case <-link.closed: - return - case <-keepalive.C: - // send keepalive message - log.Debugf("Tunnel sending keepalive to link: %v", link.Remote()) - if err := link.Send(&transport.Message{ - Header: map[string]string{ - "Micro-Tunnel": "keepalive", - "Micro-Tunnel-Id": t.id, - }, - }); err != nil { - log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) - t.delLink(link.Remote()) - return - } - } - } +func (t *tun) sendMsg(method string, link *link) error { + return link.Send(&transport.Message{ + Header: map[string]string{ + "Micro-Tunnel": method, + "Micro-Tunnel-Id": t.id, + }, + }) } // setupLink connects to node and returns link if successful // It returns error if the link failed to be established func (t *tun) setupLink(node string) (*link, error) { log.Debugf("Tunnel setting up link: %s", node) + c, err := t.options.Transport.Dial(node) if err != nil { log.Debugf("Tunnel failed to connect to %s: %v", node, err) return nil, err } + log.Debugf("Tunnel connected to %s", node) // create a new link @@ -766,12 +788,8 @@ func (t *tun) setupLink(node string) (*link, error) { link.id = c.Remote() // send the first connect message - if err := link.Send(&transport.Message{ - Header: map[string]string{ - "Micro-Tunnel": "connect", - "Micro-Tunnel-Id": t.id, - }, - }); err != nil { + if err := t.sendMsg("connect", link); err != nil { + link.Close() return nil, err } @@ -782,37 +800,40 @@ func (t *tun) setupLink(node string) (*link, error) { // process incoming messages go t.listen(link) - // start keepalive monitor - go t.keepalive(link) - - // discover things on the remote side - go t.discover(link) + // manage keepalives and discovery messages + go t.manageLink(link) return link, nil } func (t *tun) setupLinks() { + var wg sync.WaitGroup + for _, node := range t.options.Nodes { - // skip zero length nodes - if len(node) == 0 { - continue - } + wg.Add(1) - // link already exists - if _, ok := t.links[node]; ok { - continue - } + go func(node string) { + defer wg.Done() - // connect to node and return link - link, err := t.setupLink(node) - if err != nil { - log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) - continue - } + // we're not trying to fix existing links + if _, ok := t.links[node]; ok { + return + } - // save the link - t.links[node] = link + // create new link + link, err := t.setupLink(node) + if err != nil { + log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) + return + } + + // save the link + t.links[node] = link + }(node) } + + // wait for all threads to finish + wg.Wait() } // connect the tunnel to all the nodes and listen for incoming tunnel connections @@ -833,11 +854,8 @@ func (t *tun) connect() error { // create a new link link := newLink(sock) - // start keepalive monitor - go t.keepalive(link) - - // discover things on the remote side - go t.discover(link) + // manage the link + go t.manageLink(link) // listen for inbound messages. // only save the link once connected. @@ -864,12 +882,13 @@ func (t *tun) Connect() error { // already connected if t.connected { - // setup links + // do it immediately t.setupLinks() + // setup links return nil } - // send the connect message + // connect the tunnel: start the listener if err := t.connect(); err != nil { return err } @@ -879,15 +898,15 @@ func (t *tun) Connect() error { // create new close channel t.closed = make(chan bool) - // setup links - t.setupLinks() - // process outbound messages to be sent // process sends to all links go t.process() - // monitor links - go t.monitor() + // call setup before managing them + t.setupLinks() + + // manage the links + go t.manage() return nil } @@ -1025,7 +1044,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { // set the multicast option c.mode = options.Mode // set the dial timeout - c.timeout = options.Timeout + c.dialTimeout = options.Timeout + // set read timeout set to never + c.readTimeout = time.Duration(-1) var links []*link // did we measure the rtt @@ -1052,7 +1073,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { t.RUnlock() - // link not found + // link not found and one was specified so error out if len(links) == 0 && len(options.Link) > 0 { // delete session and return error t.delSession(c.channel, c.session) @@ -1061,15 +1082,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } // discovered so set the link if not multicast - // TODO: pick the link efficiently based - // on link status and saturation. if c.discovered && c.mode == Unicast { // pickLink will pick the best link link := t.pickLink(links) + // set the link c.link = link.id } - // shit fuck + // if its not already discovered we need to attempt to do so if !c.discovered { // piggy back roundtrip nowRTT := time.Now() @@ -1098,7 +1118,15 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { } } - // a unicast session so we call "open" and wait for an "accept" + // return early if its not unicast + // we will not call "open" for multicast + if c.mode != Unicast { + return c, nil + } + + // Note: we go no further for multicast or broadcast. + // This is a unicast session so we call "open" and wait + // for an "accept" // reset now in case we use it now := time.Now() @@ -1115,7 +1143,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { d := time.Since(now) // if we haven't measured the roundtrip do it now - if !measured && c.mode == Unicast { + if !measured { // set the link time t.RLock() link, ok := t.links[c.link] @@ -1134,7 +1162,11 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { log.Debugf("Tunnel listening on %s", channel) - var options ListenOptions + options := ListenOptions{ + // Read timeout defaults to never + Timeout: time.Duration(-1), + } + for _, o := range opts { o(&options) } @@ -1145,6 +1177,7 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { return nil, errors.New("already listening on " + channel) } + // delete function removes the session when closed delFunc := func() { t.delSession(channel, "listener") } @@ -1155,6 +1188,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { c.local = channel // set mode c.mode = options.Mode + // set the timeout + c.readTimeout = options.Timeout tl := &tunListener{ channel: channel, diff --git a/tunnel/link.go b/tunnel/link.go index 042f1cb9..a072a047 100644 --- a/tunnel/link.go +++ b/tunnel/link.go @@ -2,6 +2,7 @@ package tunnel import ( "bytes" + "errors" "io" "sync" "time" @@ -14,7 +15,11 @@ import ( type link struct { transport.Socket + // transport to use for connections + transport transport.Transport + sync.RWMutex + // stops the link closed chan bool // link state channel for testing link @@ -65,6 +70,8 @@ var ( linkRequest = []byte{0, 0, 0, 0} // the 4 byte 1 filled packet sent to determine link state linkResponse = []byte{1, 1, 1, 1} + + ErrLinkConnectTimeout = errors.New("link connect timeout") ) func newLink(s transport.Socket) *link { @@ -72,8 +79,8 @@ func newLink(s transport.Socket) *link { Socket: s, id: uuid.New().String(), lastKeepAlive: time.Now(), - channels: make(map[string]time.Time), closed: make(chan bool), + channels: make(map[string]time.Time), state: make(chan *packet, 64), sendQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128), @@ -87,6 +94,32 @@ func newLink(s transport.Socket) *link { return l } +func (l *link) connect(addr string) error { + c, err := l.transport.Dial(addr) + if err != nil { + return err + } + + l.Lock() + l.Socket = c + l.Unlock() + + return nil +} + +func (l *link) accept(sock transport.Socket) error { + l.Lock() + l.Socket = sock + l.Unlock() + return nil +} + +func (l *link) setLoopback(v bool) { + l.Lock() + l.loopback = v + l.Unlock() +} + // setRate sets the bits per second rate as a float64 func (l *link) setRate(bits int64, delta time.Duration) { // rate of send in bits per nanosecond @@ -167,6 +200,8 @@ func (l *link) process() { // process link state message select { case l.state <- pk: + case <-l.closed: + return default: } continue @@ -188,7 +223,11 @@ func (l *link) process() { select { case pk := <-l.sendQueue: // send the message - pk.status <- l.send(pk.message) + select { + case pk.status <- l.send(pk.message): + case <-l.closed: + return + } case <-l.closed: return } @@ -201,11 +240,15 @@ func (l *link) manage() { t := time.NewTicker(time.Minute) defer t.Stop() + // get link id + linkId := l.Id() + // used to send link state packets send := func(b []byte) error { return l.Send(&transport.Message{ Header: map[string]string{ - "Micro-Method": "link", + "Micro-Method": "link", + "Micro-Link-Id": linkId, }, Body: b, }) } @@ -229,9 +272,7 @@ func (l *link) manage() { // check the type of message switch { case bytes.Equal(p.message.Body, linkRequest): - l.RLock() - log.Tracef("Link %s received link request %v", l.id, p.message.Body) - l.RUnlock() + log.Tracef("Link %s received link request", linkId) // send response if err := send(linkResponse); err != nil { @@ -242,9 +283,7 @@ func (l *link) manage() { case bytes.Equal(p.message.Body, linkResponse): // set round trip time d := time.Since(now) - l.RLock() - log.Tracef("Link %s received link response in %v", p.message.Body, d) - l.RUnlock() + log.Tracef("Link %s received link response in %v", linkId, d) // set the RTT l.setRTT(d) } @@ -309,6 +348,12 @@ func (l *link) Rate() float64 { return l.rate } +func (l *link) Loopback() bool { + l.RLock() + defer l.RUnlock() + return l.loopback +} + // Length returns the roundtrip time as nanoseconds (lower is better). // Returns 0 where no measurement has been taken. func (l *link) Length() int64 { @@ -320,7 +365,6 @@ func (l *link) Length() int64 { func (l *link) Id() string { l.RLock() defer l.RUnlock() - return l.id } @@ -350,13 +394,6 @@ func (l *link) Send(m *transport.Message) error { // get time now now := time.Now() - // check if its closed first - select { - case <-l.closed: - return io.EOF - default: - } - // queue the message select { case l.sendQueue <- p: diff --git a/tunnel/listener.go b/tunnel/listener.go index 7ba6074e..6dbec5c8 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -24,7 +24,7 @@ type tunListener struct { delFunc func() } -// periodically announce self +// periodically announce self the channel being listened on func (t *tunListener) announce() { tick := time.NewTicker(time.Second * 30) defer tick.Stop() @@ -48,9 +48,12 @@ func (t *tunListener) process() { defer func() { // close the sessions - for _, conn := range conns { + for id, conn := range conns { conn.Close() + delete(conns, id) } + // unassign + conns = nil }() for { @@ -62,9 +65,24 @@ func (t *tunListener) process() { return // receive a new message case m := <-t.session.recv: + var sessionId string + var linkId string + + switch m.mode { + case Multicast: + sessionId = "multicast" + linkId = "multicast" + case Broadcast: + sessionId = "broadcast" + linkId = "broadcast" + default: + sessionId = m.session + linkId = m.link + } + // get a session - sess, ok := conns[m.session] - log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok) + sess, ok := conns[sessionId] + log.Tracef("Tunnel listener received channel %s session %s type %s exists: %t", m.channel, m.session, m.typ, ok) if !ok { // we only process open and session types switch m.typ { @@ -80,13 +98,13 @@ func (t *tunListener) process() { // the channel channel: m.channel, // the session id - session: m.session, + session: sessionId, // tunnel token token: t.token, // is loopback conn loopback: m.loopback, // the link the message was received on - link: m.link, + link: linkId, // set the connection mode mode: m.mode, // close chan @@ -95,14 +113,14 @@ func (t *tunListener) process() { recv: make(chan *message, 128), // use the internal send buffer send: t.session.send, - // wait - wait: make(chan bool), // error channel errChan: make(chan error, 1), + // set the read timeout + readTimeout: t.session.readTimeout, } // save the session - conns[m.session] = sess + conns[sessionId] = sess select { case <-t.closed: @@ -114,17 +132,19 @@ func (t *tunListener) process() { // an existing session was found - // received a close message switch m.typ { case "close": + // received a close message select { + // check if the session is closed case <-sess.closed: // no op - delete(conns, m.session) + delete(conns, sessionId) default: + // only close if unicast session // close and delete session close(sess.closed) - delete(conns, m.session) + delete(conns, sessionId) } // continue @@ -139,9 +159,9 @@ func (t *tunListener) process() { // send this to the accept chan select { case <-sess.closed: - delete(conns, m.session) + delete(conns, sessionId) case sess.recv <- m: - log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session) + log.Tracef("Tunnel listener sent to recv chan channel %s session %s type %s", m.channel, sessionId, m.typ) } } } @@ -180,6 +200,10 @@ func (t *tunListener) Accept() (Session, error) { if !ok { return nil, io.EOF } + // return without accept + if c.mode != Unicast { + return c, nil + } // send back the accept if err := c.Accept(); err != nil { return nil, err diff --git a/tunnel/options.go b/tunnel/options.go index 3a8f13db..145db19f 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -47,6 +47,8 @@ type ListenOption func(*ListenOptions) type ListenOptions struct { // specify mode of the session Mode Mode + // The read timeout + Timeout time.Duration } // The tunnel id @@ -84,16 +86,6 @@ func Transport(t transport.Transport) Option { } } -// DefaultOptions returns router default options -func DefaultOptions() Options { - return Options{ - Id: uuid.New().String(), - Address: DefaultAddress, - Token: DefaultToken, - Transport: quic.NewTransport(), - } -} - // Listen options func ListenMode(m Mode) ListenOption { return func(o *ListenOptions) { @@ -101,6 +93,13 @@ func ListenMode(m Mode) ListenOption { } } +// Timeout for reads and writes on the listener session +func ListenTimeout(t time.Duration) ListenOption { + return func(o *ListenOptions) { + o.Timeout = t + } +} + // Dial options // Dial multicast sets the multicast option to send only to those mapped @@ -124,3 +123,13 @@ func DialLink(id string) DialOption { o.Link = id } } + +// DefaultOptions returns router default options +func DefaultOptions() Options { + return Options{ + Id: uuid.New().String(), + Address: DefaultAddress, + Token: DefaultToken, + Transport: quic.NewTransport(), + } +} diff --git a/tunnel/session.go b/tunnel/session.go index b5b724b0..84153541 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -2,7 +2,6 @@ package tunnel import ( "encoding/hex" - "errors" "io" "time" @@ -30,8 +29,6 @@ type session struct { send chan *message // recv chan recv chan *message - // wait until we have a connection - wait chan bool // if the discovery worked discovered bool // if the session was accepted @@ -42,8 +39,10 @@ type session struct { loopback bool // mode of the connection mode Mode - // the timeout - timeout time.Duration + // the dial timeout + dialTimeout time.Duration + // the read timeout + readTimeout time.Duration // the link on which this message was received link string // the error response @@ -109,65 +108,114 @@ func (s *session) newMessage(typ string) *message { } } +func (s *session) sendMsg(msg *message) error { + select { + case <-s.closed: + return io.EOF + case s.send <- msg: + return nil + } +} + +func (s *session) wait(msg *message) error { + // wait for an error response + select { + case err := <-msg.errChan: + if err != nil { + return err + } + case <-s.closed: + return io.EOF + } + + return nil +} + // waitFor waits for the message type required until the timeout specified func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { now := time.Now() - after := func(timeout time.Duration) time.Duration { + after := func(timeout time.Duration) <-chan time.Time { + if timeout < time.Duration(0) { + return nil + } + + // get the delta d := time.Since(now) + // dial timeout minus time since wait := timeout - d if wait < time.Duration(0) { - return time.Duration(0) + wait = time.Duration(0) } - return wait + return time.After(wait) } // wait for the message type for { select { case msg := <-s.recv: + // there may be no message type + if len(msgType) == 0 { + return msg, nil + } + // ignore what we don't want if msg.typ != msgType { log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) continue } + // got the message return msg, nil - case <-time.After(after(timeout)): - return nil, ErrDialTimeout + case <-after(timeout): + return nil, ErrReadTimeout case <-s.closed: return nil, io.EOF } } } -// Discover attempts to discover the link for a specific channel +// Discover attempts to discover the link for a specific channel. +// This is only used by the tunnel.Dial when first connecting. func (s *session) Discover() error { // create a new discovery message for this channel msg := s.newMessage("discover") + // broadcast the message to all links msg.mode = Broadcast + // its an outbound connection since we're dialling msg.outbound = true + // don't set the link since we don't know where it is msg.link = "" - // send the discovery message - s.send <- msg + // if multicast then set that as session + if s.mode == Multicast { + msg.session = "multicast" + } + + // send discover message + if err := s.sendMsg(msg); err != nil { + return err + } // set time now now := time.Now() + // after strips down the dial timeout after := func() time.Duration { d := time.Since(now) // dial timeout minus time since - wait := s.timeout - d + wait := s.dialTimeout - d + // make sure its always > 0 if wait < time.Duration(0) { return time.Duration(0) } return wait } + // the discover message is sent out, now // wait to hear back about the sent message select { case <-time.After(after()): @@ -178,27 +226,16 @@ func (s *session) Discover() error { } } - var err error - - // set a new dialTimeout - dialTimeout := after() - - // set a shorter delay for multicast - if s.mode != Unicast { - // shorten this - dialTimeout = time.Millisecond * 500 - } - - // wait for announce - _, err = s.waitFor("announce", dialTimeout) - - // if its multicast just go ahead because this is best effort + // bail early if its not unicast + // we don't need to wait for the announce if s.mode != Unicast { s.discovered = true s.accepted = true return nil } + // wait for announce + _, err := s.waitFor("announce", after()) if err != nil { return err } @@ -210,31 +247,23 @@ func (s *session) Discover() error { } // Open will fire the open message for the session. This is called by the dialler. +// This is to indicate that we want to create a new session. func (s *session) Open() error { // create a new message msg := s.newMessage("open") // send open message - s.send <- msg + if err := s.sendMsg(msg); err != nil { + return err + } // wait for an error response for send - select { - case err := <-msg.errChan: - if err != nil { - return err - } - case <-s.closed: - return io.EOF + if err := s.wait(msg); err != nil { + return err } - // don't wait on multicast/broadcast - if s.mode == Multicast { - s.accepted = true - return nil - } - - // now wait for the accept - msg, err := s.waitFor("accept", s.timeout) + // now wait for the accept message to be returned + msg, err := s.waitFor("accept", s.dialTimeout) if err != nil { return err } @@ -252,32 +281,16 @@ func (s *session) Accept() error { msg := s.newMessage("accept") // send the accept message - select { - case <-s.closed: - return io.EOF - case s.send <- msg: - // no op here - } - - // don't wait on multicast/broadcast - if s.mode == Multicast { - return nil + if err := s.sendMsg(msg); err != nil { + return err } // wait for send response - select { - case err := <-s.errChan: - if err != nil { - return err - } - case <-s.closed: - return io.EOF - } - - return nil + return s.wait(msg) } -// Announce sends an announcement to notify that this session exists. This is primarily used by the listener. +// Announce sends an announcement to notify that this session exists. +// This is primarily used by the listener. func (s *session) Announce() error { msg := s.newMessage("announce") // we don't need an error back @@ -287,23 +300,12 @@ func (s *session) Announce() error { // we don't need the link msg.link = "" - select { - case s.send <- msg: - return nil - case <-s.closed: - return io.EOF - } + // send announce message + return s.sendMsg(msg) } // Send is used to send a message func (s *session) Send(m *transport.Message) error { - select { - case <-s.closed: - return io.EOF - default: - // no op - } - // encrypt the transport message payload body, err := Encrypt(m.Body, s.token+s.channel+s.session) if err != nil { @@ -335,32 +337,28 @@ func (s *session) Send(m *transport.Message) error { msg.data = data // if multicast don't set the link - if s.mode == Multicast { + if s.mode != Unicast { msg.link = "" } log.Tracef("Appending %+v to send backlog", msg) + // send the actual message - s.send <- msg + if err := s.sendMsg(msg); err != nil { + return err + } // wait for an error response - select { - case err := <-msg.errChan: - return err - case <-s.closed: - return io.EOF - } + return s.wait(msg) } // Recv is used to receive a message func (s *session) Recv(m *transport.Message) error { var msg *message - select { - case <-s.closed: - return errors.New("session is closed") - // recv from backlog - case msg = <-s.recv: + msg, err := s.waitFor("", s.readTimeout) + if err != nil { + return err } // check the error if one exists @@ -371,10 +369,13 @@ func (s *session) Recv(m *transport.Message) error { } //log.Tracef("Received %+v from recv backlog", msg) - log.Debugf("Received %+v from recv backlog", msg) + log.Tracef("Received %+v from recv backlog", msg) // decrypt the received payload using the token - body, err := Decrypt(msg.data.Body, s.token+s.channel+s.session) + // we have to used msg.session because multicast has a shared + // session id of "multicast" in this session struct on + // the listener side + body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session) if err != nil { log.Debugf("failed to decrypt message body: %v", err) return err @@ -390,7 +391,7 @@ func (s *session) Recv(m *transport.Message) error { return err } // encrypt the transport message payload - val, err := Decrypt([]byte(h), s.token+s.channel+s.session) + val, err := Decrypt([]byte(h), s.token+s.channel+msg.session) if err != nil { log.Debugf("failed to decrypt message header %s: %v", k, err) return err @@ -399,6 +400,12 @@ func (s *session) Recv(m *transport.Message) error { msg.data.Header[k] = string(val) } + // set the link + // TODO: decruft, this is only for multicast + // since the session is now a single session + // likely provide as part of message.Link() + msg.data.Header["Micro-Link"] = msg.link + // set message *m = *msg.data // return nil @@ -413,6 +420,11 @@ func (s *session) Close() error { default: close(s.closed) + // don't send close on multicast or broadcast + if s.mode != Unicast { + return nil + } + // append to backlog msg := s.newMessage("close") // no error response on close @@ -421,7 +433,7 @@ func (s *session) Close() error { // send the close message select { case s.send <- msg: - default: + case <-time.After(time.Millisecond * 10): } } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 56928d8d..15ff6078 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -26,6 +26,8 @@ var ( ErrDiscoverChan = errors.New("failed to discover channel") // ErrLinkNotFound is returned when a link is specified at dial time and does not exist ErrLinkNotFound = errors.New("link not found") + // ErrReadTimeout is a timeout on session.Recv + ErrReadTimeout = errors.New("read timeout") ) // Mode of the session @@ -64,7 +66,9 @@ type Link interface { Length() int64 // Current transfer rate as bits per second (lower is better) Rate() float64 - // State of the link e.g connected/closed + // Is this a loopback link + Loopback() bool + // State of the link: connected/closed/error State() string // honours transport socket transport.Socket diff --git a/tunnel/tunnel_reconnect_test.go b/tunnel/tunnel_reconnect_test.go deleted file mode 100644 index 2c78b8b2..00000000 --- a/tunnel/tunnel_reconnect_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// +build !race - -package tunnel - -import ( - "sync" - "testing" - "time" -) - -func TestReconnectTunnel(t *testing.T) { - // create a new tunnel client - tunA := NewTunnel( - Address("127.0.0.1:9096"), - Nodes("127.0.0.1:9097"), - ) - - // create a new tunnel server - tunB := NewTunnel( - Address("127.0.0.1:9097"), - ) - - // start tunnel - err := tunB.Connect() - if err != nil { - t.Fatal(err) - } - defer tunB.Close() - - // we manually override the tunnel.ReconnectTime value here - // this is so that we make the reconnects faster than the default 5s - ReconnectTime = 200 * time.Millisecond - - // start tunnel - err = tunA.Connect() - if err != nil { - t.Fatal(err) - } - defer tunA.Close() - - wait := make(chan bool) - - var wg sync.WaitGroup - - wg.Add(1) - // start tunnel listener - go testBrokenTunAccept(t, tunB, wait, &wg) - - wg.Add(1) - // start tunnel sender - go testBrokenTunSend(t, tunA, wait, &wg) - - // wait until done - wg.Wait() -} diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index 8c3119da..9633e53d 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -8,6 +8,90 @@ import ( "github.com/micro/go-micro/transport" ) +func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { + defer wg.Done() + + // listen on some virtual address + tl, err := tun.Listen("test-tunnel") + if err != nil { + t.Fatal(err) + } + + // receiver ready; notify sender + wait <- true + + // accept a connection + c, err := tl.Accept() + if err != nil { + t.Fatal(err) + } + + // accept the message and close the tunnel + // we do this to simulate loss of network connection + m := new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + + // close all the links + for _, link := range tun.Links() { + link.Close() + } + + // receiver ready; notify sender + wait <- true + + // accept the message + m = new(transport.Message) + if err := c.Recv(m); err != nil { + t.Fatal(err) + } + + // notify the sender we have received + wait <- true +} + +func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup, reconnect time.Duration) { + defer wg.Done() + + // wait for the listener to get ready + <-wait + + // dial a new session + c, err := tun.Dial("test-tunnel") + if err != nil { + t.Fatal(err) + } + defer c.Close() + + m := transport.Message{ + Header: map[string]string{ + "test": "send", + }, + } + + // send the message + if err := c.Send(&m); err != nil { + t.Fatal(err) + } + + // wait for the listener to get ready + <-wait + + // give it time to reconnect + time.Sleep(reconnect) + + // send the message + if err := c.Send(&m); err != nil { + t.Fatal(err) + } + + // wait for the listener to receive the message + // c.Send merely enqueues the message to the link send queue and returns + // in order to verify it was received we wait for the listener to tell us + <-wait +} + // testAccept will accept connections on the transport, create a new link and tunnel on top func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { defer wg.Done() @@ -163,90 +247,6 @@ func TestLoopbackTunnel(t *testing.T) { wg.Wait() } -func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { - defer wg.Done() - - // listen on some virtual address - tl, err := tun.Listen("test-tunnel") - if err != nil { - t.Fatal(err) - } - - // receiver ready; notify sender - wait <- true - - // accept a connection - c, err := tl.Accept() - if err != nil { - t.Fatal(err) - } - - // accept the message and close the tunnel - // we do this to simulate loss of network connection - m := new(transport.Message) - if err := c.Recv(m); err != nil { - t.Fatal(err) - } - - // close all the links - for _, link := range tun.Links() { - link.Close() - } - - // receiver ready; notify sender - wait <- true - - // accept the message - m = new(transport.Message) - if err := c.Recv(m); err != nil { - t.Fatal(err) - } - - // notify the sender we have received - wait <- true -} - -func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { - defer wg.Done() - - // wait for the listener to get ready - <-wait - - // dial a new session - c, err := tun.Dial("test-tunnel") - if err != nil { - t.Fatal(err) - } - defer c.Close() - - m := transport.Message{ - Header: map[string]string{ - "test": "send", - }, - } - - // send the message - if err := c.Send(&m); err != nil { - t.Fatal(err) - } - - // wait for the listener to get ready - <-wait - - // give it time to reconnect - time.Sleep(5 * ReconnectTime) - - // send the message - if err := c.Send(&m); err != nil { - t.Fatal(err) - } - - // wait for the listener to receive the message - // c.Send merely enqueues the message to the link send queue and returns - // in order to verify it was received we wait for the listener to tell us - <-wait -} - func TestTunnelRTTRate(t *testing.T) { // create a new tunnel client tunA := NewTunnel( @@ -296,3 +296,49 @@ func TestTunnelRTTRate(t *testing.T) { t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate()) } } + +func TestReconnectTunnel(t *testing.T) { + // we manually override the tunnel.ReconnectTime value here + // this is so that we make the reconnects faster than the default 5s + ReconnectTime = 200 * time.Millisecond + + // create a new tunnel client + tunA := NewTunnel( + Address("127.0.0.1:9098"), + Nodes("127.0.0.1:9099"), + ) + + // create a new tunnel server + tunB := NewTunnel( + Address("127.0.0.1:9099"), + ) + + // start tunnel + err := tunB.Connect() + if err != nil { + t.Fatal(err) + } + defer tunB.Close() + + // start tunnel + err = tunA.Connect() + if err != nil { + t.Fatal(err) + } + defer tunA.Close() + + wait := make(chan bool) + + var wg sync.WaitGroup + + wg.Add(1) + // start tunnel listener + go testBrokenTunAccept(t, tunB, wait, &wg) + + wg.Add(1) + // start tunnel sender + go testBrokenTunSend(t, tunA, wait, &wg, ReconnectTime*5) + + // wait until done + wg.Wait() +}