Merge pull request #1026 from micro/tun

Network & Tunnel refactor
This commit is contained in:
Asim Aslam 2019-12-08 16:01:46 +00:00 committed by GitHub
commit a9be1288d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1155 additions and 742 deletions

View File

@ -316,6 +316,22 @@ func (r *rpcClient) Options() Options {
return r.opts return r.opts
} }
// hasProxy checks if we have proxy set in the environment
func (r *rpcClient) hasProxy() bool {
// get proxy
if prx := os.Getenv("MICRO_PROXY"); len(prx) > 0 {
return true
}
// get proxy address
if prx := os.Getenv("MICRO_PROXY_ADDRESS"); len(prx) > 0 {
return true
}
return false
}
// next returns an iterator for the next nodes to call
func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) { func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) {
service := request.Service() service := request.Service()
@ -431,10 +447,18 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
return err return err
} }
ch := make(chan error, callOpts.Retries+1) // get the retries
retries := callOpts.Retries
// disable retries when using a proxy
if r.hasProxy() {
retries = 0
}
ch := make(chan error, retries+1)
var gerr error var gerr error
for i := 0; i <= callOpts.Retries; i++ { for i := 0; i <= retries; i++ {
go func(i int) { go func(i int) {
ch <- call(i) ch <- call(i)
}(i) }(i)
@ -514,10 +538,18 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
err error err error
} }
ch := make(chan response, callOpts.Retries+1) // get the retries
retries := callOpts.Retries
// disable retries when using a proxy
if r.hasProxy() {
retries = 0
}
ch := make(chan response, retries+1)
var grr error var grr error
for i := 0; i <= callOpts.Retries; i++ { for i := 0; i <= retries; i++ {
go func(i int) { go func(i int) {
s, err := call(i) s, err := call(i)
ch <- response{s, err} ch <- response{s, err}

View File

@ -88,32 +88,24 @@ func (rwc *readWriteCloser) Close() error {
} }
func getHeaders(m *codec.Message) { func getHeaders(m *codec.Message) {
get := func(hdr string) string { set := func(v, hdr string) string {
if hd := m.Header[hdr]; len(hd) > 0 { if len(v) > 0 {
return hd return v
} }
// old return m.Header[hdr]
return m.Header["X-"+hdr]
} }
// check error in header // check error in header
if len(m.Error) == 0 { m.Error = set(m.Error, "Micro-Error")
m.Error = get("Micro-Error")
}
// check endpoint in header // check endpoint in header
if len(m.Endpoint) == 0 { m.Endpoint = set(m.Endpoint, "Micro-Endpoint")
m.Endpoint = get("Micro-Endpoint")
}
// check method in header // check method in header
if len(m.Method) == 0 { m.Method = set(m.Method, "Micro-Method")
m.Method = get("Micro-Method")
}
if len(m.Id) == 0 { // set the request id
m.Id = get("Micro-Id") m.Id = set(m.Id, "Micro-Id")
}
} }
func setHeaders(m *codec.Message, stream string) { func setHeaders(m *codec.Message, stream string) {
@ -122,7 +114,6 @@ func setHeaders(m *codec.Message, stream string) {
return return
} }
m.Header[hdr] = v m.Header[hdr] = v
m.Header["X-"+hdr] = v
} }
set("Micro-Id", m.Id) set("Micro-Id", m.Id)

File diff suppressed because it is too large Load Diff

View File

@ -6,8 +6,6 @@ import (
"github.com/micro/go-micro/client" "github.com/micro/go-micro/client"
"github.com/micro/go-micro/server" "github.com/micro/go-micro/server"
"github.com/micro/go-micro/transport"
"github.com/micro/go-micro/tunnel"
) )
var ( var (
@ -56,14 +54,6 @@ type Network interface {
Server() server.Server Server() server.Server
} }
// message is network message
type message struct {
// msg is transport message
msg *transport.Message
// session is tunnel session
session tunnel.Session
}
// NewNetwork returns a new network interface // NewNetwork returns a new network interface
func NewNetwork(opts ...Option) Network { func NewNetwork(opts ...Option) Network {
return newNetwork(opts...) return newNetwork(opts...)

View File

@ -216,7 +216,7 @@ func (n *node) DeletePeerNode(id string) error {
// PruneStalePeerNodes prune the peers that have not been seen for longer than given time // PruneStalePeerNodes prune the peers that have not been seen for longer than given time
// It returns a map of the the nodes that got pruned // It returns a map of the the nodes that got pruned
func (n *node) PruneStalePeerNodes(pruneTime time.Duration) map[string]*node { func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node {
n.Lock() n.Lock()
defer n.Unlock() defer n.Unlock()

View File

@ -225,7 +225,7 @@ func TestPruneStalePeerNodes(t *testing.T) {
time.Sleep(pruneTime) time.Sleep(pruneTime)
// should delete all nodes besides node // should delete all nodes besides node
pruned := node.PruneStalePeerNodes(pruneTime) pruned := node.PruneStalePeers(pruneTime)
if len(pruned) != len(nodes)-1 { if len(pruned) != len(nodes)-1 {
t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned)) t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned))

View File

@ -22,8 +22,8 @@ type Options struct {
Address string Address string
// Advertise sets the address to advertise // Advertise sets the address to advertise
Advertise string Advertise string
// Peers is a list of peers to connect to // Nodes is a list of nodes to connect to
Peers []string Nodes []string
// Tunnel is network tunnel // Tunnel is network tunnel
Tunnel tunnel.Tunnel Tunnel tunnel.Tunnel
// Router is network router // Router is network router
@ -62,10 +62,10 @@ func Advertise(a string) Option {
} }
} }
// Peers is a list of peers to connect to // Nodes is a list of nodes to connect to
func Peers(n ...string) Option { func Nodes(n ...string) Option {
return func(o *Options) { return func(o *Options) {
o.Peers = n o.Nodes = n
} }
} }

View File

@ -31,6 +31,17 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
r.Address = "1.0.0.1:53" r.Address = "1.0.0.1:53"
} }
//nolint:prealloc
var records []*resolver.Record
// parsed an actual ip
if v := net.ParseIP(host); v != nil {
records = append(records, &resolver.Record{
Address: net.JoinHostPort(host, port),
})
return records, nil
}
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(host), dns.TypeA) m.SetQuestion(dns.Fqdn(host), dns.TypeA)
rec, err := dns.ExchangeContext(context.Background(), m, r.Address) rec, err := dns.ExchangeContext(context.Background(), m, r.Address)
@ -38,9 +49,6 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
return nil, err return nil, err
} }
//nolint:prealloc
var records []*resolver.Record
for _, answer := range rec.Answer { for _, answer := range rec.Answer {
h := answer.Header() h := answer.Header()
// check record type matches // check record type matches
@ -59,5 +67,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
}) })
} }
// no records returned so just best effort it
if len(records) == 0 {
records = append(records, &resolver.Record{
Address: net.JoinHostPort(host, port),
})
}
return records, nil return records, nil
} }

View File

@ -55,7 +55,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp *
} }
// get list of existing nodes // get list of existing nodes
nodes := n.Network.Options().Peers nodes := n.Network.Options().Nodes
// generate a node map // generate a node map
nodeMap := make(map[string]bool) nodeMap := make(map[string]bool)
@ -84,7 +84,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp *
// reinitialise the peers // reinitialise the peers
n.Network.Init( n.Network.Init(
network.Peers(nodes...), network.Nodes(nodes...),
) )
// call the connect method // call the connect method

View File

