| @@ -316,6 +316,22 @@ func (r *rpcClient) Options() Options { | |||||||
| 	return r.opts | 	return r.opts | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // hasProxy checks if we have proxy set in the environment | ||||||
|  | func (r *rpcClient) hasProxy() bool { | ||||||
|  | 	// get proxy | ||||||
|  | 	if prx := os.Getenv("MICRO_PROXY"); len(prx) > 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// get proxy address | ||||||
|  | 	if prx := os.Getenv("MICRO_PROXY_ADDRESS"); len(prx) > 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // next returns an iterator for the next nodes to call | ||||||
| func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) { | func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) { | ||||||
| 	service := request.Service() | 	service := request.Service() | ||||||
|  |  | ||||||
| @@ -431,10 +447,18 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ch := make(chan error, callOpts.Retries+1) | 	// get the retries | ||||||
|  | 	retries := callOpts.Retries | ||||||
|  |  | ||||||
|  | 	// disable retries when using a proxy | ||||||
|  | 	if r.hasProxy() { | ||||||
|  | 		retries = 0 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ch := make(chan error, retries+1) | ||||||
| 	var gerr error | 	var gerr error | ||||||
|  |  | ||||||
| 	for i := 0; i <= callOpts.Retries; i++ { | 	for i := 0; i <= retries; i++ { | ||||||
| 		go func(i int) { | 		go func(i int) { | ||||||
| 			ch <- call(i) | 			ch <- call(i) | ||||||
| 		}(i) | 		}(i) | ||||||
| @@ -514,10 +538,18 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt | |||||||
| 		err    error | 		err    error | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ch := make(chan response, callOpts.Retries+1) | 	// get the retries | ||||||
|  | 	retries := callOpts.Retries | ||||||
|  |  | ||||||
|  | 	// disable retries when using a proxy | ||||||
|  | 	if r.hasProxy() { | ||||||
|  | 		retries = 0 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ch := make(chan response, retries+1) | ||||||
| 	var grr error | 	var grr error | ||||||
|  |  | ||||||
| 	for i := 0; i <= callOpts.Retries; i++ { | 	for i := 0; i <= retries; i++ { | ||||||
| 		go func(i int) { | 		go func(i int) { | ||||||
| 			s, err := call(i) | 			s, err := call(i) | ||||||
| 			ch <- response{s, err} | 			ch <- response{s, err} | ||||||
|   | |||||||
| @@ -88,32 +88,24 @@ func (rwc *readWriteCloser) Close() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func getHeaders(m *codec.Message) { | func getHeaders(m *codec.Message) { | ||||||
| 	get := func(hdr string) string { | 	set := func(v, hdr string) string { | ||||||
| 		if hd := m.Header[hdr]; len(hd) > 0 { | 		if len(v) > 0 { | ||||||
| 			return hd | 			return v | ||||||
| 		} | 		} | ||||||
| 		// old | 		return m.Header[hdr] | ||||||
| 		return m.Header["X-"+hdr] |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// check error in header | 	// check error in header | ||||||
| 	if len(m.Error) == 0 { | 	m.Error = set(m.Error, "Micro-Error") | ||||||
| 		m.Error = get("Micro-Error") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// check endpoint in header | 	// check endpoint in header | ||||||
| 	if len(m.Endpoint) == 0 { | 	m.Endpoint = set(m.Endpoint, "Micro-Endpoint") | ||||||
| 		m.Endpoint = get("Micro-Endpoint") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// check method in header | 	// check method in header | ||||||
| 	if len(m.Method) == 0 { | 	m.Method = set(m.Method, "Micro-Method") | ||||||
| 		m.Method = get("Micro-Method") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(m.Id) == 0 { | 	// set the request id | ||||||
| 		m.Id = get("Micro-Id") | 	m.Id = set(m.Id, "Micro-Id") | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func setHeaders(m *codec.Message, stream string) { | func setHeaders(m *codec.Message, stream string) { | ||||||
| @@ -122,7 +114,6 @@ func setHeaders(m *codec.Message, stream string) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		m.Header[hdr] = v | 		m.Header[hdr] = v | ||||||
| 		m.Header["X-"+hdr] = v |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	set("Micro-Id", m.Id) | 	set("Micro-Id", m.Id) | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -6,8 +6,6 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/client" | 	"github.com/micro/go-micro/client" | ||||||
| 	"github.com/micro/go-micro/server" | 	"github.com/micro/go-micro/server" | ||||||
| 	"github.com/micro/go-micro/transport" |  | ||||||
| 	"github.com/micro/go-micro/tunnel" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| @@ -56,14 +54,6 @@ type Network interface { | |||||||
| 	Server() server.Server | 	Server() server.Server | ||||||
| } | } | ||||||
|  |  | ||||||
| // message is network message |  | ||||||
| type message struct { |  | ||||||
| 	// msg is transport message |  | ||||||
| 	msg *transport.Message |  | ||||||
| 	// session is tunnel session |  | ||||||
| 	session tunnel.Session |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewNetwork returns a new network interface | // NewNetwork returns a new network interface | ||||||
| func NewNetwork(opts ...Option) Network { | func NewNetwork(opts ...Option) Network { | ||||||
| 	return newNetwork(opts...) | 	return newNetwork(opts...) | ||||||
|   | |||||||
| @@ -216,7 +216,7 @@ func (n *node) DeletePeerNode(id string) error { | |||||||
|  |  | ||||||
| // PruneStalePeerNodes prune the peers that have not been seen for longer than given time | // PruneStalePeerNodes prune the peers that have not been seen for longer than given time | ||||||
| // It returns a map of the the nodes that got pruned | // It returns a map of the the nodes that got pruned | ||||||
| func (n *node) PruneStalePeerNodes(pruneTime time.Duration) map[string]*node { | func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node { | ||||||
| 	n.Lock() | 	n.Lock() | ||||||
| 	defer n.Unlock() | 	defer n.Unlock() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -225,7 +225,7 @@ func TestPruneStalePeerNodes(t *testing.T) { | |||||||
| 	time.Sleep(pruneTime) | 	time.Sleep(pruneTime) | ||||||
|  |  | ||||||
| 	// should delete all nodes besides node | 	// should delete all nodes besides node | ||||||
| 	pruned := node.PruneStalePeerNodes(pruneTime) | 	pruned := node.PruneStalePeers(pruneTime) | ||||||
|  |  | ||||||
| 	if len(pruned) != len(nodes)-1 { | 	if len(pruned) != len(nodes)-1 { | ||||||
| 		t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned)) | 		t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned)) | ||||||
|   | |||||||
| @@ -22,8 +22,8 @@ type Options struct { | |||||||
| 	Address string | 	Address string | ||||||
| 	// Advertise sets the address to advertise | 	// Advertise sets the address to advertise | ||||||
| 	Advertise string | 	Advertise string | ||||||
| 	// Peers is a list of peers to connect to | 	// Nodes is a list of nodes to connect to | ||||||
| 	Peers []string | 	Nodes []string | ||||||
| 	// Tunnel is network tunnel | 	// Tunnel is network tunnel | ||||||
| 	Tunnel tunnel.Tunnel | 	Tunnel tunnel.Tunnel | ||||||
| 	// Router is network router | 	// Router is network router | ||||||
| @@ -62,10 +62,10 @@ func Advertise(a string) Option { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Peers is a list of peers to connect to | // Nodes is a list of nodes to connect to | ||||||
| func Peers(n ...string) Option { | func Nodes(n ...string) Option { | ||||||
| 	return func(o *Options) { | 	return func(o *Options) { | ||||||
| 		o.Peers = n | 		o.Nodes = n | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -31,6 +31,17 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { | |||||||
| 		r.Address = "1.0.0.1:53" | 		r.Address = "1.0.0.1:53" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	//nolint:prealloc | ||||||
|  | 	var records []*resolver.Record | ||||||
|  |  | ||||||
|  | 	// parsed an actual ip | ||||||
|  | 	if v := net.ParseIP(host); v != nil { | ||||||
|  | 		records = append(records, &resolver.Record{ | ||||||
|  | 			Address: net.JoinHostPort(host, port), | ||||||
|  | 		}) | ||||||
|  | 		return records, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	m := new(dns.Msg) | 	m := new(dns.Msg) | ||||||
| 	m.SetQuestion(dns.Fqdn(host), dns.TypeA) | 	m.SetQuestion(dns.Fqdn(host), dns.TypeA) | ||||||
| 	rec, err := dns.ExchangeContext(context.Background(), m, r.Address) | 	rec, err := dns.ExchangeContext(context.Background(), m, r.Address) | ||||||
| @@ -38,9 +49,6 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	//nolint:prealloc |  | ||||||
| 	var records []*resolver.Record |  | ||||||
|  |  | ||||||
| 	for _, answer := range rec.Answer { | 	for _, answer := range rec.Answer { | ||||||
| 		h := answer.Header() | 		h := answer.Header() | ||||||
| 		// check record type matches | 		// check record type matches | ||||||
| @@ -59,5 +67,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// no records returned so just best effort it | ||||||
|  | 	if len(records) == 0 { | ||||||
|  | 		records = append(records, &resolver.Record{ | ||||||
|  | 			Address: net.JoinHostPort(host, port), | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return records, nil | 	return records, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -55,7 +55,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp * | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// get list of existing nodes | 	// get list of existing nodes | ||||||
| 	nodes := n.Network.Options().Peers | 	nodes := n.Network.Options().Nodes | ||||||
|  |  | ||||||
| 	// generate a node map | 	// generate a node map | ||||||
| 	nodeMap := make(map[string]bool) | 	nodeMap := make(map[string]bool) | ||||||
| @@ -84,7 +84,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp * | |||||||
|  |  | ||||||
| 	// reinitialise the peers | 	// reinitialise the peers | ||||||
| 	n.Network.Init( | 	n.Network.Init( | ||||||
| 		network.Peers(nodes...), | 		network.Nodes(nodes...), | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	// call the connect method | 	// call the connect method | ||||||
|   | |||||||
| @@ -83,7 +83,8 @@ func readLoop(r server.Request, s client.Stream) error { | |||||||
|  |  | ||||||
| // toNodes returns a list of node addresses from given routes | // toNodes returns a list of node addresses from given routes | ||||||
| func toNodes(routes []router.Route) []string { | func toNodes(routes []router.Route) []string { | ||||||
| 	nodes := make([]string, len(routes)) | 	nodes := make([]string, 0, len(routes)) | ||||||
|  |  | ||||||
| 	for _, node := range routes { | 	for _, node := range routes { | ||||||
| 		address := node.Address | 		address := node.Address | ||||||
| 		if len(node.Gateway) > 0 { | 		if len(node.Gateway) > 0 { | ||||||
| @@ -91,11 +92,13 @@ func toNodes(routes []router.Route) []string { | |||||||
| 		} | 		} | ||||||
| 		nodes = append(nodes, address) | 		nodes = append(nodes, address) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nodes | 	return nodes | ||||||
| } | } | ||||||
|  |  | ||||||
| func toSlice(r map[uint64]router.Route) []router.Route { | func toSlice(r map[uint64]router.Route) []router.Route { | ||||||
| 	routes := make([]router.Route, 0, len(r)) | 	routes := make([]router.Route, 0, len(r)) | ||||||
|  |  | ||||||
| 	for _, v := range r { | 	for _, v := range r { | ||||||
| 		routes = append(routes, v) | 		routes = append(routes, v) | ||||||
| 	} | 	} | ||||||
| @@ -161,6 +164,8 @@ func (p *Proxy) filterRoutes(ctx context.Context, routes []router.Route) []route | |||||||
| 		filteredRoutes = append(filteredRoutes, route) | 		filteredRoutes = append(filteredRoutes, route) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	log.Tracef("Proxy filtered routes %+v\n", filteredRoutes) | ||||||
|  |  | ||||||
| 	return filteredRoutes | 	return filteredRoutes | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -225,13 +230,15 @@ func (p *Proxy) cacheRoutes(service string) ([]router.Route, error) { | |||||||
| // refreshMetrics will refresh any metrics for our local cached routes. | // refreshMetrics will refresh any metrics for our local cached routes. | ||||||
| // we may not receive new watch events for these as they change. | // we may not receive new watch events for these as they change. | ||||||
| func (p *Proxy) refreshMetrics() { | func (p *Proxy) refreshMetrics() { | ||||||
| 	services := make([]string, 0, len(p.Routes)) |  | ||||||
|  |  | ||||||
| 	// get a list of services to update | 	// get a list of services to update | ||||||
| 	p.RLock() | 	p.RLock() | ||||||
|  |  | ||||||
|  | 	services := make([]string, 0, len(p.Routes)) | ||||||
|  |  | ||||||
| 	for service := range p.Routes { | 	for service := range p.Routes { | ||||||
| 		services = append(services, service) | 		services = append(services, service) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	p.RUnlock() | 	p.RUnlock() | ||||||
|  |  | ||||||
| 	// get and cache the routes for the service | 	// get and cache the routes for the service | ||||||
| @@ -246,6 +253,8 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error { | |||||||
| 	p.Lock() | 	p.Lock() | ||||||
| 	defer p.Unlock() | 	defer p.Unlock() | ||||||
|  |  | ||||||
|  | 	log.Tracef("Proxy taking route action %v %+v\n", action, route) | ||||||
|  |  | ||||||
| 	switch action { | 	switch action { | ||||||
| 	case "create", "update": | 	case "create", "update": | ||||||
| 		if _, ok := p.Routes[route.Service]; !ok { | 		if _, ok := p.Routes[route.Service]; !ok { | ||||||
| @@ -253,7 +262,12 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error { | |||||||
| 		} | 		} | ||||||
| 		p.Routes[route.Service][route.Hash()] = route | 		p.Routes[route.Service][route.Hash()] = route | ||||||
| 	case "delete": | 	case "delete": | ||||||
|  | 		// delete that specific route | ||||||
| 		delete(p.Routes[route.Service], route.Hash()) | 		delete(p.Routes[route.Service], route.Hash()) | ||||||
|  | 		// clean up the cache entirely | ||||||
|  | 		if len(p.Routes[route.Service]) == 0 { | ||||||
|  | 			delete(p.Routes, route.Service) | ||||||
|  | 		} | ||||||
| 	default: | 	default: | ||||||
| 		return fmt.Errorf("unknown action: %s", action) | 		return fmt.Errorf("unknown action: %s", action) | ||||||
| 	} | 	} | ||||||
| @@ -288,7 +302,7 @@ func (p *Proxy) ProcessMessage(ctx context.Context, msg server.Message) error { | |||||||
| 	// TODO: check that we're not broadcast storming by sending to the same topic | 	// TODO: check that we're not broadcast storming by sending to the same topic | ||||||
| 	// that we're actually subscribed to | 	// that we're actually subscribed to | ||||||
|  |  | ||||||
| 	log.Tracef("Received message for %s", msg.Topic()) | 	log.Tracef("Proxy received message for %s", msg.Topic()) | ||||||
|  |  | ||||||
| 	var errors []string | 	var errors []string | ||||||
|  |  | ||||||
| @@ -329,7 +343,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server | |||||||
| 		return errors.BadRequest("go.micro.proxy", "service name is blank") | 		return errors.BadRequest("go.micro.proxy", "service name is blank") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Tracef("Received request for %s", service) | 	log.Tracef("Proxy received request for %s", service) | ||||||
|  |  | ||||||
| 	// are we network routing or local routing | 	// are we network routing or local routing | ||||||
| 	if len(p.Links) == 0 { | 	if len(p.Links) == 0 { | ||||||
| @@ -363,15 +377,17 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	//nolint:prealloc | 	//nolint:prealloc | ||||||
| 	var opts []client.CallOption | 	opts := []client.CallOption{ | ||||||
|  |  | ||||||
| 		// set strategy to round robin | 		// set strategy to round robin | ||||||
| 	opts = append(opts, client.WithSelectOption(selector.WithStrategy(selector.RoundRobin))) | 		client.WithSelectOption(selector.WithStrategy(selector.RoundRobin)), | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// if the address is already set just serve it | 	// if the address is already set just serve it | ||||||
| 	// TODO: figure it out if we should know to pick a link | 	// TODO: figure it out if we should know to pick a link | ||||||
| 	if len(addresses) > 0 { | 	if len(addresses) > 0 { | ||||||
| 		opts = append(opts, client.WithAddress(addresses...)) | 		opts = append(opts, | ||||||
|  | 			client.WithAddress(addresses...), | ||||||
|  | 		) | ||||||
|  |  | ||||||
| 		// serve the normal way | 		// serve the normal way | ||||||
| 		return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) | 		return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) | ||||||
| @@ -387,10 +403,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server | |||||||
| 			opts = append(opts, client.WithAddress(addresses...)) | 			opts = append(opts, client.WithAddress(addresses...)) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		log.Tracef("Proxy calling %+v\n", addresses) | ||||||
| 		// serve the normal way | 		// serve the normal way | ||||||
| 		return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) | 		return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// we're assuming we need routes to operate on | ||||||
|  | 	if len(routes) == 0 { | ||||||
|  | 		return errors.InternalServerError("go.micro.proxy", "route not found") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	var gerr error | 	var gerr error | ||||||
|  |  | ||||||
| 	// we're routing globally with multiple links | 	// we're routing globally with multiple links | ||||||
| @@ -404,11 +426,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debugf("Proxy using route %+v\n", route) | 		log.Tracef("Proxy using route %+v\n", route) | ||||||
|  |  | ||||||
| 		// set the address to call | 		// set the address to call | ||||||
| 		addresses := toNodes([]router.Route{route}) | 		addresses := toNodes([]router.Route{route}) | ||||||
| 		opts = append(opts, client.WithAddress(addresses...)) | 		// set the address in the options | ||||||
|  | 		// disable retries since its one route processing | ||||||
|  | 		opts = append(opts, | ||||||
|  | 			client.WithAddress(addresses...), | ||||||
|  | 			client.WithRetries(0), | ||||||
|  | 		) | ||||||
|  |  | ||||||
| 		// do the request with the link | 		// do the request with the link | ||||||
| 		gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...) | 		gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...) | ||||||
| @@ -558,7 +585,9 @@ func NewProxy(opts ...options.Option) proxy.Proxy { | |||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		t := time.NewTicker(time.Minute) | 		// TODO: speed up refreshing of metrics | ||||||
|  | 		// without this ticking effort e.g stream | ||||||
|  | 		t := time.NewTicker(time.Second * 10) | ||||||
| 		defer t.Stop() | 		defer t.Stop() | ||||||
|  |  | ||||||
| 		// we must refresh route metrics since they do not trigger new events | 		// we must refresh route metrics since they do not trigger new events | ||||||
|   | |||||||
| @@ -799,7 +799,8 @@ func (r *router) flushRouteEvents(evType EventType) ([]*Event, error) { | |||||||
|  |  | ||||||
| 	// build a list of events to advertise | 	// build a list of events to advertise | ||||||
| 	events := make([]*Event, len(bestRoutes)) | 	events := make([]*Event, len(bestRoutes)) | ||||||
| 	i := 0 | 	var i int | ||||||
|  |  | ||||||
| 	for _, route := range bestRoutes { | 	for _, route := range bestRoutes { | ||||||
| 		event := &Event{ | 		event := &Event{ | ||||||
| 			Type:      evType, | 			Type:      evType, | ||||||
| @@ -823,9 +824,10 @@ func (r *router) Solicit() error { | |||||||
|  |  | ||||||
| 	// advertise the routes | 	// advertise the routes | ||||||
| 	r.advertWg.Add(1) | 	r.advertWg.Add(1) | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		defer r.advertWg.Done() | 		r.publishAdvert(Solicitation, events) | ||||||
| 		r.publishAdvert(RouteUpdate, events) | 		r.advertWg.Done() | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
|   | |||||||
| @@ -111,6 +111,8 @@ const ( | |||||||
| 	Announce AdvertType = iota | 	Announce AdvertType = iota | ||||||
| 	// RouteUpdate advertises route updates | 	// RouteUpdate advertises route updates | ||||||
| 	RouteUpdate | 	RouteUpdate | ||||||
|  | 	// Solicitation indicates routes were solicited | ||||||
|  | 	Solicitation | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // String returns human readable advertisement type | // String returns human readable advertisement type | ||||||
| @@ -120,6 +122,8 @@ func (t AdvertType) String() string { | |||||||
| 		return "announce" | 		return "announce" | ||||||
| 	case RouteUpdate: | 	case RouteUpdate: | ||||||
| 		return "update" | 		return "update" | ||||||
|  | 	case Solicitation: | ||||||
|  | 		return "solicitation" | ||||||
| 	default: | 	default: | ||||||
| 		return "unknown" | 		return "unknown" | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -86,24 +86,18 @@ func getHeader(hdr string, md map[string]string) string { | |||||||
| } | } | ||||||
|  |  | ||||||
| func getHeaders(m *codec.Message) { | func getHeaders(m *codec.Message) { | ||||||
| 	get := func(hdr, v string) string { | 	set := func(v, hdr string) string { | ||||||
| 		if len(v) > 0 { | 		if len(v) > 0 { | ||||||
| 			return v | 			return v | ||||||
| 		} | 		} | ||||||
|  | 		return m.Header[hdr] | ||||||
| 		if hd := m.Header[hdr]; len(hd) > 0 { |  | ||||||
| 			return hd |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 		// old | 	m.Id = set(m.Id, "Micro-Id") | ||||||
| 		return m.Header["X-"+hdr] | 	m.Error = set(m.Error, "Micro-Error") | ||||||
| 	} | 	m.Endpoint = set(m.Endpoint, "Micro-Endpoint") | ||||||
|  | 	m.Method = set(m.Method, "Micro-Method") | ||||||
| 	m.Id = get("Micro-Id", m.Id) | 	m.Target = set(m.Target, "Micro-Service") | ||||||
| 	m.Error = get("Micro-Error", m.Error) |  | ||||||
| 	m.Endpoint = get("Micro-Endpoint", m.Endpoint) |  | ||||||
| 	m.Method = get("Micro-Method", m.Method) |  | ||||||
| 	m.Target = get("Micro-Service", m.Target) |  | ||||||
|  |  | ||||||
| 	// TODO: remove this cruft | 	// TODO: remove this cruft | ||||||
| 	if len(m.Endpoint) == 0 { | 	if len(m.Endpoint) == 0 { | ||||||
| @@ -321,7 +315,6 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error { | |||||||
|  |  | ||||||
| 		// write an error if it failed | 		// write an error if it failed | ||||||
| 		m.Error = errors.Wrapf(err, "Unable to encode body").Error() | 		m.Error = errors.Wrapf(err, "Unable to encode body").Error() | ||||||
| 		m.Header["X-Micro-Error"] = m.Error |  | ||||||
| 		m.Header["Micro-Error"] = m.Error | 		m.Header["Micro-Error"] = m.Error | ||||||
| 		// no body to write | 		// no body to write | ||||||
| 		if err := c.codec.Write(m, nil); err != nil { | 		if err := c.codec.Write(m, nil); err != nil { | ||||||
|   | |||||||
| @@ -549,6 +549,7 @@ func (s *rpcServer) Register() error { | |||||||
| 	node.Metadata["protocol"] = "mucp" | 	node.Metadata["protocol"] = "mucp" | ||||||
|  |  | ||||||
| 	s.RLock() | 	s.RLock() | ||||||
|  |  | ||||||
| 	// Maps are ordered randomly, sort the keys for consistency | 	// Maps are ordered randomly, sort the keys for consistency | ||||||
| 	var handlerList []string | 	var handlerList []string | ||||||
| 	for n, e := range s.handlers { | 	for n, e := range s.handlers { | ||||||
| @@ -557,6 +558,7 @@ func (s *rpcServer) Register() error { | |||||||
| 			handlerList = append(handlerList, n) | 			handlerList = append(handlerList, n) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	sort.Strings(handlerList) | 	sort.Strings(handlerList) | ||||||
|  |  | ||||||
| 	var subscriberList []Subscriber | 	var subscriberList []Subscriber | ||||||
| @@ -566,18 +568,20 @@ func (s *rpcServer) Register() error { | |||||||
| 			subscriberList = append(subscriberList, e) | 			subscriberList = append(subscriberList, e) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	sort.Slice(subscriberList, func(i, j int) bool { | 	sort.Slice(subscriberList, func(i, j int) bool { | ||||||
| 		return subscriberList[i].Topic() > subscriberList[j].Topic() | 		return subscriberList[i].Topic() > subscriberList[j].Topic() | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList)) | 	endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList)) | ||||||
|  |  | ||||||
| 	for _, n := range handlerList { | 	for _, n := range handlerList { | ||||||
| 		endpoints = append(endpoints, s.handlers[n].Endpoints()...) | 		endpoints = append(endpoints, s.handlers[n].Endpoints()...) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, e := range subscriberList { | 	for _, e := range subscriberList { | ||||||
| 		endpoints = append(endpoints, e.Endpoints()...) | 		endpoints = append(endpoints, e.Endpoints()...) | ||||||
| 	} | 	} | ||||||
| 	s.RUnlock() |  | ||||||
|  |  | ||||||
| 	service := ®istry.Service{ | 	service := ®istry.Service{ | ||||||
| 		Name:      config.Name, | 		Name:      config.Name, | ||||||
| @@ -586,9 +590,10 @@ func (s *rpcServer) Register() error { | |||||||
| 		Endpoints: endpoints, | 		Endpoints: endpoints, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.Lock() | 	// get registered value | ||||||
| 	registered := s.registered | 	registered := s.registered | ||||||
| 	s.Unlock() |  | ||||||
|  | 	s.RUnlock() | ||||||
|  |  | ||||||
| 	if !registered { | 	if !registered { | ||||||
| 		log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id) | 		log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id) | ||||||
| @@ -610,6 +615,8 @@ func (s *rpcServer) Register() error { | |||||||
| 	defer s.Unlock() | 	defer s.Unlock() | ||||||
|  |  | ||||||
| 	s.registered = true | 	s.registered = true | ||||||
|  | 	// set what we're advertising | ||||||
|  | 	s.opts.Advertise = addr | ||||||
|  |  | ||||||
| 	// subscribe to the topic with own name | 	// subscribe to the topic with own name | ||||||
| 	sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent) | 	sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent) | ||||||
|   | |||||||
| @@ -9,9 +9,9 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/client" | 	"github.com/micro/go-micro/client" | ||||||
| 	"github.com/micro/go-micro/config/cmd" | 	"github.com/micro/go-micro/config/cmd" | ||||||
| 	"github.com/micro/go-micro/debug/service/handler" |  | ||||||
| 	"github.com/micro/go-micro/debug/profile" | 	"github.com/micro/go-micro/debug/profile" | ||||||
| 	"github.com/micro/go-micro/debug/profile/pprof" | 	"github.com/micro/go-micro/debug/profile/pprof" | ||||||
|  | 	"github.com/micro/go-micro/debug/service/handler" | ||||||
| 	"github.com/micro/go-micro/plugin" | 	"github.com/micro/go-micro/plugin" | ||||||
| 	"github.com/micro/go-micro/server" | 	"github.com/micro/go-micro/server" | ||||||
| 	"github.com/micro/go-micro/util/log" | 	"github.com/micro/go-micro/util/log" | ||||||
|   | |||||||
| @@ -127,7 +127,6 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) { | |||||||
| 		closed:  make(chan bool), | 		closed:  make(chan bool), | ||||||
| 		recv:    make(chan *message, 128), | 		recv:    make(chan *message, 128), | ||||||
| 		send:    t.send, | 		send:    t.send, | ||||||
| 		wait:    make(chan bool), |  | ||||||
| 		errChan: make(chan error, 1), | 		errChan: make(chan error, 1), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -198,8 +197,8 @@ func (t *tun) announce(channel, session string, link *link) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // monitor monitors outbound links and attempts to reconnect to the failed ones | // manage monitors outbound links and attempts to reconnect to the failed ones | ||||||
| func (t *tun) monitor() { | func (t *tun) manage() { | ||||||
| 	reconnect := time.NewTicker(ReconnectTime) | 	reconnect := time.NewTicker(ReconnectTime) | ||||||
| 	defer reconnect.Stop() | 	defer reconnect.Stop() | ||||||
|  |  | ||||||
| @@ -208,9 +207,48 @@ func (t *tun) monitor() { | |||||||
| 		case <-t.closed: | 		case <-t.closed: | ||||||
| 			return | 			return | ||||||
| 		case <-reconnect.C: | 		case <-reconnect.C: | ||||||
|  | 			t.manageLinks() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // manageLink sends channel discover requests periodically and | ||||||
|  | // keepalive messages to link | ||||||
|  | func (t *tun) manageLink(link *link) { | ||||||
|  | 	keepalive := time.NewTicker(KeepAliveTime) | ||||||
|  | 	defer keepalive.Stop() | ||||||
|  | 	discover := time.NewTicker(DiscoverTime) | ||||||
|  | 	defer discover.Stop() | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-t.closed: | ||||||
|  | 			return | ||||||
|  | 		case <-link.closed: | ||||||
|  | 			return | ||||||
|  | 		case <-discover.C: | ||||||
|  | 			// send a discovery message to all links | ||||||
|  | 			if err := t.sendMsg("discover", link); err != nil { | ||||||
|  | 				log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err) | ||||||
|  | 			} | ||||||
|  | 		case <-keepalive.C: | ||||||
|  | 			// send keepalive message | ||||||
|  | 			log.Debugf("Tunnel sending keepalive to link: %v", link.Remote()) | ||||||
|  | 			if err := t.sendMsg("keepalive", link); err != nil { | ||||||
|  | 				log.Debugf("Tunnel error sending keepalive to link %v: %v", link.Remote(), err) | ||||||
|  | 				t.delLink(link.Remote()) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // manageLinks is a function that can be called to immediately to link setup | ||||||
|  | func (t *tun) manageLinks() { | ||||||
|  | 	var delLinks []string | ||||||
|  |  | ||||||
| 	t.RLock() | 	t.RLock() | ||||||
|  |  | ||||||
| 			var delLinks []string |  | ||||||
| 	// check the link status and purge dead links | 	// check the link status and purge dead links | ||||||
| 	for node, link := range t.links { | 	for node, link := range t.links { | ||||||
| 		// check link status | 		// check link status | ||||||
| @@ -242,27 +280,46 @@ func (t *tun) monitor() { | |||||||
|  |  | ||||||
| 	// build list of unknown nodes to connect to | 	// build list of unknown nodes to connect to | ||||||
| 	t.RLock() | 	t.RLock() | ||||||
|  |  | ||||||
| 	for _, node := range t.options.Nodes { | 	for _, node := range t.options.Nodes { | ||||||
| 		if _, ok := t.links[node]; !ok { | 		if _, ok := t.links[node]; !ok { | ||||||
| 			connect = append(connect, node) | 			connect = append(connect, node) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	t.RUnlock() | 	t.RUnlock() | ||||||
|  |  | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  |  | ||||||
| 	for _, node := range connect { | 	for _, node := range connect { | ||||||
|  | 		wg.Add(1) | ||||||
|  |  | ||||||
|  | 		go func(node string) { | ||||||
|  | 			defer wg.Done() | ||||||
|  |  | ||||||
| 			// create new link | 			// create new link | ||||||
| 			link, err := t.setupLink(node) | 			link, err := t.setupLink(node) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) | 				log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) | ||||||
| 					continue | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// save the link | 			// save the link | ||||||
| 			t.Lock() | 			t.Lock() | ||||||
|  | 			defer t.Unlock() | ||||||
|  |  | ||||||
|  | 			// just check nothing else was setup in the interim | ||||||
|  | 			if _, ok := t.links[node]; ok { | ||||||
|  | 				link.Close() | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			// save the link | ||||||
| 			t.links[node] = link | 			t.links[node] = link | ||||||
| 				t.Unlock() | 		}(node) | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// wait for all threads to finish | ||||||
|  | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
| // process outgoing messages sent by all local sessions | // process outgoing messages sent by all local sessions | ||||||
| @@ -328,6 +385,7 @@ func (t *tun) process() { | |||||||
| 				// and the message is being sent outbound via | 				// and the message is being sent outbound via | ||||||
| 				// a dialled connection don't use this link | 				// a dialled connection don't use this link | ||||||
| 				if loopback && msg.outbound { | 				if loopback && msg.outbound { | ||||||
|  | 					log.Tracef("Link for node %s is loopback", node) | ||||||
| 					err = errors.New("link is loopback") | 					err = errors.New("link is loopback") | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
| @@ -335,6 +393,7 @@ func (t *tun) process() { | |||||||
| 				// if the message was being returned by the loopback listener | 				// if the message was being returned by the loopback listener | ||||||
| 				// send it back up the loopback link only | 				// send it back up the loopback link only | ||||||
| 				if msg.loopback && !loopback { | 				if msg.loopback && !loopback { | ||||||
|  | 					log.Tracef("Link for message %s is loopback", node) | ||||||
| 					err = errors.New("link is not loopback") | 					err = errors.New("link is not loopback") | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
| @@ -364,7 +423,7 @@ func (t *tun) process() { | |||||||
| 			// send the message | 			// send the message | ||||||
| 			for _, link := range sendTo { | 			for _, link := range sendTo { | ||||||
| 				// send the message via the current link | 				// send the message via the current link | ||||||
| 				log.Tracef("Sending %+v to %s", newMsg.Header, link.Remote()) | 				log.Tracef("Tunnel sending %+v to %s", newMsg.Header, link.Remote()) | ||||||
|  |  | ||||||
| 				if errr := link.Send(newMsg); errr != nil { | 				if errr := link.Send(newMsg); errr != nil { | ||||||
| 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr) | 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr) | ||||||
| @@ -470,6 +529,9 @@ func (t *tun) listen(link *link) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// this state machine block handles the only message types | ||||||
|  | 		// that we know or care about; connect, close, open, accept, | ||||||
|  | 		// discover, announce, session, keepalive | ||||||
| 		switch mtype { | 		switch mtype { | ||||||
| 		case "connect": | 		case "connect": | ||||||
| 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | ||||||
| @@ -495,14 +557,14 @@ func (t *tun) listen(link *link) { | |||||||
| 			t.links[link.Remote()] = link | 			t.links[link.Remote()] = link | ||||||
| 			t.Unlock() | 			t.Unlock() | ||||||
|  |  | ||||||
| 			// send back a discovery | 			// send back an announcement of our channels discovery | ||||||
| 			go t.announce("", "", link) | 			go t.announce("", "", link) | ||||||
|  | 			// ask for the things on the other wise | ||||||
|  | 			go t.sendMsg("discover", link) | ||||||
| 			// nothing more to do | 			// nothing more to do | ||||||
| 			continue | 			continue | ||||||
| 		case "close": | 		case "close": | ||||||
| 			// TODO: handle the close message | 			log.Debugf("Tunnel link %s received close message", link.Remote()) | ||||||
| 			// maybe report io.EOF or kill the link |  | ||||||
|  |  | ||||||
| 			// if there is no channel then we close the link | 			// if there is no channel then we close the link | ||||||
| 			// as its a signal from the other side to close the connection | 			// as its a signal from the other side to close the connection | ||||||
| 			if len(channel) == 0 { | 			if len(channel) == 0 { | ||||||
| @@ -521,6 +583,8 @@ func (t *tun) listen(link *link) { | |||||||
| 			// try get the dialing socket | 			// try get the dialing socket | ||||||
| 			s, exists := t.getSession(channel, sessionId) | 			s, exists := t.getSession(channel, sessionId) | ||||||
| 			if exists && !loopback { | 			if exists && !loopback { | ||||||
|  | 				// only delete the session if its unicast | ||||||
|  | 				// otherwise ignore close on the multicast | ||||||
| 				if s.mode == Unicast { | 				if s.mode == Unicast { | ||||||
| 					// only delete this if its unicast | 					// only delete this if its unicast | ||||||
| 					// but not if its a loopback conn | 					// but not if its a loopback conn | ||||||
| @@ -541,20 +605,24 @@ func (t *tun) listen(link *link) { | |||||||
| 		// an accept returned by the listener | 		// an accept returned by the listener | ||||||
| 		case "accept": | 		case "accept": | ||||||
| 			s, exists := t.getSession(channel, sessionId) | 			s, exists := t.getSession(channel, sessionId) | ||||||
| 			// we don't need this | 			// just set accepted on anything not unicast | ||||||
| 			if exists && s.mode > Unicast { | 			if exists && s.mode > Unicast { | ||||||
| 				s.accepted = true | 				s.accepted = true | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  | 			// if its already accepted move on | ||||||
| 			if exists && s.accepted { | 			if exists && s.accepted { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  | 			// otherwise we're going to process to accept | ||||||
| 		// a continued session | 		// a continued session | ||||||
| 		case "session": | 		case "session": | ||||||
| 			// process message | 			// process message | ||||||
| 			log.Tracef("Received %+v from %s", msg.Header, link.Remote()) | 			log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote()) | ||||||
| 		// an announcement of a channel listener | 		// an announcement of a channel listener | ||||||
| 		case "announce": | 		case "announce": | ||||||
|  | 			log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote()) | ||||||
|  |  | ||||||
| 			// process the announcement | 			// process the announcement | ||||||
| 			channels := strings.Split(channel, ",") | 			channels := strings.Split(channel, ",") | ||||||
|  |  | ||||||
| @@ -562,7 +630,10 @@ func (t *tun) listen(link *link) { | |||||||
| 			link.setChannel(channels...) | 			link.setChannel(channels...) | ||||||
|  |  | ||||||
| 			// this was an announcement not intended for anything | 			// this was an announcement not intended for anything | ||||||
| 			if sessionId == "listener" || sessionId == "" { | 			// if the dialing side sent "discover" then a session | ||||||
|  | 			// id would be present. We skip in case of multicast. | ||||||
|  | 			switch sessionId { | ||||||
|  | 			case "listener", "multicast", "": | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -574,14 +645,19 @@ func (t *tun) listen(link *link) { | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// send the announce back to the caller | 				msg := &message{ | ||||||
| 				s.recv <- &message{ |  | ||||||
| 					typ:     "announce", | 					typ:     "announce", | ||||||
| 					tunnel:  id, | 					tunnel:  id, | ||||||
| 					channel: channel, | 					channel: channel, | ||||||
| 					session: sessionId, | 					session: sessionId, | ||||||
| 					link:    link.id, | 					link:    link.id, | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
|  | 				// send the announce back to the caller | ||||||
|  | 				select { | ||||||
|  | 				case <-s.closed: | ||||||
|  | 				case s.recv <- msg: | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			continue | 			continue | ||||||
| 		case "discover": | 		case "discover": | ||||||
| @@ -618,7 +694,7 @@ func (t *tun) listen(link *link) { | |||||||
| 			s, exists = t.getSession(channel, "listener") | 			s, exists = t.getSession(channel, "listener") | ||||||
| 		// only return accept to the session | 		// only return accept to the session | ||||||
| 		case mtype == "accept": | 		case mtype == "accept": | ||||||
| 			log.Debugf("Received accept message for %s %s", channel, sessionId) | 			log.Debugf("Tunnel received accept message for channel: %s session: %s", channel, sessionId) | ||||||
| 			s, exists = t.getSession(channel, sessionId) | 			s, exists = t.getSession(channel, sessionId) | ||||||
| 			if exists && s.accepted { | 			if exists && s.accepted { | ||||||
| 				continue | 				continue | ||||||
| @@ -638,7 +714,7 @@ func (t *tun) listen(link *link) { | |||||||
|  |  | ||||||
| 		// bail if no session or listener has been found | 		// bail if no session or listener has been found | ||||||
| 		if !exists { | 		if !exists { | ||||||
| 			log.Debugf("Tunnel skipping no session %s %s exists", channel, sessionId) | 			log.Tracef("Tunnel skipping no channel: %s session: %s exists", channel, sessionId) | ||||||
| 			// drop it, we don't care about | 			// drop it, we don't care about | ||||||
| 			// messages we don't know about | 			// messages we don't know about | ||||||
| 			continue | 			continue | ||||||
| @@ -651,22 +727,10 @@ func (t *tun) listen(link *link) { | |||||||
| 			delete(t.sessions, channel) | 			delete(t.sessions, channel) | ||||||
| 			continue | 			continue | ||||||
| 		default: | 		default: | ||||||
| 			// process | 			// otherwise process | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debugf("Tunnel using channel %s session %s", s.channel, s.session) | 		log.Tracef("Tunnel using channel: %s session: %s type: %s", s.channel, s.session, mtype) | ||||||
|  |  | ||||||
| 		// is the session new? |  | ||||||
| 		select { |  | ||||||
| 		// if its new the session is actually blocked waiting |  | ||||||
| 		// for a connection. so we check if its waiting. |  | ||||||
| 		case <-s.wait: |  | ||||||
| 		// if its waiting e.g its new then we close it |  | ||||||
| 		default: |  | ||||||
| 			// set remote address of the session |  | ||||||
| 			s.remote = msg.Header["Remote"] |  | ||||||
| 			close(s.wait) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// construct a new transport message | 		// construct a new transport message | ||||||
| 		tmsg := &transport.Message{ | 		tmsg := &transport.Message{ | ||||||
| @@ -696,68 +760,26 @@ func (t *tun) listen(link *link) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // discover sends channel discover requests periodically | func (t *tun) sendMsg(method string, link *link) error { | ||||||
| func (t *tun) discover(link *link) { | 	return link.Send(&transport.Message{ | ||||||
| 	tick := time.NewTicker(DiscoverTime) |  | ||||||
| 	defer tick.Stop() |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case <-tick.C: |  | ||||||
| 			// send a discovery message to all links |  | ||||||
| 			if err := link.Send(&transport.Message{ |  | ||||||
| 		Header: map[string]string{ | 		Header: map[string]string{ | ||||||
| 					"Micro-Tunnel":    "discover", | 			"Micro-Tunnel":    method, | ||||||
| 			"Micro-Tunnel-Id": t.id, | 			"Micro-Tunnel-Id": t.id, | ||||||
| 		}, | 		}, | ||||||
| 			}); err != nil { | 	}) | ||||||
| 				log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err) |  | ||||||
| 			} |  | ||||||
| 		case <-link.closed: |  | ||||||
| 			return |  | ||||||
| 		case <-t.closed: |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // keepalive periodically sends keepalive messages to link |  | ||||||
| func (t *tun) keepalive(link *link) { |  | ||||||
| 	keepalive := time.NewTicker(KeepAliveTime) |  | ||||||
| 	defer keepalive.Stop() |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case <-t.closed: |  | ||||||
| 			return |  | ||||||
| 		case <-link.closed: |  | ||||||
| 			return |  | ||||||
| 		case <-keepalive.C: |  | ||||||
| 			// send keepalive message |  | ||||||
| 			log.Debugf("Tunnel sending keepalive to link: %v", link.Remote()) |  | ||||||
| 			if err := link.Send(&transport.Message{ |  | ||||||
| 				Header: map[string]string{ |  | ||||||
| 					"Micro-Tunnel":    "keepalive", |  | ||||||
| 					"Micro-Tunnel-Id": t.id, |  | ||||||
| 				}, |  | ||||||
| 			}); err != nil { |  | ||||||
| 				log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) |  | ||||||
| 				t.delLink(link.Remote()) |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // setupLink connects to node and returns link if successful | // setupLink connects to node and returns link if successful | ||||||
| // It returns error if the link failed to be established | // It returns error if the link failed to be established | ||||||
| func (t *tun) setupLink(node string) (*link, error) { | func (t *tun) setupLink(node string) (*link, error) { | ||||||
| 	log.Debugf("Tunnel setting up link: %s", node) | 	log.Debugf("Tunnel setting up link: %s", node) | ||||||
|  |  | ||||||
| 	c, err := t.options.Transport.Dial(node) | 	c, err := t.options.Transport.Dial(node) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Debugf("Tunnel failed to connect to %s: %v", node, err) | 		log.Debugf("Tunnel failed to connect to %s: %v", node, err) | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debugf("Tunnel connected to %s", node) | 	log.Debugf("Tunnel connected to %s", node) | ||||||
|  |  | ||||||
| 	// create a new link | 	// create a new link | ||||||
| @@ -766,12 +788,8 @@ func (t *tun) setupLink(node string) (*link, error) { | |||||||
| 	link.id = c.Remote() | 	link.id = c.Remote() | ||||||
|  |  | ||||||
| 	// send the first connect message | 	// send the first connect message | ||||||
| 	if err := link.Send(&transport.Message{ | 	if err := t.sendMsg("connect", link); err != nil { | ||||||
| 		Header: map[string]string{ | 		link.Close() | ||||||
| 			"Micro-Tunnel":    "connect", |  | ||||||
| 			"Micro-Tunnel-Id": t.id, |  | ||||||
| 		}, |  | ||||||
| 	}); err != nil { |  | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -782,37 +800,40 @@ func (t *tun) setupLink(node string) (*link, error) { | |||||||
| 	// process incoming messages | 	// process incoming messages | ||||||
| 	go t.listen(link) | 	go t.listen(link) | ||||||
|  |  | ||||||
| 	// start keepalive monitor | 	// manage keepalives and discovery messages | ||||||
| 	go t.keepalive(link) | 	go t.manageLink(link) | ||||||
|  |  | ||||||
| 	// discover things on the remote side |  | ||||||
| 	go t.discover(link) |  | ||||||
|  |  | ||||||
| 	return link, nil | 	return link, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tun) setupLinks() { | func (t *tun) setupLinks() { | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  |  | ||||||
| 	for _, node := range t.options.Nodes { | 	for _, node := range t.options.Nodes { | ||||||
| 		// skip zero length nodes | 		wg.Add(1) | ||||||
| 		if len(node) == 0 { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// link already exists | 		go func(node string) { | ||||||
|  | 			defer wg.Done() | ||||||
|  |  | ||||||
|  | 			// we're not trying to fix existing links | ||||||
| 			if _, ok := t.links[node]; ok { | 			if _, ok := t.links[node]; ok { | ||||||
| 			continue | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 		// connect to node and return link | 			// create new link | ||||||
| 			link, err := t.setupLink(node) | 			link, err := t.setupLink(node) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 			log.Debugf("Tunnel failed to establish node link to %s: %v", node, err) | 				log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) | ||||||
| 			continue | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// save the link | 			// save the link | ||||||
| 			t.links[node] = link | 			t.links[node] = link | ||||||
|  | 		}(node) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// wait for all threads to finish | ||||||
|  | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
| // connect the tunnel to all the nodes and listen for incoming tunnel connections | // connect the tunnel to all the nodes and listen for incoming tunnel connections | ||||||
| @@ -833,11 +854,8 @@ func (t *tun) connect() error { | |||||||
| 			// create a new link | 			// create a new link | ||||||
| 			link := newLink(sock) | 			link := newLink(sock) | ||||||
|  |  | ||||||
| 			// start keepalive monitor | 			// manage the link | ||||||
| 			go t.keepalive(link) | 			go t.manageLink(link) | ||||||
|  |  | ||||||
| 			// discover things on the remote side |  | ||||||
| 			go t.discover(link) |  | ||||||
|  |  | ||||||
| 			// listen for inbound messages. | 			// listen for inbound messages. | ||||||
| 			// only save the link once connected. | 			// only save the link once connected. | ||||||
| @@ -864,12 +882,13 @@ func (t *tun) Connect() error { | |||||||
|  |  | ||||||
| 	// already connected | 	// already connected | ||||||
| 	if t.connected { | 	if t.connected { | ||||||
| 		// setup links | 		// do it immediately | ||||||
| 		t.setupLinks() | 		t.setupLinks() | ||||||
|  | 		// setup links | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// send the connect message | 	// connect the tunnel: start the listener | ||||||
| 	if err := t.connect(); err != nil { | 	if err := t.connect(); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -879,15 +898,15 @@ func (t *tun) Connect() error { | |||||||
| 	// create new close channel | 	// create new close channel | ||||||
| 	t.closed = make(chan bool) | 	t.closed = make(chan bool) | ||||||
|  |  | ||||||
| 	// setup links |  | ||||||
| 	t.setupLinks() |  | ||||||
|  |  | ||||||
| 	// process outbound messages to be sent | 	// process outbound messages to be sent | ||||||
| 	// process sends to all links | 	// process sends to all links | ||||||
| 	go t.process() | 	go t.process() | ||||||
|  |  | ||||||
| 	// monitor links | 	// call setup before managing them | ||||||
| 	go t.monitor() | 	t.setupLinks() | ||||||
|  |  | ||||||
|  | 	// manage the links | ||||||
|  | 	go t.manage() | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -1025,7 +1044,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 	// set the multicast option | 	// set the multicast option | ||||||
| 	c.mode = options.Mode | 	c.mode = options.Mode | ||||||
| 	// set the dial timeout | 	// set the dial timeout | ||||||
| 	c.timeout = options.Timeout | 	c.dialTimeout = options.Timeout | ||||||
|  | 	// set read timeout set to never | ||||||
|  | 	c.readTimeout = time.Duration(-1) | ||||||
|  |  | ||||||
| 	var links []*link | 	var links []*link | ||||||
| 	// did we measure the rtt | 	// did we measure the rtt | ||||||
| @@ -1052,7 +1073,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
|  |  | ||||||
| 	t.RUnlock() | 	t.RUnlock() | ||||||
|  |  | ||||||
| 	// link not found | 	// link not found and one was specified so error out | ||||||
| 	if len(links) == 0 && len(options.Link) > 0 { | 	if len(links) == 0 && len(options.Link) > 0 { | ||||||
| 		// delete session and return error | 		// delete session and return error | ||||||
| 		t.delSession(c.channel, c.session) | 		t.delSession(c.channel, c.session) | ||||||
| @@ -1061,15 +1082,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// discovered so set the link if not multicast | 	// discovered so set the link if not multicast | ||||||
| 	// TODO: pick the link efficiently based |  | ||||||
| 	// on link status and saturation. |  | ||||||
| 	if c.discovered && c.mode == Unicast { | 	if c.discovered && c.mode == Unicast { | ||||||
| 		// pickLink will pick the best link | 		// pickLink will pick the best link | ||||||
| 		link := t.pickLink(links) | 		link := t.pickLink(links) | ||||||
|  | 		// set the link | ||||||
| 		c.link = link.id | 		c.link = link.id | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// shit fuck | 	// if its not already discovered we need to attempt to do so | ||||||
| 	if !c.discovered { | 	if !c.discovered { | ||||||
| 		// piggy back roundtrip | 		// piggy back roundtrip | ||||||
| 		nowRTT := time.Now() | 		nowRTT := time.Now() | ||||||
| @@ -1098,7 +1118,15 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// a unicast session so we call "open" and wait for an "accept" | 	// return early if its not unicast | ||||||
|  | 	// we will not call "open" for multicast | ||||||
|  | 	if c.mode != Unicast { | ||||||
|  | 		return c, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Note: we go no further for multicast or broadcast. | ||||||
|  | 	// This is a unicast session so we call "open" and wait | ||||||
|  | 	// for an "accept" | ||||||
|  |  | ||||||
| 	// reset now in case we use it | 	// reset now in case we use it | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| @@ -1115,7 +1143,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| 	d := time.Since(now) | 	d := time.Since(now) | ||||||
|  |  | ||||||
| 	// if we haven't measured the roundtrip do it now | 	// if we haven't measured the roundtrip do it now | ||||||
| 	if !measured && c.mode == Unicast { | 	if !measured { | ||||||
| 		// set the link time | 		// set the link time | ||||||
| 		t.RLock() | 		t.RLock() | ||||||
| 		link, ok := t.links[c.link] | 		link, ok := t.links[c.link] | ||||||
| @@ -1134,7 +1162,11 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | |||||||
| func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { | func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { | ||||||
| 	log.Debugf("Tunnel listening on %s", channel) | 	log.Debugf("Tunnel listening on %s", channel) | ||||||
|  |  | ||||||
| 	var options ListenOptions | 	options := ListenOptions{ | ||||||
|  | 		// Read timeout defaults to never | ||||||
|  | 		Timeout: time.Duration(-1), | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	for _, o := range opts { | 	for _, o := range opts { | ||||||
| 		o(&options) | 		o(&options) | ||||||
| 	} | 	} | ||||||
| @@ -1145,6 +1177,7 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { | |||||||
| 		return nil, errors.New("already listening on " + channel) | 		return nil, errors.New("already listening on " + channel) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// delete function removes the session when closed | ||||||
| 	delFunc := func() { | 	delFunc := func() { | ||||||
| 		t.delSession(channel, "listener") | 		t.delSession(channel, "listener") | ||||||
| 	} | 	} | ||||||
| @@ -1155,6 +1188,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) { | |||||||
| 	c.local = channel | 	c.local = channel | ||||||
| 	// set mode | 	// set mode | ||||||
| 	c.mode = options.Mode | 	c.mode = options.Mode | ||||||
|  | 	// set the timeout | ||||||
|  | 	c.readTimeout = options.Timeout | ||||||
|  |  | ||||||
| 	tl := &tunListener{ | 	tl := &tunListener{ | ||||||
| 		channel: channel, | 		channel: channel, | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package tunnel | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -14,7 +15,11 @@ import ( | |||||||
| type link struct { | type link struct { | ||||||
| 	transport.Socket | 	transport.Socket | ||||||
|  |  | ||||||
|  | 	// transport to use for connections | ||||||
|  | 	transport transport.Transport | ||||||
|  |  | ||||||
| 	sync.RWMutex | 	sync.RWMutex | ||||||
|  |  | ||||||
| 	// stops the link | 	// stops the link | ||||||
| 	closed chan bool | 	closed chan bool | ||||||
| 	// link state channel for testing link | 	// link state channel for testing link | ||||||
| @@ -65,6 +70,8 @@ var ( | |||||||
| 	linkRequest = []byte{0, 0, 0, 0} | 	linkRequest = []byte{0, 0, 0, 0} | ||||||
| 	// the 4 byte 1 filled packet sent to determine link state | 	// the 4 byte 1 filled packet sent to determine link state | ||||||
| 	linkResponse = []byte{1, 1, 1, 1} | 	linkResponse = []byte{1, 1, 1, 1} | ||||||
|  |  | ||||||
|  | 	ErrLinkConnectTimeout = errors.New("link connect timeout") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func newLink(s transport.Socket) *link { | func newLink(s transport.Socket) *link { | ||||||
| @@ -72,8 +79,8 @@ func newLink(s transport.Socket) *link { | |||||||
| 		Socket:        s, | 		Socket:        s, | ||||||
| 		id:            uuid.New().String(), | 		id:            uuid.New().String(), | ||||||
| 		lastKeepAlive: time.Now(), | 		lastKeepAlive: time.Now(), | ||||||
| 		channels:      make(map[string]time.Time), |  | ||||||
| 		closed:        make(chan bool), | 		closed:        make(chan bool), | ||||||
|  | 		channels:      make(map[string]time.Time), | ||||||
| 		state:         make(chan *packet, 64), | 		state:         make(chan *packet, 64), | ||||||
| 		sendQueue:     make(chan *packet, 128), | 		sendQueue:     make(chan *packet, 128), | ||||||
| 		recvQueue:     make(chan *packet, 128), | 		recvQueue:     make(chan *packet, 128), | ||||||
| @@ -87,6 +94,32 @@ func newLink(s transport.Socket) *link { | |||||||
| 	return l | 	return l | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (l *link) connect(addr string) error { | ||||||
|  | 	c, err := l.transport.Dial(addr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	l.Lock() | ||||||
|  | 	l.Socket = c | ||||||
|  | 	l.Unlock() | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l *link) accept(sock transport.Socket) error { | ||||||
|  | 	l.Lock() | ||||||
|  | 	l.Socket = sock | ||||||
|  | 	l.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l *link) setLoopback(v bool) { | ||||||
|  | 	l.Lock() | ||||||
|  | 	l.loopback = v | ||||||
|  | 	l.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
| // setRate sets the bits per second rate as a float64 | // setRate sets the bits per second rate as a float64 | ||||||
| func (l *link) setRate(bits int64, delta time.Duration) { | func (l *link) setRate(bits int64, delta time.Duration) { | ||||||
| 	// rate of send in bits per nanosecond | 	// rate of send in bits per nanosecond | ||||||
| @@ -167,6 +200,8 @@ func (l *link) process() { | |||||||
| 				// process link state message | 				// process link state message | ||||||
| 				select { | 				select { | ||||||
| 				case l.state <- pk: | 				case l.state <- pk: | ||||||
|  | 				case <-l.closed: | ||||||
|  | 					return | ||||||
| 				default: | 				default: | ||||||
| 				} | 				} | ||||||
| 				continue | 				continue | ||||||
| @@ -188,7 +223,11 @@ func (l *link) process() { | |||||||
| 		select { | 		select { | ||||||
| 		case pk := <-l.sendQueue: | 		case pk := <-l.sendQueue: | ||||||
| 			// send the message | 			// send the message | ||||||
| 			pk.status <- l.send(pk.message) | 			select { | ||||||
|  | 			case pk.status <- l.send(pk.message): | ||||||
|  | 			case <-l.closed: | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
| 		case <-l.closed: | 		case <-l.closed: | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @@ -201,11 +240,15 @@ func (l *link) manage() { | |||||||
| 	t := time.NewTicker(time.Minute) | 	t := time.NewTicker(time.Minute) | ||||||
| 	defer t.Stop() | 	defer t.Stop() | ||||||
|  |  | ||||||
|  | 	// get link id | ||||||
|  | 	linkId := l.Id() | ||||||
|  |  | ||||||
| 	// used to send link state packets | 	// used to send link state packets | ||||||
| 	send := func(b []byte) error { | 	send := func(b []byte) error { | ||||||
| 		return l.Send(&transport.Message{ | 		return l.Send(&transport.Message{ | ||||||
| 			Header: map[string]string{ | 			Header: map[string]string{ | ||||||
| 				"Micro-Method":  "link", | 				"Micro-Method":  "link", | ||||||
|  | 				"Micro-Link-Id": linkId, | ||||||
| 			}, Body: b, | 			}, Body: b, | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| @@ -229,9 +272,7 @@ func (l *link) manage() { | |||||||
| 			// check the type of message | 			// check the type of message | ||||||
| 			switch { | 			switch { | ||||||
| 			case bytes.Equal(p.message.Body, linkRequest): | 			case bytes.Equal(p.message.Body, linkRequest): | ||||||
| 				l.RLock() | 				log.Tracef("Link %s received link request", linkId) | ||||||
| 				log.Tracef("Link %s received link request %v", l.id, p.message.Body) |  | ||||||
| 				l.RUnlock() |  | ||||||
|  |  | ||||||
| 				// send response | 				// send response | ||||||
| 				if err := send(linkResponse); err != nil { | 				if err := send(linkResponse); err != nil { | ||||||
| @@ -242,9 +283,7 @@ func (l *link) manage() { | |||||||
| 			case bytes.Equal(p.message.Body, linkResponse): | 			case bytes.Equal(p.message.Body, linkResponse): | ||||||
| 				// set round trip time | 				// set round trip time | ||||||
| 				d := time.Since(now) | 				d := time.Since(now) | ||||||
| 				l.RLock() | 				log.Tracef("Link %s received link response in %v", linkId, d) | ||||||
| 				log.Tracef("Link %s received link response in %v", p.message.Body, d) |  | ||||||
| 				l.RUnlock() |  | ||||||
| 				// set the RTT | 				// set the RTT | ||||||
| 				l.setRTT(d) | 				l.setRTT(d) | ||||||
| 			} | 			} | ||||||
| @@ -309,6 +348,12 @@ func (l *link) Rate() float64 { | |||||||
| 	return l.rate | 	return l.rate | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (l *link) Loopback() bool { | ||||||
|  | 	l.RLock() | ||||||
|  | 	defer l.RUnlock() | ||||||
|  | 	return l.loopback | ||||||
|  | } | ||||||
|  |  | ||||||
| // Length returns the roundtrip time as nanoseconds (lower is better). | // Length returns the roundtrip time as nanoseconds (lower is better). | ||||||
| // Returns 0 where no measurement has been taken. | // Returns 0 where no measurement has been taken. | ||||||
| func (l *link) Length() int64 { | func (l *link) Length() int64 { | ||||||
| @@ -320,7 +365,6 @@ func (l *link) Length() int64 { | |||||||
| func (l *link) Id() string { | func (l *link) Id() string { | ||||||
| 	l.RLock() | 	l.RLock() | ||||||
| 	defer l.RUnlock() | 	defer l.RUnlock() | ||||||
|  |  | ||||||
| 	return l.id | 	return l.id | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -350,13 +394,6 @@ func (l *link) Send(m *transport.Message) error { | |||||||
| 	// get time now | 	// get time now | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
|  |  | ||||||
| 	// check if its closed first |  | ||||||
| 	select { |  | ||||||
| 	case <-l.closed: |  | ||||||
| 		return io.EOF |  | ||||||
| 	default: |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// queue the message | 	// queue the message | ||||||
| 	select { | 	select { | ||||||
| 	case l.sendQueue <- p: | 	case l.sendQueue <- p: | ||||||
|   | |||||||
| @@ -24,7 +24,7 @@ type tunListener struct { | |||||||
| 	delFunc func() | 	delFunc func() | ||||||
| } | } | ||||||
|  |  | ||||||
| // periodically announce self | // periodically announce self the channel being listened on | ||||||
| func (t *tunListener) announce() { | func (t *tunListener) announce() { | ||||||
| 	tick := time.NewTicker(time.Second * 30) | 	tick := time.NewTicker(time.Second * 30) | ||||||
| 	defer tick.Stop() | 	defer tick.Stop() | ||||||
| @@ -48,9 +48,12 @@ func (t *tunListener) process() { | |||||||
|  |  | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		// close the sessions | 		// close the sessions | ||||||
| 		for _, conn := range conns { | 		for id, conn := range conns { | ||||||
| 			conn.Close() | 			conn.Close() | ||||||
|  | 			delete(conns, id) | ||||||
| 		} | 		} | ||||||
|  | 		// unassign | ||||||
|  | 		conns = nil | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| @@ -62,9 +65,24 @@ func (t *tunListener) process() { | |||||||
| 			return | 			return | ||||||
| 		// receive a new message | 		// receive a new message | ||||||
| 		case m := <-t.session.recv: | 		case m := <-t.session.recv: | ||||||
|  | 			var sessionId string | ||||||
|  | 			var linkId string | ||||||
|  |  | ||||||
|  | 			switch m.mode { | ||||||
|  | 			case Multicast: | ||||||
|  | 				sessionId = "multicast" | ||||||
|  | 				linkId = "multicast" | ||||||
|  | 			case Broadcast: | ||||||
|  | 				sessionId = "broadcast" | ||||||
|  | 				linkId = "broadcast" | ||||||
|  | 			default: | ||||||
|  | 				sessionId = m.session | ||||||
|  | 				linkId = m.link | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// get a session | 			// get a session | ||||||
| 			sess, ok := conns[m.session] | 			sess, ok := conns[sessionId] | ||||||
| 			log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok) | 			log.Tracef("Tunnel listener received channel %s session %s type %s exists: %t", m.channel, m.session, m.typ, ok) | ||||||
| 			if !ok { | 			if !ok { | ||||||
| 				// we only process open and session types | 				// we only process open and session types | ||||||
| 				switch m.typ { | 				switch m.typ { | ||||||
| @@ -80,13 +98,13 @@ func (t *tunListener) process() { | |||||||
| 					// the channel | 					// the channel | ||||||
| 					channel: m.channel, | 					channel: m.channel, | ||||||
| 					// the session id | 					// the session id | ||||||
| 					session: m.session, | 					session: sessionId, | ||||||
| 					// tunnel token | 					// tunnel token | ||||||
| 					token: t.token, | 					token: t.token, | ||||||
| 					// is loopback conn | 					// is loopback conn | ||||||
| 					loopback: m.loopback, | 					loopback: m.loopback, | ||||||
| 					// the link the message was received on | 					// the link the message was received on | ||||||
| 					link: m.link, | 					link: linkId, | ||||||
| 					// set the connection mode | 					// set the connection mode | ||||||
| 					mode: m.mode, | 					mode: m.mode, | ||||||
| 					// close chan | 					// close chan | ||||||
| @@ -95,14 +113,14 @@ func (t *tunListener) process() { | |||||||
| 					recv: make(chan *message, 128), | 					recv: make(chan *message, 128), | ||||||
| 					// use the internal send buffer | 					// use the internal send buffer | ||||||
| 					send: t.session.send, | 					send: t.session.send, | ||||||
| 					// wait |  | ||||||
| 					wait: make(chan bool), |  | ||||||
| 					// error channel | 					// error channel | ||||||
| 					errChan: make(chan error, 1), | 					errChan: make(chan error, 1), | ||||||
|  | 					// set the read timeout | ||||||
|  | 					readTimeout: t.session.readTimeout, | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// save the session | 				// save the session | ||||||
| 				conns[m.session] = sess | 				conns[sessionId] = sess | ||||||
|  |  | ||||||
| 				select { | 				select { | ||||||
| 				case <-t.closed: | 				case <-t.closed: | ||||||
| @@ -114,17 +132,19 @@ func (t *tunListener) process() { | |||||||
|  |  | ||||||
| 			// an existing session was found | 			// an existing session was found | ||||||
|  |  | ||||||
| 			// received a close message |  | ||||||
| 			switch m.typ { | 			switch m.typ { | ||||||
| 			case "close": | 			case "close": | ||||||
|  | 				// received a close message | ||||||
| 				select { | 				select { | ||||||
|  | 				// check if the session is closed | ||||||
| 				case <-sess.closed: | 				case <-sess.closed: | ||||||
| 					// no op | 					// no op | ||||||
| 					delete(conns, m.session) | 					delete(conns, sessionId) | ||||||
| 				default: | 				default: | ||||||
|  | 					// only close if unicast session | ||||||
| 					// close and delete session | 					// close and delete session | ||||||
| 					close(sess.closed) | 					close(sess.closed) | ||||||
| 					delete(conns, m.session) | 					delete(conns, sessionId) | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// continue | 				// continue | ||||||
| @@ -139,9 +159,9 @@ func (t *tunListener) process() { | |||||||
| 			// send this to the accept chan | 			// send this to the accept chan | ||||||
| 			select { | 			select { | ||||||
| 			case <-sess.closed: | 			case <-sess.closed: | ||||||
| 				delete(conns, m.session) | 				delete(conns, sessionId) | ||||||
| 			case sess.recv <- m: | 			case sess.recv <- m: | ||||||
| 				log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session) | 				log.Tracef("Tunnel listener sent to recv chan channel %s session %s type %s", m.channel, sessionId, m.typ) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -180,6 +200,10 @@ func (t *tunListener) Accept() (Session, error) { | |||||||
| 		if !ok { | 		if !ok { | ||||||
| 			return nil, io.EOF | 			return nil, io.EOF | ||||||
| 		} | 		} | ||||||
|  | 		// return without accept | ||||||
|  | 		if c.mode != Unicast { | ||||||
|  | 			return c, nil | ||||||
|  | 		} | ||||||
| 		// send back the accept | 		// send back the accept | ||||||
| 		if err := c.Accept(); err != nil { | 		if err := c.Accept(); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
|   | |||||||
| @@ -47,6 +47,8 @@ type ListenOption func(*ListenOptions) | |||||||
| type ListenOptions struct { | type ListenOptions struct { | ||||||
| 	// specify mode of the session | 	// specify mode of the session | ||||||
| 	Mode Mode | 	Mode Mode | ||||||
|  | 	// The read timeout | ||||||
|  | 	Timeout time.Duration | ||||||
| } | } | ||||||
|  |  | ||||||
| // The tunnel id | // The tunnel id | ||||||
| @@ -84,16 +86,6 @@ func Transport(t transport.Transport) Option { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // DefaultOptions returns router default options |  | ||||||
| func DefaultOptions() Options { |  | ||||||
| 	return Options{ |  | ||||||
| 		Id:        uuid.New().String(), |  | ||||||
| 		Address:   DefaultAddress, |  | ||||||
| 		Token:     DefaultToken, |  | ||||||
| 		Transport: quic.NewTransport(), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Listen options | // Listen options | ||||||
| func ListenMode(m Mode) ListenOption { | func ListenMode(m Mode) ListenOption { | ||||||
| 	return func(o *ListenOptions) { | 	return func(o *ListenOptions) { | ||||||
| @@ -101,6 +93,13 @@ func ListenMode(m Mode) ListenOption { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Timeout for reads and writes on the listener session | ||||||
|  | func ListenTimeout(t time.Duration) ListenOption { | ||||||
|  | 	return func(o *ListenOptions) { | ||||||
|  | 		o.Timeout = t | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // Dial options | // Dial options | ||||||
|  |  | ||||||
| // Dial multicast sets the multicast option to send only to those mapped | // Dial multicast sets the multicast option to send only to those mapped | ||||||
| @@ -124,3 +123,13 @@ func DialLink(id string) DialOption { | |||||||
| 		o.Link = id | 		o.Link = id | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // DefaultOptions returns router default options | ||||||
|  | func DefaultOptions() Options { | ||||||
|  | 	return Options{ | ||||||
|  | 		Id:        uuid.New().String(), | ||||||
|  | 		Address:   DefaultAddress, | ||||||
|  | 		Token:     DefaultToken, | ||||||
|  | 		Transport: quic.NewTransport(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,7 +2,6 @@ package tunnel | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"encoding/hex" | 	"encoding/hex" | ||||||
| 	"errors" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -30,8 +29,6 @@ type session struct { | |||||||
| 	send chan *message | 	send chan *message | ||||||
| 	// recv chan | 	// recv chan | ||||||
| 	recv chan *message | 	recv chan *message | ||||||
| 	// wait until we have a connection |  | ||||||
| 	wait chan bool |  | ||||||
| 	// if the discovery worked | 	// if the discovery worked | ||||||
| 	discovered bool | 	discovered bool | ||||||
| 	// if the session was accepted | 	// if the session was accepted | ||||||
| @@ -42,8 +39,10 @@ type session struct { | |||||||
| 	loopback bool | 	loopback bool | ||||||
| 	// mode of the connection | 	// mode of the connection | ||||||
| 	mode Mode | 	mode Mode | ||||||
| 	// the timeout | 	// the dial timeout | ||||||
| 	timeout time.Duration | 	dialTimeout time.Duration | ||||||
|  | 	// the read timeout | ||||||
|  | 	readTimeout time.Duration | ||||||
| 	// the link on which this message was received | 	// the link on which this message was received | ||||||
| 	link string | 	link string | ||||||
| 	// the error response | 	// the error response | ||||||
| @@ -109,65 +108,114 @@ func (s *session) newMessage(typ string) *message { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *session) sendMsg(msg *message) error { | ||||||
|  | 	select { | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	case s.send <- msg: | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *session) wait(msg *message) error { | ||||||
|  | 	// wait for an error response | ||||||
|  | 	select { | ||||||
|  | 	case err := <-msg.errChan: | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| // waitFor waits for the message type required until the timeout specified | // waitFor waits for the message type required until the timeout specified | ||||||
| func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { | func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
|  |  | ||||||
| 	after := func(timeout time.Duration) time.Duration { | 	after := func(timeout time.Duration) <-chan time.Time { | ||||||
|  | 		if timeout < time.Duration(0) { | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// get the delta | ||||||
| 		d := time.Since(now) | 		d := time.Since(now) | ||||||
|  |  | ||||||
| 		// dial timeout minus time since | 		// dial timeout minus time since | ||||||
| 		wait := timeout - d | 		wait := timeout - d | ||||||
|  |  | ||||||
| 		if wait < time.Duration(0) { | 		if wait < time.Duration(0) { | ||||||
| 			return time.Duration(0) | 			wait = time.Duration(0) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return wait | 		return time.After(wait) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// wait for the message type | 	// wait for the message type | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case msg := <-s.recv: | 		case msg := <-s.recv: | ||||||
|  | 			// there may be no message type | ||||||
|  | 			if len(msgType) == 0 { | ||||||
|  | 				return msg, nil | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// ignore what we don't want | 			// ignore what we don't want | ||||||
| 			if msg.typ != msgType { | 			if msg.typ != msgType { | ||||||
| 				log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) | 				log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// got the message | 			// got the message | ||||||
| 			return msg, nil | 			return msg, nil | ||||||
| 		case <-time.After(after(timeout)): | 		case <-after(timeout): | ||||||
| 			return nil, ErrDialTimeout | 			return nil, ErrReadTimeout | ||||||
| 		case <-s.closed: | 		case <-s.closed: | ||||||
| 			return nil, io.EOF | 			return nil, io.EOF | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Discover attempts to discover the link for a specific channel | // Discover attempts to discover the link for a specific channel. | ||||||
|  | // This is only used by the tunnel.Dial when first connecting. | ||||||
| func (s *session) Discover() error { | func (s *session) Discover() error { | ||||||
| 	// create a new discovery message for this channel | 	// create a new discovery message for this channel | ||||||
| 	msg := s.newMessage("discover") | 	msg := s.newMessage("discover") | ||||||
|  | 	// broadcast the message to all links | ||||||
| 	msg.mode = Broadcast | 	msg.mode = Broadcast | ||||||
|  | 	// its an outbound connection since we're dialling | ||||||
| 	msg.outbound = true | 	msg.outbound = true | ||||||
|  | 	// don't set the link since we don't know where it is | ||||||
| 	msg.link = "" | 	msg.link = "" | ||||||
|  |  | ||||||
| 	// send the discovery message | 	// if multicast then set that as session | ||||||
| 	s.send <- msg | 	if s.mode == Multicast { | ||||||
|  | 		msg.session = "multicast" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// send discover message | ||||||
|  | 	if err := s.sendMsg(msg); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// set time now | 	// set time now | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
|  |  | ||||||
|  | 	// after strips down the dial timeout | ||||||
| 	after := func() time.Duration { | 	after := func() time.Duration { | ||||||
| 		d := time.Since(now) | 		d := time.Since(now) | ||||||
| 		// dial timeout minus time since | 		// dial timeout minus time since | ||||||
| 		wait := s.timeout - d | 		wait := s.dialTimeout - d | ||||||
|  | 		// make sure its always > 0 | ||||||
| 		if wait < time.Duration(0) { | 		if wait < time.Duration(0) { | ||||||
| 			return time.Duration(0) | 			return time.Duration(0) | ||||||
| 		} | 		} | ||||||
| 		return wait | 		return wait | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// the discover message is sent out, now | ||||||
| 	// wait to hear back about the sent message | 	// wait to hear back about the sent message | ||||||
| 	select { | 	select { | ||||||
| 	case <-time.After(after()): | 	case <-time.After(after()): | ||||||
| @@ -178,27 +226,16 @@ func (s *session) Discover() error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var err error | 	// bail early if its not unicast | ||||||
|  | 	// we don't need to wait for the announce | ||||||
| 	// set a new dialTimeout |  | ||||||
| 	dialTimeout := after() |  | ||||||
|  |  | ||||||
| 	// set a shorter delay for multicast |  | ||||||
| 	if s.mode != Unicast { |  | ||||||
| 		// shorten this |  | ||||||
| 		dialTimeout = time.Millisecond * 500 |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// wait for announce |  | ||||||
| 	_, err = s.waitFor("announce", dialTimeout) |  | ||||||
|  |  | ||||||
| 	// if its multicast just go ahead because this is best effort |  | ||||||
| 	if s.mode != Unicast { | 	if s.mode != Unicast { | ||||||
| 		s.discovered = true | 		s.discovered = true | ||||||
| 		s.accepted = true | 		s.accepted = true | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// wait for announce | ||||||
|  | 	_, err := s.waitFor("announce", after()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -210,31 +247,23 @@ func (s *session) Discover() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Open will fire the open message for the session. This is called by the dialler. | // Open will fire the open message for the session. This is called by the dialler. | ||||||
|  | // This is to indicate that we want to create a new session. | ||||||
| func (s *session) Open() error { | func (s *session) Open() error { | ||||||
| 	// create a new message | 	// create a new message | ||||||
| 	msg := s.newMessage("open") | 	msg := s.newMessage("open") | ||||||
|  |  | ||||||
| 	// send open message | 	// send open message | ||||||
| 	s.send <- msg | 	if err := s.sendMsg(msg); err != nil { | ||||||
|  |  | ||||||
| 	// wait for an error response for send |  | ||||||
| 	select { |  | ||||||
| 	case err := <-msg.errChan: |  | ||||||
| 		if err != nil { |  | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	case <-s.closed: |  | ||||||
| 		return io.EOF | 	// wait for an error response for send | ||||||
|  | 	if err := s.wait(msg); err != nil { | ||||||
|  | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// don't wait on multicast/broadcast | 	// now wait for the accept message to be returned | ||||||
| 	if s.mode == Multicast { | 	msg, err := s.waitFor("accept", s.dialTimeout) | ||||||
| 		s.accepted = true |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// now wait for the accept |  | ||||||
| 	msg, err := s.waitFor("accept", s.timeout) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -252,32 +281,16 @@ func (s *session) Accept() error { | |||||||
| 	msg := s.newMessage("accept") | 	msg := s.newMessage("accept") | ||||||
|  |  | ||||||
| 	// send the accept message | 	// send the accept message | ||||||
| 	select { | 	if err := s.sendMsg(msg); err != nil { | ||||||
| 	case <-s.closed: | 		return err | ||||||
| 		return io.EOF |  | ||||||
| 	case s.send <- msg: |  | ||||||
| 		// no op here |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// don't wait on multicast/broadcast |  | ||||||
| 	if s.mode == Multicast { |  | ||||||
| 		return nil |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// wait for send response | 	// wait for send response | ||||||
| 	select { | 	return s.wait(msg) | ||||||
| 	case err := <-s.errChan: |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	case <-s.closed: |  | ||||||
| 		return io.EOF |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Announce sends an announcement to notify that this session exists. This is primarily used by the listener. | // Announce sends an announcement to notify that this session exists. | ||||||
|  | // This is primarily used by the listener. | ||||||
| func (s *session) Announce() error { | func (s *session) Announce() error { | ||||||
| 	msg := s.newMessage("announce") | 	msg := s.newMessage("announce") | ||||||
| 	// we don't need an error back | 	// we don't need an error back | ||||||
| @@ -287,23 +300,12 @@ func (s *session) Announce() error { | |||||||
| 	// we don't need the link | 	// we don't need the link | ||||||
| 	msg.link = "" | 	msg.link = "" | ||||||
|  |  | ||||||
| 	select { | 	// send announce message | ||||||
| 	case s.send <- msg: | 	return s.sendMsg(msg) | ||||||
| 		return nil |  | ||||||
| 	case <-s.closed: |  | ||||||
| 		return io.EOF |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Send is used to send a message | // Send is used to send a message | ||||||
| func (s *session) Send(m *transport.Message) error { | func (s *session) Send(m *transport.Message) error { | ||||||
| 	select { |  | ||||||
| 	case <-s.closed: |  | ||||||
| 		return io.EOF |  | ||||||
| 	default: |  | ||||||
| 		// no op |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// encrypt the transport message payload | 	// encrypt the transport message payload | ||||||
| 	body, err := Encrypt(m.Body, s.token+s.channel+s.session) | 	body, err := Encrypt(m.Body, s.token+s.channel+s.session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -335,32 +337,28 @@ func (s *session) Send(m *transport.Message) error { | |||||||
| 	msg.data = data | 	msg.data = data | ||||||
|  |  | ||||||
| 	// if multicast don't set the link | 	// if multicast don't set the link | ||||||
| 	if s.mode == Multicast { | 	if s.mode != Unicast { | ||||||
| 		msg.link = "" | 		msg.link = "" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Tracef("Appending %+v to send backlog", msg) | 	log.Tracef("Appending %+v to send backlog", msg) | ||||||
|  |  | ||||||
| 	// send the actual message | 	// send the actual message | ||||||
| 	s.send <- msg | 	if err := s.sendMsg(msg); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// wait for an error response | 	// wait for an error response | ||||||
| 	select { | 	return s.wait(msg) | ||||||
| 	case err := <-msg.errChan: |  | ||||||
| 		return err |  | ||||||
| 	case <-s.closed: |  | ||||||
| 		return io.EOF |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Recv is used to receive a message | // Recv is used to receive a message | ||||||
| func (s *session) Recv(m *transport.Message) error { | func (s *session) Recv(m *transport.Message) error { | ||||||
| 	var msg *message | 	var msg *message | ||||||
|  |  | ||||||
| 	select { | 	msg, err := s.waitFor("", s.readTimeout) | ||||||
| 	case <-s.closed: | 	if err != nil { | ||||||
| 		return errors.New("session is closed") | 		return err | ||||||
| 	// recv from backlog |  | ||||||
| 	case msg = <-s.recv: |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// check the error if one exists | 	// check the error if one exists | ||||||
| @@ -371,10 +369,13 @@ func (s *session) Recv(m *transport.Message) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	//log.Tracef("Received %+v from recv backlog", msg) | 	//log.Tracef("Received %+v from recv backlog", msg) | ||||||
| 	log.Debugf("Received %+v from recv backlog", msg) | 	log.Tracef("Received %+v from recv backlog", msg) | ||||||
|  |  | ||||||
| 	// decrypt the received payload using the token | 	// decrypt the received payload using the token | ||||||
| 	body, err := Decrypt(msg.data.Body, s.token+s.channel+s.session) | 	// we have to used msg.session because multicast has a shared | ||||||
|  | 	// session id of "multicast" in this session struct on | ||||||
|  | 	// the listener side | ||||||
|  | 	body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Debugf("failed to decrypt message body: %v", err) | 		log.Debugf("failed to decrypt message body: %v", err) | ||||||
| 		return err | 		return err | ||||||
| @@ -390,7 +391,7 @@ func (s *session) Recv(m *transport.Message) error { | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		// encrypt the transport message payload | 		// encrypt the transport message payload | ||||||
| 		val, err := Decrypt([]byte(h), s.token+s.channel+s.session) | 		val, err := Decrypt([]byte(h), s.token+s.channel+msg.session) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Debugf("failed to decrypt message header %s: %v", k, err) | 			log.Debugf("failed to decrypt message header %s: %v", k, err) | ||||||
| 			return err | 			return err | ||||||
| @@ -399,6 +400,12 @@ func (s *session) Recv(m *transport.Message) error { | |||||||
| 		msg.data.Header[k] = string(val) | 		msg.data.Header[k] = string(val) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// set the link | ||||||
|  | 	// TODO: decruft, this is only for multicast | ||||||
|  | 	// since the session is now a single session | ||||||
|  | 	// likely provide as part of message.Link() | ||||||
|  | 	msg.data.Header["Micro-Link"] = msg.link | ||||||
|  |  | ||||||
| 	// set message | 	// set message | ||||||
| 	*m = *msg.data | 	*m = *msg.data | ||||||
| 	// return nil | 	// return nil | ||||||
| @@ -413,6 +420,11 @@ func (s *session) Close() error { | |||||||
| 	default: | 	default: | ||||||
| 		close(s.closed) | 		close(s.closed) | ||||||
|  |  | ||||||
|  | 		// don't send close on multicast or broadcast | ||||||
|  | 		if s.mode != Unicast { | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// append to backlog | 		// append to backlog | ||||||
| 		msg := s.newMessage("close") | 		msg := s.newMessage("close") | ||||||
| 		// no error response on close | 		// no error response on close | ||||||
| @@ -421,7 +433,7 @@ func (s *session) Close() error { | |||||||
| 		// send the close message | 		// send the close message | ||||||
| 		select { | 		select { | ||||||
| 		case s.send <- msg: | 		case s.send <- msg: | ||||||
| 		default: | 		case <-time.After(time.Millisecond * 10): | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -26,6 +26,8 @@ var ( | |||||||
| 	ErrDiscoverChan = errors.New("failed to discover channel") | 	ErrDiscoverChan = errors.New("failed to discover channel") | ||||||
| 	// ErrLinkNotFound is returned when a link is specified at dial time and does not exist | 	// ErrLinkNotFound is returned when a link is specified at dial time and does not exist | ||||||
| 	ErrLinkNotFound = errors.New("link not found") | 	ErrLinkNotFound = errors.New("link not found") | ||||||
|  | 	// ErrReadTimeout is a timeout on session.Recv | ||||||
|  | 	ErrReadTimeout = errors.New("read timeout") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Mode of the session | // Mode of the session | ||||||
| @@ -64,7 +66,9 @@ type Link interface { | |||||||
| 	Length() int64 | 	Length() int64 | ||||||
| 	// Current transfer rate as bits per second (lower is better) | 	// Current transfer rate as bits per second (lower is better) | ||||||
| 	Rate() float64 | 	Rate() float64 | ||||||
| 	// State of the link e.g connected/closed | 	// Is this a loopback link | ||||||
|  | 	Loopback() bool | ||||||
|  | 	// State of the link: connected/closed/error | ||||||
| 	State() string | 	State() string | ||||||
| 	// honours transport socket | 	// honours transport socket | ||||||
| 	transport.Socket | 	transport.Socket | ||||||
|   | |||||||
| @@ -1,55 +0,0 @@ | |||||||
| // +build !race |  | ||||||
|  |  | ||||||
| package tunnel |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"sync" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestReconnectTunnel(t *testing.T) { |  | ||||||
| 	// create a new tunnel client |  | ||||||
| 	tunA := NewTunnel( |  | ||||||
| 		Address("127.0.0.1:9096"), |  | ||||||
| 		Nodes("127.0.0.1:9097"), |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	// create a new tunnel server |  | ||||||
| 	tunB := NewTunnel( |  | ||||||
| 		Address("127.0.0.1:9097"), |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	// start tunnel |  | ||||||
| 	err := tunB.Connect() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer tunB.Close() |  | ||||||
|  |  | ||||||
| 	// we manually override the tunnel.ReconnectTime value here |  | ||||||
| 	// this is so that we make the reconnects faster than the default 5s |  | ||||||
| 	ReconnectTime = 200 * time.Millisecond |  | ||||||
|  |  | ||||||
| 	// start tunnel |  | ||||||
| 	err = tunA.Connect() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer tunA.Close() |  | ||||||
|  |  | ||||||
| 	wait := make(chan bool) |  | ||||||
|  |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
|  |  | ||||||
| 	wg.Add(1) |  | ||||||
| 	// start tunnel listener |  | ||||||
| 	go testBrokenTunAccept(t, tunB, wait, &wg) |  | ||||||
|  |  | ||||||
| 	wg.Add(1) |  | ||||||
| 	// start tunnel sender |  | ||||||
| 	go testBrokenTunSend(t, tunA, wait, &wg) |  | ||||||
|  |  | ||||||
| 	// wait until done |  | ||||||
| 	wg.Wait() |  | ||||||
| } |  | ||||||
| @@ -8,6 +8,90 @@ import ( | |||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { | ||||||
|  | 	defer wg.Done() | ||||||
|  |  | ||||||
|  | 	// listen on some virtual address | ||||||
|  | 	tl, err := tun.Listen("test-tunnel") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// receiver ready; notify sender | ||||||
|  | 	wait <- true | ||||||
|  |  | ||||||
|  | 	// accept a connection | ||||||
|  | 	c, err := tl.Accept() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// accept the message and close the tunnel | ||||||
|  | 	// we do this to simulate loss of network connection | ||||||
|  | 	m := new(transport.Message) | ||||||
|  | 	if err := c.Recv(m); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// close all the links | ||||||
|  | 	for _, link := range tun.Links() { | ||||||
|  | 		link.Close() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// receiver ready; notify sender | ||||||
|  | 	wait <- true | ||||||
|  |  | ||||||
|  | 	// accept the message | ||||||
|  | 	m = new(transport.Message) | ||||||
|  | 	if err := c.Recv(m); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// notify the sender we have received | ||||||
|  | 	wait <- true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup, reconnect time.Duration) { | ||||||
|  | 	defer wg.Done() | ||||||
|  |  | ||||||
|  | 	// wait for the listener to get ready | ||||||
|  | 	<-wait | ||||||
|  |  | ||||||
|  | 	// dial a new session | ||||||
|  | 	c, err := tun.Dial("test-tunnel") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer c.Close() | ||||||
|  |  | ||||||
|  | 	m := transport.Message{ | ||||||
|  | 		Header: map[string]string{ | ||||||
|  | 			"test": "send", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// send the message | ||||||
|  | 	if err := c.Send(&m); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// wait for the listener to get ready | ||||||
|  | 	<-wait | ||||||
|  |  | ||||||
|  | 	// give it time to reconnect | ||||||
|  | 	time.Sleep(reconnect) | ||||||
|  |  | ||||||
|  | 	// send the message | ||||||
|  | 	if err := c.Send(&m); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// wait for the listener to receive the message | ||||||
|  | 	// c.Send merely enqueues the message to the link send queue and returns | ||||||
|  | 	// in order to verify it was received we wait for the listener to tell us | ||||||
|  | 	<-wait | ||||||
|  | } | ||||||
|  |  | ||||||
| // testAccept will accept connections on the transport, create a new link and tunnel on top | // testAccept will accept connections on the transport, create a new link and tunnel on top | ||||||
| func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { | func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { | ||||||
| 	defer wg.Done() | 	defer wg.Done() | ||||||
| @@ -163,90 +247,6 @@ func TestLoopbackTunnel(t *testing.T) { | |||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
| func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { |  | ||||||
| 	defer wg.Done() |  | ||||||
|  |  | ||||||
| 	// listen on some virtual address |  | ||||||
| 	tl, err := tun.Listen("test-tunnel") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// receiver ready; notify sender |  | ||||||
| 	wait <- true |  | ||||||
|  |  | ||||||
| 	// accept a connection |  | ||||||
| 	c, err := tl.Accept() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// accept the message and close the tunnel |  | ||||||
| 	// we do this to simulate loss of network connection |  | ||||||
| 	m := new(transport.Message) |  | ||||||
| 	if err := c.Recv(m); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// close all the links |  | ||||||
| 	for _, link := range tun.Links() { |  | ||||||
| 		link.Close() |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// receiver ready; notify sender |  | ||||||
| 	wait <- true |  | ||||||
|  |  | ||||||
| 	// accept the message |  | ||||||
| 	m = new(transport.Message) |  | ||||||
| 	if err := c.Recv(m); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// notify the sender we have received |  | ||||||
| 	wait <- true |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) { |  | ||||||
| 	defer wg.Done() |  | ||||||
|  |  | ||||||
| 	// wait for the listener to get ready |  | ||||||
| 	<-wait |  | ||||||
|  |  | ||||||
| 	// dial a new session |  | ||||||
| 	c, err := tun.Dial("test-tunnel") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer c.Close() |  | ||||||
|  |  | ||||||
| 	m := transport.Message{ |  | ||||||
| 		Header: map[string]string{ |  | ||||||
| 			"test": "send", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// send the message |  | ||||||
| 	if err := c.Send(&m); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// wait for the listener to get ready |  | ||||||
| 	<-wait |  | ||||||
|  |  | ||||||
| 	// give it time to reconnect |  | ||||||
| 	time.Sleep(5 * ReconnectTime) |  | ||||||
|  |  | ||||||
| 	// send the message |  | ||||||
| 	if err := c.Send(&m); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// wait for the listener to receive the message |  | ||||||
| 	// c.Send merely enqueues the message to the link send queue and returns |  | ||||||
| 	// in order to verify it was received we wait for the listener to tell us |  | ||||||
| 	<-wait |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestTunnelRTTRate(t *testing.T) { | func TestTunnelRTTRate(t *testing.T) { | ||||||
| 	// create a new tunnel client | 	// create a new tunnel client | ||||||
| 	tunA := NewTunnel( | 	tunA := NewTunnel( | ||||||
| @@ -296,3 +296,49 @@ func TestTunnelRTTRate(t *testing.T) { | |||||||
| 		t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate()) | 		t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestReconnectTunnel(t *testing.T) { | ||||||
|  | 	// we manually override the tunnel.ReconnectTime value here | ||||||
|  | 	// this is so that we make the reconnects faster than the default 5s | ||||||
|  | 	ReconnectTime = 200 * time.Millisecond | ||||||
|  |  | ||||||
|  | 	// create a new tunnel client | ||||||
|  | 	tunA := NewTunnel( | ||||||
|  | 		Address("127.0.0.1:9098"), | ||||||
|  | 		Nodes("127.0.0.1:9099"), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	// create a new tunnel server | ||||||
|  | 	tunB := NewTunnel( | ||||||
|  | 		Address("127.0.0.1:9099"), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	// start tunnel | ||||||
|  | 	err := tunB.Connect() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer tunB.Close() | ||||||
|  |  | ||||||
|  | 	// start tunnel | ||||||
|  | 	err = tunA.Connect() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer tunA.Close() | ||||||
|  |  | ||||||
|  | 	wait := make(chan bool) | ||||||
|  |  | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	// start tunnel listener | ||||||
|  | 	go testBrokenTunAccept(t, tunB, wait, &wg) | ||||||
|  |  | ||||||
|  | 	wg.Add(1) | ||||||
|  | 	// start tunnel sender | ||||||
|  | 	go testBrokenTunSend(t, tunA, wait, &wg, ReconnectTime*5) | ||||||
|  |  | ||||||
|  | 	// wait until done | ||||||
|  | 	wg.Wait() | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user