@ -83,7 +83,8 @@ func readLoop(r server.Request, s client.Stream) error {
// toNodes returns a list of node addresses from given routes // toNodes returns a list of node addresses from given routes
func toNodes(routes []router.Route) []string { func toNodes(routes []router.Route) []string {
nodes := make([]string, len(routes)) nodes := make([]string, 0, len(routes))
for _, node := range routes { for _, node := range routes {
address := node.Address address := node.Address
if len(node.Gateway) > 0 { if len(node.Gateway) > 0 {
@ -91,11 +92,13 @@ func toNodes(routes []router.Route) []string {
} }
nodes = append(nodes, address) nodes = append(nodes, address)
} }
return nodes return nodes
} }
func toSlice(r map[uint64]router.Route) []router.Route { func toSlice(r map[uint64]router.Route) []router.Route {
routes := make([]router.Route, 0, len(r)) routes := make([]router.Route, 0, len(r))
for _, v := range r { for _, v := range r {
routes = append(routes, v) routes = append(routes, v)
} }
@ -161,6 +164,8 @@ func (p *Proxy) filterRoutes(ctx context.Context, routes []router.Route) []route
filteredRoutes = append(filteredRoutes, route) filteredRoutes = append(filteredRoutes, route)
} }
log.Tracef("Proxy filtered routes %+v\n", filteredRoutes)
return filteredRoutes return filteredRoutes
} }
@ -225,13 +230,15 @@ func (p *Proxy) cacheRoutes(service string) ([]router.Route, error) {
// refreshMetrics will refresh any metrics for our local cached routes. // refreshMetrics will refresh any metrics for our local cached routes.
// we may not receive new watch events for these as they change. // we may not receive new watch events for these as they change.
func (p *Proxy) refreshMetrics() { func (p *Proxy) refreshMetrics() {
services := make([]string, 0, len(p.Routes))
// get a list of services to update // get a list of services to update
p.RLock() p.RLock()
services := make([]string, 0, len(p.Routes))
for service := range p.Routes { for service := range p.Routes {
services = append(services, service) services = append(services, service)
} }
p.RUnlock() p.RUnlock()
// get and cache the routes for the service // get and cache the routes for the service
@ -246,6 +253,8 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
log.Tracef("Proxy taking route action %v %+v\n", action, route)
switch action { switch action {
case "create", "update": case "create", "update":
if _, ok := p.Routes[route.Service]; !ok { if _, ok := p.Routes[route.Service]; !ok {
@ -253,7 +262,12 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error {
} }
p.Routes[route.Service][route.Hash()] = route p.Routes[route.Service][route.Hash()] = route
case "delete": case "delete":
// delete that specific route
delete(p.Routes[route.Service], route.Hash()) delete(p.Routes[route.Service], route.Hash())
// clean up the cache entirely
if len(p.Routes[route.Service]) == 0 {
delete(p.Routes, route.Service)
}
default: default:
return fmt.Errorf("unknown action: %s", action) return fmt.Errorf("unknown action: %s", action)
} }
@ -288,7 +302,7 @@ func (p *Proxy) ProcessMessage(ctx context.Context, msg server.Message) error {
// TODO: check that we're not broadcast storming by sending to the same topic // TODO: check that we're not broadcast storming by sending to the same topic
// that we're actually subscribed to // that we're actually subscribed to
log.Tracef("Received message for %s", msg.Topic()) log.Tracef("Proxy received message for %s", msg.Topic())
var errors []string var errors []string
@ -329,7 +343,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
return errors.BadRequest("go.micro.proxy", "service name is blank") return errors.BadRequest("go.micro.proxy", "service name is blank")
} }
log.Tracef("Received request for %s", service) log.Tracef("Proxy received request for %s", service)
// are we network routing or local routing // are we network routing or local routing
if len(p.Links) == 0 { if len(p.Links) == 0 {
@ -363,15 +377,17 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
} }
//nolint:prealloc //nolint:prealloc
var opts []client.CallOption opts := []client.CallOption{
// set strategy to round robin // set strategy to round robin
opts = append(opts, client.WithSelectOption(selector.WithStrategy(selector.RoundRobin))) client.WithSelectOption(selector.WithStrategy(selector.RoundRobin)),
}
// if the address is already set just serve it // if the address is already set just serve it
// TODO: figure it out if we should know to pick a link // TODO: figure it out if we should know to pick a link
if len(addresses) > 0 { if len(addresses) > 0 {
opts = append(opts, client.WithAddress(addresses...)) opts = append(opts,
client.WithAddress(addresses...),
)
// serve the normal way // serve the normal way
return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...)
@ -387,10 +403,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
opts = append(opts, client.WithAddress(addresses...)) opts = append(opts, client.WithAddress(addresses...))
} }
log.Tracef("Proxy calling %+v\n", addresses)
// serve the normal way // serve the normal way
return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...)
} }
// we're assuming we need routes to operate on
if len(routes) == 0 {
return errors.InternalServerError("go.micro.proxy", "route not found")
}
var gerr error var gerr error
// we're routing globally with multiple links // we're routing globally with multiple links
@ -404,11 +426,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
continue continue
} }
log.Debugf("Proxy using route %+v\n", route) log.Tracef("Proxy using route %+v\n", route)
// set the address to call // set the address to call
addresses := toNodes([]router.Route{route}) addresses := toNodes([]router.Route{route})
opts = append(opts, client.WithAddress(addresses...)) // set the address in the options
// disable retries since its one route processing
opts = append(opts,
client.WithAddress(addresses...),
client.WithRetries(0),
)
// do the request with the link // do the request with the link
gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...) gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...)
@ -558,7 +585,9 @@ func NewProxy(opts ...options.Option) proxy.Proxy {
}() }()
go func() { go func() {
t := time.NewTicker(time.Minute) // TODO: speed up refreshing of metrics
// without this ticking effort e.g stream
t := time.NewTicker(time.Second * 10)
defer t.Stop() defer t.Stop()
// we must refresh route metrics since they do not trigger new events // we must refresh route metrics since they do not trigger new events

View File

@ -799,7 +799,8 @@ func (r *router) flushRouteEvents(evType EventType) ([]*Event, error) {
// build a list of events to advertise // build a list of events to advertise
events := make([]*Event, len(bestRoutes)) events := make([]*Event, len(bestRoutes))
i := 0 var i int
for _, route := range bestRoutes { for _, route := range bestRoutes {
event := &Event{ event := &Event{
Type: evType, Type: evType,
@ -823,9 +824,10 @@ func (r *router) Solicit() error {
// advertise the routes // advertise the routes
r.advertWg.Add(1) r.advertWg.Add(1)
go func() { go func() {
defer r.advertWg.Done() r.publishAdvert(Solicitation, events)
r.publishAdvert(RouteUpdate, events) r.advertWg.Done()
}() }()
return nil return nil

View File

@ -111,6 +111,8 @@ const (
Announce AdvertType = iota Announce AdvertType = iota
// RouteUpdate advertises route updates // RouteUpdate advertises route updates
RouteUpdate RouteUpdate
// Solicitation indicates routes were solicited
Solicitation
) )
// String returns human readable advertisement type // String returns human readable advertisement type
@ -120,6 +122,8 @@ func (t AdvertType) String() string {
return "announce" return "announce"
case RouteUpdate: case RouteUpdate:
return "update" return "update"
case Solicitation:
return "solicitation"
default: default:
return "unknown" return "unknown"
} }

View File

@ -86,24 +86,18 @@ func getHeader(hdr string, md map[string]string) string {
} }
func getHeaders(m *codec.Message) { func getHeaders(m *codec.Message) {
get := func(hdr, v string) string { set := func(v, hdr string) string {
if len(v) > 0 { if len(v) > 0 {
return v return v
} }
return m.Header[hdr]
if hd := m.Header[hdr]; len(hd) > 0 {
return hd
} }
// old m.Id = set(m.Id, "Micro-Id")
return m.Header["X-"+hdr] m.Error = set(m.Error, "Micro-Error")
} m.Endpoint = set(m.Endpoint, "Micro-Endpoint")
m.Method = set(m.Method, "Micro-Method")
m.Id = get("Micro-Id", m.Id) m.Target = set(m.Target, "Micro-Service")
m.Error = get("Micro-Error", m.Error)
m.Endpoint = get("Micro-Endpoint", m.Endpoint)
m.Method = get("Micro-Method", m.Method)
m.Target = get("Micro-Service", m.Target)
// TODO: remove this cruft // TODO: remove this cruft
if len(m.Endpoint) == 0 { if len(m.Endpoint) == 0 {
@ -321,7 +315,6 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
// write an error if it failed // write an error if it failed
m.Error = errors.Wrapf(err, "Unable to encode body").Error() m.Error = errors.Wrapf(err, "Unable to encode body").Error()
m.Header["X-Micro-Error"] = m.Error
m.Header["Micro-Error"] = m.Error m.Header["Micro-Error"] = m.Error
// no body to write // no body to write
if err := c.codec.Write(m, nil); err != nil { if err := c.codec.Write(m, nil); err != nil {

View File

@ -549,6 +549,7 @@ func (s *rpcServer) Register() error {
node.Metadata["protocol"] = "mucp" node.Metadata["protocol"] = "mucp"
s.RLock() s.RLock()
// Maps are ordered randomly, sort the keys for consistency // Maps are ordered randomly, sort the keys for consistency
var handlerList []string var handlerList []string
for n, e := range s.handlers { for n, e := range s.handlers {
@ -557,6 +558,7 @@ func (s *rpcServer) Register() error {
handlerList = append(handlerList, n) handlerList = append(handlerList, n)
} }
} }
sort.Strings(handlerList) sort.Strings(handlerList)
var subscriberList []Subscriber var subscriberList []Subscriber
@ -566,18 +568,20 @@ func (s *rpcServer) Register() error {
subscriberList = append(subscriberList, e) subscriberList = append(subscriberList, e)
} }
} }
sort.Slice(subscriberList, func(i, j int) bool { sort.Slice(subscriberList, func(i, j int) bool {
return subscriberList[i].Topic() > subscriberList[j].Topic() return subscriberList[i].Topic() > subscriberList[j].Topic()
}) })
endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList)) endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList))
for _, n := range handlerList { for _, n := range handlerList {
endpoints = append(endpoints, s.handlers[n].Endpoints()...) endpoints = append(endpoints, s.handlers[n].Endpoints()...)
} }
for _, e := range subscriberList { for _, e := range subscriberList {
endpoints = append(endpoints, e.Endpoints()...) endpoints = append(endpoints, e.Endpoints()...)
} }
s.RUnlock()
service := &registry.Service{ service := &registry.Service{
Name: config.Name, Name: config.Name,
@ -586,9 +590,10 @@ func (s *rpcServer) Register() error {
Endpoints: endpoints, Endpoints: endpoints,
} }
s.Lock() // get registered value
registered := s.registered registered := s.registered
s.Unlock()
s.RUnlock()
if !registered { if !registered {
log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id) log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id)
@ -610,6 +615,8 @@ func (s *rpcServer) Register() error {
defer s.Unlock() defer s.Unlock()
s.registered = true s.registered = true
// set what we're advertising
s.opts.Advertise = addr
// subscribe to the topic with own name // subscribe to the topic with own name
sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent) sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent)

View File

@ -9,9 +9,9 @@ import (
"github.com/micro/go-micro/client" "github.com/micro/go-micro/client"
"github.com/micro/go-micro/config/cmd" "github.com/micro/go-micro/config/cmd"
"github.com/micro/go-micro/debug/service/handler"
"github.com/micro/go-micro/debug/profile" "github.com/micro/go-micro/debug/profile"
"github.com/micro/go-micro/debug/profile/pprof" "github.com/micro/go-micro/debug/profile/pprof"
"github.com/micro/go-micro/debug/service/handler"
"github.com/micro/go-micro/plugin" "github.com/micro/go-micro/plugin"
"github.com/micro/go-micro/server" "github.com/micro/go-micro/server"
"github.com/micro/go-micro/util/log" "github.com/micro/go-micro/util/log"

View File

@ -127,7 +127,6 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) {
closed: make(chan bool), closed: make(chan bool),
recv: make(chan *message, 128), recv: make(chan *message, 128),
send: t.send, send: t.send,
wait: make(chan bool),
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
@ -198,8 +197,8 @@ func (t *tun) announce(channel, session string, link *link) {
} }
} }
// monitor monitors outbound links and attempts to reconnect to the failed ones // manage monitors outbound links and attempts to reconnect to the failed ones
func (t *tun) monitor() { func (t *tun) manage() {
reconnect := time.NewTicker(ReconnectTime) reconnect := time.NewTicker(ReconnectTime)
defer reconnect.Stop() defer reconnect.Stop()
@ -208,9 +207,48 @@ func (t *tun) monitor() {
case <-t.closed: case <-t.closed:
return return
case <-reconnect.C: case <-reconnect.C:
t.manageLinks()
}
}
}
// manageLink sends channel discover requests periodically and
// keepalive messages to link
func (t *tun) manageLink(link *link) {
keepalive := time.NewTicker(KeepAliveTime)
defer keepalive.Stop()
discover := time.NewTicker(DiscoverTime)
defer discover.Stop()
for {
select {
case <-t.closed:
return
case <-link.closed:
return
case <-discover.C:
// send a discovery message to all links
if err := t.sendMsg("discover", link); err != nil {
log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err)
}
case <-keepalive.C:
// send keepalive message
log.Debugf("Tunnel sending keepalive to link: %v", link.Remote())
if err := t.sendMsg("keepalive", link); err != nil {
log.Debugf("Tunnel error sending keepalive to link %v: %v", link.Remote(), err)
t.delLink(link.Remote())
return
}
}
}
}
// manageLinks is a function that can be called to immediately to link setup
func (t *tun) manageLinks() {
var delLinks []string
t.RLock() t.RLock()
var delLinks []string
// check the link status and purge dead links // check the link status and purge dead links
for node, link := range t.links { for node, link := range t.links {
// check link status // check link status
@ -242,27 +280,46 @@ func (t *tun) monitor() {
// build list of unknown nodes to connect to // build list of unknown nodes to connect to
t.RLock() t.RLock()
for _, node := range t.options.Nodes { for _, node := range t.options.Nodes {
if _, ok := t.links[node]; !ok { if _, ok := t.links[node]; !ok {
connect = append(connect, node) connect = append(connect, node)
} }
} }
t.RUnlock() t.RUnlock()
var wg sync.WaitGroup
for _, node := range connect { for _, node := range connect {
wg.Add(1)
go func(node string) {
defer wg.Done()
// create new link // create new link
link, err := t.setupLink(node) link, err := t.setupLink(node)
if err != nil { if err != nil {
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
continue return
} }
// save the link // save the link
t.Lock() t.Lock()
defer t.Unlock()
// just check nothing else was setup in the interim
if _, ok := t.links[node]; ok {
link.Close()
return
}
// save the link
t.links[node] = link t.links[node] = link
t.Unlock() }(node)
}
}
} }
// wait for all threads to finish
wg.Wait()
} }
// process outgoing messages sent by all local sessions // process outgoing messages sent by all local sessions
@ -328,6 +385,7 @@ func (t *tun) process() {
// and the message is being sent outbound via // and the message is being sent outbound via
// a dialled connection don't use this link // a dialled connection don't use this link
if loopback && msg.outbound { if loopback && msg.outbound {
log.Tracef("Link for node %s is loopback", node)
err = errors.New("link is loopback") err = errors.New("link is loopback")
continue continue
} }
@ -335,6 +393,7 @@ func (t *tun) process() {
// if the message was being returned by the loopback listener // if the message was being returned by the loopback listener
// send it back up the loopback link only // send it back up the loopback link only
if msg.loopback && !loopback { if msg.loopback && !loopback {
log.Tracef("Link for message %s is loopback", node)
err = errors.New("link is not loopback") err = errors.New("link is not loopback")
continue continue
} }
@ -364,7 +423,7 @@ func (t *tun) process() {
// send the message // send the message
for _, link := range sendTo { for _, link := range sendTo {
// send the message via the current link // send the message via the current link
log.Tracef("Sending %+v to %s", newMsg.Header, link.Remote()) log.Tracef("Tunnel sending %+v to %s", newMsg.Header, link.Remote())
if errr := link.Send(newMsg); errr != nil { if errr := link.Send(newMsg); errr != nil {
log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr) log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr)
@ -470,6 +529,9 @@ func (t *tun) listen(link *link) {
return return
} }
// this state machine block handles the only message types
// that we know or care about; connect, close, open, accept,
// discover, announce, session, keepalive
switch mtype { switch mtype {
case "connect": case "connect":
log.Debugf("Tunnel link %s received connect message", link.Remote()) log.Debugf("Tunnel link %s received connect message", link.Remote())
@ -495,14 +557,14 @@ func (t *tun) listen(link *link) {
t.links[link.Remote()] = link t.links[link.Remote()] = link
t.Unlock() t.Unlock()
// send back a discovery // send back an announcement of our channels discovery
go t.announce("", "", link) go t.announce("", "", link)
// ask for the things on the other wise
go t.sendMsg("discover", link)
// nothing more to do // nothing more to do
continue continue
case "close": case "close":
// TODO: handle the close message log.Debugf("Tunnel link %s received close message", link.Remote())
// maybe report io.EOF or kill the link
// if there is no channel then we close the link // if there is no channel then we close the link
// as its a signal from the other side to close the connection // as its a signal from the other side to close the connection
if len(channel) == 0 { if len(channel) == 0 {
@ -521,6 +583,8 @@ func (t *tun) listen(link *link) {
// try get the dialing socket // try get the dialing socket
s, exists := t.getSession(channel, sessionId) s, exists := t.getSession(channel, sessionId)
if exists && !loopback { if exists && !loopback {
// only delete the session if its unicast
// otherwise ignore close on the multicast
if s.mode == Unicast { if s.mode == Unicast {
// only delete this if its unicast // only delete this if its unicast
// but not if its a loopback conn // but not if its a loopback conn
@ -541,20 +605,24 @@ func (t *tun) listen(link *link) {
// an accept returned by the listener // an accept returned by the listener
case "accept": case "accept":
s, exists := t.getSession(channel, sessionId) s, exists := t.getSession(channel, sessionId)
// we don't need this // just set accepted on anything not unicast
if exists && s.mode > Unicast { if exists && s.mode > Unicast {
s.accepted = true s.accepted = true
continue continue
} }
// if its already accepted move on
if exists && s.accepted { if exists && s.accepted {
continue continue
} }
// otherwise we're going to process to accept
// a continued session // a continued session
case "session": case "session":
// process message // process message
log.Tracef("Received %+v from %s", msg.Header, link.Remote()) log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote())
// an announcement of a channel listener // an announcement of a channel listener
case "announce": case "announce":
log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote())
// process the announcement // process the announcement
channels := strings.Split(channel, ",") channels := strings.Split(channel, ",")
@ -562,7 +630,10 @@ func (t *tun) listen(link *link) {
link.setChannel(channels...) link.setChannel(channels...)
// this was an announcement not intended for anything // this was an announcement not intended for anything
if sessionId == "listener" || sessionId == "" { // if the dialing side sent "discover" then a session
// id would be present. We skip in case of multicast.
switch sessionId {
case "listener", "multicast", "":
continue continue
} }
@ -574,14 +645,19 @@ func (t *tun) listen(link *link) {
continue continue
} }
// send the announce back to the caller msg := &message{
s.recv <- &message{
typ: "announce", typ: "announce",
tunnel: id, tunnel: id,
channel: channel, channel: channel,
session: sessionId, session: sessionId,
link: link.id, link: link.id,
} }
// send the announce back to the caller
select {
case <-s.closed:
case s.recv <- msg:
}
} }
continue continue
case "discover": case "discover":
@ -618,7 +694,7 @@ func (t *tun) listen(link *link) {
s, exists = t.getSession(channel, "listener") s, exists = t.getSession(channel, "listener")
// only return accept to the session // only return accept to the session
case mtype == "accept": case mtype == "accept":
log.Debugf("Received accept message for %s %s", channel, sessionId) log.Debugf("Tunnel received accept message for channel: %s session: %s", channel, sessionId)
s, exists = t.getSession(channel, sessionId) s, exists = t.getSession(channel, sessionId)
if exists && s.accepted { if exists && s.accepted {
continue continue
@ -638,7 +714,7 @@ func (t *tun) listen(link *link) {
// bail if no session or listener has been found // bail if no session or listener has been found
if !exists { if !exists {
log.Debugf("Tunnel skipping no session %s %s exists", channel, sessionId) log.Tracef("Tunnel skipping no channel: %s session: %s exists", channel, sessionId)
// drop it, we don't care about // drop it, we don't care about
// messages we don't know about // messages we don't know about
continue continue
@ -651,22 +727,10 @@ func (t *tun) listen(link *link) {
delete(t.sessions, channel) delete(t.sessions, channel)
continue continue
default: default:
// process // otherwise process
} }
log.Debugf("Tunnel using channel %s session %s", s.channel, s.session) log.Tracef("Tunnel using channel: %s session: %s type: %s", s.channel, s.session, mtype)
// is the session new?
select {
// if its new the session is actually blocked waiting
// for a connection. so we check if its waiting.
case <-s.wait:
// if its waiting e.g its new then we close it
default:
// set remote address of the session
s.remote = msg.Header["Remote"]
close(s.wait)
}
// construct a new transport message // construct a new transport message
tmsg := &transport.Message{ tmsg := &transport.Message{
@ -696,68 +760,26 @@ func (t *tun) listen(link *link) {
} }
} }
// discover sends channel discover requests periodically func (t *tun) sendMsg(method string, link *link) error {
func (t *tun) discover(link *link) { return link.Send(&transport.Message{
tick := time.NewTicker(DiscoverTime)
defer tick.Stop()
for {
select {
case <-tick.C:
// send a discovery message to all links
if err := link.Send(&transport.Message{
Header: map[string]string{ Header: map[string]string{
"Micro-Tunnel": "discover", "Micro-Tunnel": method,
"Micro-Tunnel-Id": t.id, "Micro-Tunnel-Id": t.id,
}, },
}); err != nil { })
log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err)
}
case <-link.closed:
return
case <-t.closed:
return
}
}
}
// keepalive periodically sends keepalive messages to link
func (t *tun) keepalive(link *link) {
keepalive := time.NewTicker(KeepAliveTime)
defer keepalive.Stop()
for {
select {
case <-t.closed:
return
case <-link.closed:
return
case <-keepalive.C:
// send keepalive message
log.Debugf("Tunnel sending keepalive to link: %v", link.Remote())
if err := link.Send(&transport.Message{
Header: map[string]string{
"Micro-Tunnel": "keepalive",
"Micro-Tunnel-Id": t.id,
},
}); err != nil {
log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err)
t.delLink(link.Remote())
return
}
}
}
} }
// setupLink connects to node and returns link if successful // setupLink connects to node and returns link if successful
// It returns error if the link failed to be established // It returns error if the link failed to be established
func (t *tun) setupLink(node string) (*link, error) { func (t *tun) setupLink(node string) (*link, error) {
log.Debugf("Tunnel setting up link: %s", node) log.Debugf("Tunnel setting up link: %s", node)
c, err := t.options.Transport.Dial(node) c, err := t.options.Transport.Dial(node)
if err != nil { if err != nil {
log.Debugf("Tunnel failed to connect to %s: %v", node, err) log.Debugf("Tunnel failed to connect to %s: %v", node, err)
return nil, err return nil, err
} }
log.Debugf("Tunnel connected to %s", node) log.Debugf("Tunnel connected to %s", node)
// create a new link // create a new link
@ -766,12 +788,8 @@ func (t *tun) setupLink(node string) (*link, error) {
link.id = c.Remote() link.id = c.Remote()
// send the first connect message // send the first connect message
if err := link.Send(&transport.Message{ if err := t.sendMsg("connect", link); err != nil {
Header: map[string]string{ link.Close()
"Micro-Tunnel": "connect",
"Micro-Tunnel-Id": t.id,
},
}); err != nil {
return nil, err return nil, err
} }
@ -782,37 +800,40 @@ func (t *tun) setupLink(node string) (*link, error) {
// process incoming messages // process incoming messages
go t.listen(link) go t.listen(link)
// start keepalive monitor // manage keepalives and discovery messages
go t.keepalive(link) go t.manageLink(link)
// discover things on the remote side
go t.discover(link)
return link, nil return link, nil
} }
func (t *tun) setupLinks() { func (t *tun) setupLinks() {
var wg sync.WaitGroup
for _, node := range t.options.Nodes { for _, node := range t.options.Nodes {
// skip zero length nodes wg.Add(1)
if len(node) == 0 {
continue
}
// link already exists go func(node string) {
defer wg.Done()
// we're not trying to fix existing links
if _, ok := t.links[node]; ok { if _, ok := t.links[node]; ok {
continue return
} }
// connect to node and return link // create new link
link, err := t.setupLink(node) link, err := t.setupLink(node)
if err != nil { if err != nil {
log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
continue return
} }
// save the link // save the link
t.links[node] = link t.links[node] = link
}(node)
} }
// wait for all threads to finish
wg.Wait()
} }
// connect the tunnel to all the nodes and listen for incoming tunnel connections // connect the tunnel to all the nodes and listen for incoming tunnel connections
@ -833,11 +854,8 @@ func (t *tun) connect() error {
// create a new link // create a new link
link := newLink(sock) link := newLink(sock)
// start keepalive monitor // manage the link
go t.keepalive(link) go t.manageLink(link)
// discover things on the remote side
go t.discover(link)
// listen for inbound messages. // listen for inbound messages.
// only save the link once connected. // only save the link once connected.
@ -864,12 +882,13 @@ func (t *tun) Connect() error {
// already connected // already connected
if t.connected { if t.connected {
// setup links // do it immediately
t.setupLinks() t.setupLinks()
// setup links
return nil return nil
} }
// send the connect message // connect the tunnel: start the listener
if err := t.connect(); err != nil { if err := t.connect(); err != nil {
return err return err
} }
@ -879,15 +898,15 @@ func (t *tun) Connect() error {
// create new close channel // create new close channel
t.closed = make(chan bool) t.closed = make(chan bool)
// setup links
t.setupLinks()
// process outbound messages to be sent // process outbound messages to be sent
// process sends to all links // process sends to all links
go t.process() go t.process()
// monitor links // call setup before managing them
go t.monitor() t.setupLinks()
// manage the links
go t.manage()
return nil return nil
} }
@ -1025,7 +1044,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
// set the multicast option // set the multicast option
c.mode = options.Mode c.mode = options.Mode
// set the dial timeout // set the dial timeout
c.timeout = options.Timeout c.dialTimeout = options.Timeout
// set read timeout set to never
c.readTimeout = time.Duration(-1)
var links []*link var links []*link
// did we measure the rtt // did we measure the rtt
@ -1052,7 +1073,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
t.RUnlock() t.RUnlock()
// link not found // link not found and one was specified so error out
if len(links) == 0 && len(options.Link) > 0 { if len(links) == 0 && len(options.Link) > 0 {
// delete session and return error // delete session and return error
t.delSession(c.channel, c.session) t.delSession(c.channel, c.session)
@ -1061,15 +1082,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
} }
// discovered so set the link if not multicast // discovered so set the link if not multicast
// TODO: pick the link efficiently based
// on link status and saturation.
if c.discovered && c.mode == Unicast { if c.discovered && c.mode == Unicast {
// pickLink will pick the best link // pickLink will pick the best link
link := t.pickLink(links) link := t.pickLink(links)
// set the link
c.link = link.id c.link = link.id
} }
// shit fuck // if its not already discovered we need to attempt to do so
if !c.discovered { if !c.discovered {
// piggy back roundtrip // piggy back roundtrip
nowRTT := time.Now() nowRTT := time.Now()
@ -1098,7 +1118,15 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
} }
} }
// a unicast session so we call "open" and wait for an "accept" // return early if its not unicast
// we will not call "open" for multicast
if c.mode != Unicast {
return c, nil
}
// Note: we go no further for multicast or broadcast.
// This is a unicast session so we call "open" and wait
// for an "accept"
// reset now in case we use it // reset now in case we use it
now := time.Now() now := time.Now()
@ -1115,7 +1143,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
d := time.Since(now) d := time.Since(now)
// if we haven't measured the roundtrip do it now // if we haven't measured the roundtrip do it now
if !measured && c.mode == Unicast { if !measured {
// set the link time // set the link time
t.RLock() t.RLock()
link, ok := t.links[c.link] link, ok := t.links[c.link]
@ -1134,7 +1162,11 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
log.Debugf("Tunnel listening on %s", channel) log.Debugf("Tunnel listening on %s", channel)
var options ListenOptions options := ListenOptions{
// Read timeout defaults to never
Timeout: time.Duration(-1),
}
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
@ -1145,6 +1177,7 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
return nil, errors.New("already listening on " + channel) return nil, errors.New("already listening on " + channel)
} }
// delete function removes the session when closed
delFunc := func() { delFunc := func() {
t.delSession(channel, "listener") t.delSession(channel, "listener")
} }
@ -1155,6 +1188,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
c.local = channel c.local = channel
// set mode // set mode
c.mode = options.Mode c.mode = options.Mode
// set the timeout
c.readTimeout = options.Timeout
tl := &tunListener{ tl := &tunListener{
channel: channel, channel: channel,

View File

@ -2,6 +2,7 @@ package tunnel
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"sync" "sync"
"time" "time"
@ -14,7 +15,11 @@ import (
type link struct { type link struct {
transport.Socket transport.Socket
// transport to use for connections
transport transport.Transport
sync.RWMutex sync.RWMutex
// stops the link // stops the link
closed chan bool closed chan bool
// link state channel for testing link // link state channel for testing link
@ -65,6 +70,8 @@ var (
linkRequest = []byte{0, 0, 0, 0} linkRequest = []byte{0, 0, 0, 0}
// the 4 byte 1 filled packet sent to determine link state // the 4 byte 1 filled packet sent to determine link state
linkResponse = []byte{1, 1, 1, 1} linkResponse = []byte{1, 1, 1, 1}
ErrLinkConnectTimeout = errors.New("link connect timeout")
) )
func newLink(s transport.Socket) *link { func newLink(s transport.Socket) *link {
@ -72,8 +79,8 @@ func newLink(s transport.Socket) *link {
Socket: s, Socket: s,
id: uuid.New().String(), id: uuid.New().String(),
lastKeepAlive: time.Now(), lastKeepAlive: time.Now(),
channels: make(map[string]time.Time),
closed: make(chan bool), closed: make(chan bool),
channels: make(map[string]time.Time),
state: make(chan *packet, 64), state: make(chan *packet, 64),
sendQueue: make(chan *packet, 128), sendQueue: make(chan *packet, 128),
recvQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128),
@ -87,6 +94,32 @@ func newLink(s transport.Socket) *link {
return l return l
} }
func (l *link) connect(addr string) error {
c, err := l.transport.Dial(addr)
if err != nil {
return err
}
l.Lock()
l.Socket = c
l.Unlock()
return nil
}
func (l *link) accept(sock transport.Socket) error {
l.Lock()
l.Socket = sock
l.Unlock()
return nil
}
func (l *link) setLoopback(v bool) {
l.Lock()
l.loopback = v
l.Unlock()
}
// setRate sets the bits per second rate as a float64 // setRate sets the bits per second rate as a float64
func (l *link) setRate(bits int64, delta time.Duration) { func (l *link) setRate(bits int64, delta time.Duration) {
// rate of send in bits per nanosecond // rate of send in bits per nanosecond
@ -167,6 +200,8 @@ func (l *link) process() {
// process link state message // process link state message
select { select {
case l.state <- pk: case l.state <- pk:
case <-l.closed:
return
default: default:
} }
continue continue
@ -188,7 +223,11 @@ func (l *link) process() {
select { select {
case pk := <-l.sendQueue: case pk := <-l.sendQueue:
// send the message // send the message
pk.status <- l.send(pk.message) select {
case pk.status <- l.send(pk.message):
case <-l.closed:
return
}
case <-l.closed: case <-l.closed:
return return
} }
@ -201,11 +240,15 @@ func (l *link) manage() {
t := time.NewTicker(time.Minute) t := time.NewTicker(time.Minute)
defer t.Stop() defer t.Stop()
// get link id
linkId := l.Id()
// used to send link state packets // used to send link state packets
send := func(b []byte) error { send := func(b []byte) error {
return l.Send(&transport.Message{ return l.Send(&transport.Message{
Header: map[string]string{ Header: map[string]string{
"Micro-Method": "link", "Micro-Method": "link",
"Micro-Link-Id": linkId,
}, Body: b, }, Body: b,
}) })
} }
@ -229,9 +272,7 @@ func (l *link) manage() {
// check the type of message // check the type of message
switch { switch {
case bytes.Equal(p.message.Body, linkRequest): case bytes.Equal(p.message.Body, linkRequest):
l.RLock() log.Tracef("Link %s received link request", linkId)
log.Tracef("Link %s received link request %v", l.id, p.message.Body)
l.RUnlock()
// send response // send response
if err := send(linkResponse); err != nil { if err := send(linkResponse); err != nil {
@ -242,9 +283,7 @@ func (l *link) manage() {
case bytes.Equal(p.message.Body, linkResponse): case bytes.Equal(p.message.Body, linkResponse):
// set round trip time // set round trip time
d := time.Since(now) d := time.Since(now)
l.RLock() log.Tracef("Link %s received link response in %v", linkId, d)
log.Tracef("Link %s received link response in %v", p.message.Body, d)
l.RUnlock()
// set the RTT // set the RTT
l.setRTT(d) l.setRTT(d)
} }
@ -309,6 +348,12 @@ func (l *link) Rate() float64 {
return l.rate return l.rate
} }
func (l *link) Loopback() bool {
l.RLock()
defer l.RUnlock()
return l.loopback
}
// Length returns the roundtrip time as nanoseconds (lower is better). // Length returns the roundtrip time as nanoseconds (lower is better).
// Returns 0 where no measurement has been taken. // Returns 0 where no measurement has been taken.
func (l *link) Length() int64 { func (l *link) Length() int64 {
@ -320,7 +365,6 @@ func (l *link) Length() int64 {
func (l *link) Id() string { func (l *link) Id() string {
l.RLock() l.RLock()
defer l.RUnlock() defer l.RUnlock()
return l.id return l.id
} }
@ -350,13 +394,6 @@ func (l *link) Send(m *transport.Message) error {
// get time now // get time now
now := time.Now() now := time.Now()
// check if its closed first
select {
case <-l.closed:
return io.EOF
default:
}
// queue the message // queue the message
select { select {
case l.sendQueue <- p: case l.sendQueue <- p:

View File

@ -24,7 +24,7 @@ type tunListener struct {
delFunc func() delFunc func()
} }
// periodically announce self // periodically announce self the channel being listened on
func (t *tunListener) announce() { func (t *tunListener) announce() {
tick := time.NewTicker(time.Second * 30) tick := time.NewTicker(time.Second * 30)
defer tick.Stop() defer tick.Stop()
@ -48,9 +48,12 @@ func (t *tunListener) process() {
defer func() { defer func() {
// close the sessions // close the sessions
for _, conn := range conns { for id, conn := range conns {
conn.Close() conn.Close()
delete(conns, id)
} }
// unassign
conns = nil
}() }()
for { for {
@ -62,9 +65,24 @@ func (t *tunListener) process() {
return return
// receive a new message // receive a new message
case m := <-t.session.recv: case m := <-t.session.recv:
var sessionId string
var linkId string
switch m.mode {
case Multicast:
sessionId = "multicast"
linkId = "multicast"
case Broadcast:
sessionId = "broadcast"
linkId = "broadcast"
default:
sessionId = m.session
linkId = m.link
}
// get a session // get a session
sess, ok := conns[m.session] sess, ok := conns[sessionId]
log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok) log.Tracef("Tunnel listener received channel %s session %s type %s exists: %t", m.channel, m.session, m.typ, ok)
if !ok { if !ok {
// we only process open and session types // we only process open and session types
switch m.typ { switch m.typ {
@ -80,13 +98,13 @@ func (t *tunListener) process() {
// the channel // the channel
channel: m.channel, channel: m.channel,
// the session id // the session id
session: m.session, session: sessionId,
// tunnel token // tunnel token
token: t.token, token: t.token,
// is loopback conn // is loopback conn
loopback: m.loopback, loopback: m.loopback,
// the link the message was received on // the link the message was received on
link: m.link, link: linkId,
// set the connection mode // set the connection mode
mode: m.mode, mode: m.mode,
// close chan // close chan
@ -95,14 +113,14 @@ func (t *tunListener) process() {
recv: make(chan *message, 128), recv: make(chan *message, 128),
// use the internal send buffer // use the internal send buffer
send: t.session.send, send: t.session.send,
// wait
wait: make(chan bool),
// error channel // error channel
errChan: make(chan error, 1), errChan: make(chan error, 1),
// set the read timeout
readTimeout: t.session.readTimeout,
} }
// save the session // save the session
conns[m.session] = sess conns[sessionId] = sess
select { select {
case <-t.closed: case <-t.closed:
@ -114,17 +132,19 @@ func (t *tunListener) process() {
// an existing session was found // an existing session was found
// received a close message
switch m.typ { switch m.typ {
case "close": case "close":
// received a close message
select { select {
// check if the session is closed
case <-sess.closed: case <-sess.closed:
// no op // no op
delete(conns, m.session) delete(conns, sessionId)
default: default:
// only close if unicast session
// close and delete session // close and delete session
close(sess.closed) close(sess.closed)
delete(conns, m.session) delete(conns, sessionId)
} }
// continue // continue
@ -139,9 +159,9 @@ func (t *tunListener) process() {
// send this to the accept chan // send this to the accept chan
select { select {
case <-sess.closed: case <-sess.closed:
delete(conns, m.session) delete(conns, sessionId)
case sess.recv <- m: case sess.recv <- m:
log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session) log.Tracef("Tunnel listener sent to recv chan channel %s session %s type %s", m.channel, sessionId, m.typ)
} }
} }
} }
@ -180,6 +200,10 @@ func (t *tunListener) Accept() (Session, error) {
if !ok { if !ok {
return nil, io.EOF return nil, io.EOF
} }
// return without accept
if c.mode != Unicast {
return c, nil
}
// send back the accept // send back the accept
if err := c.Accept(); err != nil { if err := c.Accept(); err != nil {
return nil, err return nil, err

View File

@ -47,6 +47,8 @@ type ListenOption func(*ListenOptions)
type ListenOptions struct { type ListenOptions struct {
// specify mode of the session // specify mode of the session
Mode Mode Mode Mode
// The read timeout
Timeout time.Duration
} }
// The tunnel id // The tunnel id
@ -84,16 +86,6 @@ func Transport(t transport.Transport) Option {
} }
} }
// DefaultOptions returns router default options
func DefaultOptions() Options {
return Options{
Id: uuid.New().String(),
Address: DefaultAddress,
Token: DefaultToken,
Transport: quic.NewTransport(),
}
}
// Listen options // Listen options
func ListenMode(m Mode) ListenOption { func ListenMode(m Mode) ListenOption {
return func(o *ListenOptions) { return func(o *ListenOptions) {
@ -101,6 +93,13 @@ func ListenMode(m Mode) ListenOption {
} }
} }
// Timeout for reads and writes on the listener session
func ListenTimeout(t time.Duration) ListenOption {
return func(o *ListenOptions) {
o.Timeout = t
}
}
// Dial options // Dial options
// Dial multicast sets the multicast option to send only to those mapped // Dial multicast sets the multicast option to send only to those mapped
@ -124,3 +123,13 @@ func DialLink(id string) DialOption {
o.Link = id o.Link = id
} }
} }
// DefaultOptions returns router default options
func DefaultOptions() Options {
return Options{
Id: uuid.New().String(),
Address: DefaultAddress,
Token: DefaultToken,
Transport: quic.NewTransport(),
}
}

View File

@ -2,7 +2,6 @@ package tunnel
import ( import (
"encoding/hex" "encoding/hex"
"errors"
"io" "io"
"time" "time"
@ -30,8 +29,6 @@ type session struct {
send chan *message send chan *message
// recv chan // recv chan
recv chan *message recv chan *message
// wait until we have a connection
wait chan bool
// if the discovery worked // if the discovery worked
discovered bool discovered bool
// if the session was accepted // if the session was accepted
@ -42,8 +39,10 @@ type session struct {
loopback bool loopback bool
// mode of the connection // mode of the connection
mode Mode mode Mode
// the timeout // the dial timeout
timeout time.Duration dialTimeout time.Duration
// the read timeout
readTimeout time.Duration
// the link on which this message was received // the link on which this message was received
link string link string
// the error response // the error response
@ -109,65 +108,114 @@ func (s *session) newMessage(typ string) *message {
} }
} }
func (s *session) sendMsg(msg *message) error {
select {
case <-s.closed:
return io.EOF
case s.send <- msg:
return nil
}
}
func (s *session) wait(msg *message) error {
// wait for an error response
select {
case err := <-msg.errChan:
if err != nil {
return err
}
case <-s.closed:
return io.EOF
}
return nil
}
// waitFor waits for the message type required until the timeout specified // waitFor waits for the message type required until the timeout specified
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
now := time.Now() now := time.Now()
after := func(timeout time.Duration) time.Duration { after := func(timeout time.Duration) <-chan time.Time {
if timeout < time.Duration(0) {
return nil
}
// get the delta
d := time.Since(now) d := time.Since(now)
// dial timeout minus time since // dial timeout minus time since
wait := timeout - d wait := timeout - d
if wait < time.Duration(0) { if wait < time.Duration(0) {
return time.Duration(0) wait = time.Duration(0)
} }
return wait return time.After(wait)
} }
// wait for the message type // wait for the message type
for { for {
select { select {
case msg := <-s.recv: case msg := <-s.recv:
// there may be no message type
if len(msgType) == 0 {
return msg, nil
}
// ignore what we don't want // ignore what we don't want
if msg.typ != msgType { if msg.typ != msgType {
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
continue continue
} }
// got the message // got the message
return msg, nil return msg, nil
case <-time.After(after(timeout)): case <-after(timeout):
return nil, ErrDialTimeout return nil, ErrReadTimeout
case <-s.closed: case <-s.closed:
return nil, io.EOF return nil, io.EOF
} }
} }
} }
// Discover attempts to discover the link for a specific channel // Discover attempts to discover the link for a specific channel.
// This is only used by the tunnel.Dial when first connecting.
func (s *session) Discover() error { func (s *session) Discover() error {
// create a new discovery message for this channel // create a new discovery message for this channel
msg := s.newMessage("discover") msg := s.newMessage("discover")
// broadcast the message to all links
msg.mode = Broadcast msg.mode = Broadcast
// its an outbound connection since we're dialling
msg.outbound = true msg.outbound = true
// don't set the link since we don't know where it is
msg.link = "" msg.link = ""
// send the discovery message // if multicast then set that as session
s.send <- msg if s.mode == Multicast {
msg.session = "multicast"
}
// send discover message
if err := s.sendMsg(msg); err != nil {
return err
}
// set time now // set time now
now := time.Now() now := time.Now()
// after strips down the dial timeout
after := func() time.Duration { after := func() time.Duration {
d := time.Since(now) d := time.Since(now)
// dial timeout minus time since // dial timeout minus time since
wait := s.timeout - d wait := s.dialTimeout - d
// make sure its always > 0
if wait < time.Duration(0) { if wait < time.Duration(0) {
return time.Duration(0) return time.Duration(0)
} }
return wait return wait
} }
// the discover message is sent out, now
// wait to hear back about the sent message // wait to hear back about the sent message
select { select {
case <-time.After(after()): case <-time.After(after()):
@ -178,27 +226,16 @@ func (s *session) Discover() error {
} }
} }
var err error // bail early if its not unicast
// we don't need to wait for the announce
// set a new dialTimeout
dialTimeout := after()
// set a shorter delay for multicast
if s.mode != Unicast {
// shorten this
dialTimeout = time.Millisecond * 500
}
// wait for announce
_, err = s.waitFor("announce", dialTimeout)
// if its multicast just go ahead because this is best effort
if s.mode != Unicast { if s.mode != Unicast {
s.discovered = true s.discovered = true
s.accepted = true s.accepted = true
return nil return nil
} }
// wait for announce
_, err := s.waitFor("announce", after())
if err != nil { if err != nil {
return err return err
} }
@ -210,31 +247,23 @@ func (s *session) Discover() error {
} }
// Open will fire the open message for the session. This is called by the dialler. // Open will fire the open message for the session. This is called by the dialler.
// This is to indicate that we want to create a new session.
func (s *session) Open() error { func (s *session) Open() error {
// create a new message // create a new message
msg := s.newMessage("open") msg := s.newMessage("open")
// send open message // send open message
s.send <- msg if err := s.sendMsg(msg); err != nil {
// wait for an error response for send
select {
case err := <-msg.errChan:
if err != nil {
return err return err
} }
case <-s.closed:
return io.EOF // wait for an error response for send
if err := s.wait(msg); err != nil {
return err
} }
// don't wait on multicast/broadcast // now wait for the accept message to be returned
if s.mode == Multicast { msg, err := s.waitFor("accept", s.dialTimeout)
s.accepted = true
return nil
}
// now wait for the accept
msg, err := s.waitFor("accept", s.timeout)
if err != nil { if err != nil {
return err return err
} }
@ -252,32 +281,16 @@ func (s *session) Accept() error {
msg := s.newMessage("accept") msg := s.newMessage("accept")
// send the accept message // send the accept message
select { if err := s.sendMsg(msg); err != nil {
case <-s.closed: return err
return io.EOF
case s.send <- msg:
// no op here
}
// don't wait on multicast/broadcast
if s.mode == Multicast {
return nil
} }
// wait for send response // wait for send response
select { return s.wait(msg)
case err := <-s.errChan:
if err != nil {
return err
}
case <-s.closed:
return io.EOF
} }
return nil // Announce sends an announcement to notify that this session exists.
} // This is primarily used by the listener.
// Announce sends an announcement to notify that this session exists. This is primarily used by the listener.
func (s *session) Announce() error { func (s *session) Announce() error {
msg := s.newMessage("announce") msg := s.newMessage("announce")
// we don't need an error back // we don't need an error back
@ -287,23 +300,12 @@ func (s *session) Announce() error {
// we don't need the link // we don't need the link
msg.link = "" msg.link = ""
select { // send announce message
case s.send <- msg: return s.sendMsg(msg)
return nil
case <-s.closed:
return io.EOF
}
} }
// Send is used to send a message // Send is used to send a message
func (s *session) Send(m *transport.Message) error { func (s *session) Send(m *transport.Message) error {
select {
case <-s.closed:
return io.EOF
default:
// no op
}
// encrypt the transport message payload // encrypt the transport message payload
body, err := Encrypt(m.Body, s.token+s.channel+s.session) body, err := Encrypt(m.Body, s.token+s.channel+s.session)
if err != nil { if err != nil {
@ -335,32 +337,28 @@ func (s *session) Send(m *transport.Message) error {
msg.data = data msg.data = data
// if multicast don't set the link // if multicast don't set the link
if s.mode == Multicast { if s.mode != Unicast {
msg.link = "" msg.link = ""
} }
log.Tracef("Appending %+v to send backlog", msg) log.Tracef("Appending %+v to send backlog", msg)
// send the actual message // send the actual message
s.send <- msg if err := s.sendMsg(msg); err != nil {
return err
}
// wait for an error response // wait for an error response
select { return s.wait(msg)
case err := <-msg.errChan:
return err
case <-s.closed:
return io.EOF
}
} }
// Recv is used to receive a message // Recv is used to receive a message
func (s *session) Recv(m *transport.Message) error { func (s *session) Recv(m *transport.Message) error {
var msg *message var msg *message
select { msg, err := s.waitFor("", s.readTimeout)
case <-s.closed: if err != nil {
return errors.New("session is closed") return err
// recv from backlog
case msg = <-s.recv:
} }
// check the error if one exists // check the error if one exists
@ -371,10 +369,13 @@ func (s *session) Recv(m *transport.Message) error {
} }
//log.Tracef("Received %+v from recv backlog", msg) //log.Tracef("Received %+v from recv backlog", msg)
log.Debugf("Received %+v from recv backlog", msg) log.Tracef("Received %+v from recv backlog", msg)
// decrypt the received payload using the token // decrypt the received payload using the token
body, err := Decrypt(msg.data.Body, s.token+s.channel+s.session) // we have to used msg.session because multicast has a shared
// session id of "multicast" in this session struct on
// the listener side
body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session)
if err != nil { if err != nil {
log.Debugf("failed to decrypt message body: %v", err) log.Debugf("failed to decrypt message body: %v", err)
return err return err
@ -390,7 +391,7 @@ func (s *session) Recv(m *transport.Message) error {
return err return err
} }
// encrypt the transport message payload // encrypt the transport message payload
val, err := Decrypt([]byte(h), s.token+s.channel+s.session) val, err := Decrypt([]byte(h), s.token+s.channel+msg.session)
if err != nil { if err != nil {
log.Debugf("failed to decrypt message header %s: %v", k, err) log.Debugf("failed to decrypt message header %s: %v", k, err)
return err return err
@ -399,6 +400,12 @@ func (s *session) Recv(m *transport.Message) error {
msg.data.Header[k] = string(val) msg.data.Header[k] = string(val)
} }
// set the link
// TODO: decruft, this is only for multicast
// since the session is now a single session
// likely provide as part of message.Link()
msg.data.Header["Micro-Link"] = msg.link
// set message // set message
*m = *msg.data *m = *msg.data
// return nil // return nil
@ -413,6 +420,11 @@ func (s *session) Close() error {
default: default:
close(s.closed) close(s.closed)
// don't send close on multicast or broadcast
if s.mode != Unicast {
return nil
}
// append to backlog // append to backlog
msg := s.newMessage("close") msg := s.newMessage("close")
// no error response on close // no error response on close
@ -421,7 +433,7 @@ func (s *session) Close() error {
// send the close message // send the close message
select { select {
case s.send <- msg: case s.send <- msg:
default: case <-time.After(time.Millisecond * 10):
} }
} }

View File

@ -26,6 +26,8 @@ var (
ErrDiscoverChan = errors.New("failed to discover channel") ErrDiscoverChan = errors.New("failed to discover channel")
// ErrLinkNotFound is returned when a link is specified at dial time and does not exist // ErrLinkNotFound is returned when a link is specified at dial time and does not exist
ErrLinkNotFound = errors.New("link not found") ErrLinkNotFound = errors.New("link not found")
// ErrReadTimeout is a timeout on session.Recv
ErrReadTimeout = errors.New("read timeout")
) )
// Mode of the session // Mode of the session
@ -64,7 +66,9 @@ type Link interface {
Length() int64 Length() int64
// Current transfer rate as bits per second (lower is better) // Current transfer rate as bits per second (lower is better)
Rate() float64 Rate() float64
// State of the link e.g connected/closed // Is this a loopback link
Loopback() bool
// State of the link: connected/closed/error
State() string State() string
// honours transport socket // honours transport socket
transport.Socket transport.Socket

View File

@ -1,55 +0,0 @@
// +build !race
package tunnel
import (
"sync"
"testing"
"time"
)
func TestReconnectTunnel(t *testing.T) {
// create a new tunnel client
tunA := NewTunnel(
Address("127.0.0.1:9096"),
Nodes("127.0.0.1:9097"),
)
// create a new tunnel server
tunB := NewTunnel(
Address("127.0.0.1:9097"),
)
// start tunnel
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// we manually override the tunnel.ReconnectTime value here
// this is so that we make the reconnects faster than the default 5s
ReconnectTime = 200 * time.Millisecond
// start tunnel
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start tunnel listener
go testBrokenTunAccept(t, tunB, wait, &wg)
wg.Add(1)
// start tunnel sender
go testBrokenTunSend(t, tunA, wait, &wg)
// wait until done
wg.Wait()
}

View File

@ -8,6 +8,90 @@ import (
"github.com/micro/go-micro/transport" "github.com/micro/go-micro/transport"
) )
func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done()
// listen on some virtual address
tl, err := tun.Listen("test-tunnel")
if err != nil {
t.Fatal(err)
}
// receiver ready; notify sender
wait <- true
// accept a connection
c, err := tl.Accept()
if err != nil {
t.Fatal(err)
}
// accept the message and close the tunnel
// we do this to simulate loss of network connection
m := new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
// close all the links
for _, link := range tun.Links() {
link.Close()
}
// receiver ready; notify sender
wait <- true
// accept the message
m = new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
// notify the sender we have received
wait <- true
}
func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup, reconnect time.Duration) {
defer wg.Done()
// wait for the listener to get ready
<-wait
// dial a new session
c, err := tun.Dial("test-tunnel")
if err != nil {
t.Fatal(err)
}
defer c.Close()
m := transport.Message{
Header: map[string]string{
"test": "send",
},
}
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// wait for the listener to get ready
<-wait
// give it time to reconnect
time.Sleep(reconnect)
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// wait for the listener to receive the message
// c.Send merely enqueues the message to the link send queue and returns
// in order to verify it was received we wait for the listener to tell us
<-wait
}
// testAccept will accept connections on the transport, create a new link and tunnel on top // testAccept will accept connections on the transport, create a new link and tunnel on top
func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
@ -163,90 +247,6 @@ func TestLoopbackTunnel(t *testing.T) {
wg.Wait() wg.Wait()
} }
func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done()
// listen on some virtual address
tl, err := tun.Listen("test-tunnel")
if err != nil {
t.Fatal(err)
}
// receiver ready; notify sender
wait <- true
// accept a connection
c, err := tl.Accept()
if err != nil {
t.Fatal(err)
}
// accept the message and close the tunnel
// we do this to simulate loss of network connection
m := new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
// close all the links
for _, link := range tun.Links() {
link.Close()
}
// receiver ready; notify sender
wait <- true
// accept the message
m = new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
// notify the sender we have received
wait <- true
}
func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done()
// wait for the listener to get ready
<-wait
// dial a new session
c, err := tun.Dial("test-tunnel")
if err != nil {
t.Fatal(err)
}
defer c.Close()
m := transport.Message{
Header: map[string]string{
"test": "send",
},
}
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// wait for the listener to get ready
<-wait
// give it time to reconnect
time.Sleep(5 * ReconnectTime)
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// wait for the listener to receive the message
// c.Send merely enqueues the message to the link send queue and returns
// in order to verify it was received we wait for the listener to tell us
<-wait
}
func TestTunnelRTTRate(t *testing.T) { func TestTunnelRTTRate(t *testing.T) {
// create a new tunnel client // create a new tunnel client
tunA := NewTunnel( tunA := NewTunnel(
@ -296,3 +296,49 @@ func TestTunnelRTTRate(t *testing.T) {
t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate()) t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate())
} }
} }
func TestReconnectTunnel(t *testing.T) {
// we manually override the tunnel.ReconnectTime value here
// this is so that we make the reconnects faster than the default 5s
ReconnectTime = 200 * time.Millisecond
// create a new tunnel client
tunA := NewTunnel(
Address("127.0.0.1:9098"),
Nodes("127.0.0.1:9099"),
)
// create a new tunnel server
tunB := NewTunnel(
Address("127.0.0.1:9099"),
)
// start tunnel
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// start tunnel
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start tunnel listener
go testBrokenTunAccept(t, tunB, wait, &wg)
wg.Add(1)
// start tunnel sender
go testBrokenTunSend(t, tunA, wait, &wg, ReconnectTime*5)
// wait until done
wg.Wait()
}