commit
a9be1288d2
@ -316,6 +316,22 @@ func (r *rpcClient) Options() Options {
|
|||||||
return r.opts
|
return r.opts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hasProxy checks if we have proxy set in the environment
|
||||||
|
func (r *rpcClient) hasProxy() bool {
|
||||||
|
// get proxy
|
||||||
|
if prx := os.Getenv("MICRO_PROXY"); len(prx) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// get proxy address
|
||||||
|
if prx := os.Getenv("MICRO_PROXY_ADDRESS"); len(prx) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// next returns an iterator for the next nodes to call
|
||||||
func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) {
|
func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) {
|
||||||
service := request.Service()
|
service := request.Service()
|
||||||
|
|
||||||
@ -431,10 +447,18 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan error, callOpts.Retries+1)
|
// get the retries
|
||||||
|
retries := callOpts.Retries
|
||||||
|
|
||||||
|
// disable retries when using a proxy
|
||||||
|
if r.hasProxy() {
|
||||||
|
retries = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan error, retries+1)
|
||||||
var gerr error
|
var gerr error
|
||||||
|
|
||||||
for i := 0; i <= callOpts.Retries; i++ {
|
for i := 0; i <= retries; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
ch <- call(i)
|
ch <- call(i)
|
||||||
}(i)
|
}(i)
|
||||||
@ -514,10 +538,18 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan response, callOpts.Retries+1)
|
// get the retries
|
||||||
|
retries := callOpts.Retries
|
||||||
|
|
||||||
|
// disable retries when using a proxy
|
||||||
|
if r.hasProxy() {
|
||||||
|
retries = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan response, retries+1)
|
||||||
var grr error
|
var grr error
|
||||||
|
|
||||||
for i := 0; i <= callOpts.Retries; i++ {
|
for i := 0; i <= retries; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
s, err := call(i)
|
s, err := call(i)
|
||||||
ch <- response{s, err}
|
ch <- response{s, err}
|
||||||
|
@ -88,32 +88,24 @@ func (rwc *readWriteCloser) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getHeaders(m *codec.Message) {
|
func getHeaders(m *codec.Message) {
|
||||||
get := func(hdr string) string {
|
set := func(v, hdr string) string {
|
||||||
if hd := m.Header[hdr]; len(hd) > 0 {
|
if len(v) > 0 {
|
||||||
return hd
|
return v
|
||||||
}
|
}
|
||||||
// old
|
return m.Header[hdr]
|
||||||
return m.Header["X-"+hdr]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check error in header
|
// check error in header
|
||||||
if len(m.Error) == 0 {
|
m.Error = set(m.Error, "Micro-Error")
|
||||||
m.Error = get("Micro-Error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check endpoint in header
|
// check endpoint in header
|
||||||
if len(m.Endpoint) == 0 {
|
m.Endpoint = set(m.Endpoint, "Micro-Endpoint")
|
||||||
m.Endpoint = get("Micro-Endpoint")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check method in header
|
// check method in header
|
||||||
if len(m.Method) == 0 {
|
m.Method = set(m.Method, "Micro-Method")
|
||||||
m.Method = get("Micro-Method")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.Id) == 0 {
|
// set the request id
|
||||||
m.Id = get("Micro-Id")
|
m.Id = set(m.Id, "Micro-Id")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setHeaders(m *codec.Message, stream string) {
|
func setHeaders(m *codec.Message, stream string) {
|
||||||
@ -122,7 +114,6 @@ func setHeaders(m *codec.Message, stream string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.Header[hdr] = v
|
m.Header[hdr] = v
|
||||||
m.Header["X-"+hdr] = v
|
|
||||||
}
|
}
|
||||||
|
|
||||||
set("Micro-Id", m.Id)
|
set("Micro-Id", m.Id)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -6,8 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/micro/go-micro/client"
|
"github.com/micro/go-micro/client"
|
||||||
"github.com/micro/go-micro/server"
|
"github.com/micro/go-micro/server"
|
||||||
"github.com/micro/go-micro/transport"
|
|
||||||
"github.com/micro/go-micro/tunnel"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -56,14 +54,6 @@ type Network interface {
|
|||||||
Server() server.Server
|
Server() server.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// message is network message
|
|
||||||
type message struct {
|
|
||||||
// msg is transport message
|
|
||||||
msg *transport.Message
|
|
||||||
// session is tunnel session
|
|
||||||
session tunnel.Session
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNetwork returns a new network interface
|
// NewNetwork returns a new network interface
|
||||||
func NewNetwork(opts ...Option) Network {
|
func NewNetwork(opts ...Option) Network {
|
||||||
return newNetwork(opts...)
|
return newNetwork(opts...)
|
||||||
|
@ -216,7 +216,7 @@ func (n *node) DeletePeerNode(id string) error {
|
|||||||
|
|
||||||
// PruneStalePeerNodes prune the peers that have not been seen for longer than given time
|
// PruneStalePeerNodes prune the peers that have not been seen for longer than given time
|
||||||
// It returns a map of the the nodes that got pruned
|
// It returns a map of the the nodes that got pruned
|
||||||
func (n *node) PruneStalePeerNodes(pruneTime time.Duration) map[string]*node {
|
func (n *node) PruneStalePeers(pruneTime time.Duration) map[string]*node {
|
||||||
n.Lock()
|
n.Lock()
|
||||||
defer n.Unlock()
|
defer n.Unlock()
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ func TestPruneStalePeerNodes(t *testing.T) {
|
|||||||
time.Sleep(pruneTime)
|
time.Sleep(pruneTime)
|
||||||
|
|
||||||
// should delete all nodes besides node
|
// should delete all nodes besides node
|
||||||
pruned := node.PruneStalePeerNodes(pruneTime)
|
pruned := node.PruneStalePeers(pruneTime)
|
||||||
|
|
||||||
if len(pruned) != len(nodes)-1 {
|
if len(pruned) != len(nodes)-1 {
|
||||||
t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned))
|
t.Errorf("Expected pruned node count: %d, got: %d", len(nodes)-1, len(pruned))
|
||||||
|
@ -22,8 +22,8 @@ type Options struct {
|
|||||||
Address string
|
Address string
|
||||||
// Advertise sets the address to advertise
|
// Advertise sets the address to advertise
|
||||||
Advertise string
|
Advertise string
|
||||||
// Peers is a list of peers to connect to
|
// Nodes is a list of nodes to connect to
|
||||||
Peers []string
|
Nodes []string
|
||||||
// Tunnel is network tunnel
|
// Tunnel is network tunnel
|
||||||
Tunnel tunnel.Tunnel
|
Tunnel tunnel.Tunnel
|
||||||
// Router is network router
|
// Router is network router
|
||||||
@ -62,10 +62,10 @@ func Advertise(a string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peers is a list of peers to connect to
|
// Nodes is a list of nodes to connect to
|
||||||
func Peers(n ...string) Option {
|
func Nodes(n ...string) Option {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.Peers = n
|
o.Nodes = n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,17 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
|||||||
r.Address = "1.0.0.1:53"
|
r.Address = "1.0.0.1:53"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:prealloc
|
||||||
|
var records []*resolver.Record
|
||||||
|
|
||||||
|
// parsed an actual ip
|
||||||
|
if v := net.ParseIP(host); v != nil {
|
||||||
|
records = append(records, &resolver.Record{
|
||||||
|
Address: net.JoinHostPort(host, port),
|
||||||
|
})
|
||||||
|
return records, nil
|
||||||
|
}
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||||
rec, err := dns.ExchangeContext(context.Background(), m, r.Address)
|
rec, err := dns.ExchangeContext(context.Background(), m, r.Address)
|
||||||
@ -38,9 +49,6 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:prealloc
|
|
||||||
var records []*resolver.Record
|
|
||||||
|
|
||||||
for _, answer := range rec.Answer {
|
for _, answer := range rec.Answer {
|
||||||
h := answer.Header()
|
h := answer.Header()
|
||||||
// check record type matches
|
// check record type matches
|
||||||
@ -59,5 +67,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// no records returned so just best effort it
|
||||||
|
if len(records) == 0 {
|
||||||
|
records = append(records, &resolver.Record{
|
||||||
|
Address: net.JoinHostPort(host, port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get list of existing nodes
|
// get list of existing nodes
|
||||||
nodes := n.Network.Options().Peers
|
nodes := n.Network.Options().Nodes
|
||||||
|
|
||||||
// generate a node map
|
// generate a node map
|
||||||
nodeMap := make(map[string]bool)
|
nodeMap := make(map[string]bool)
|
||||||
@ -84,7 +84,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp *
|
|||||||
|
|
||||||
// reinitialise the peers
|
// reinitialise the peers
|
||||||
n.Network.Init(
|
n.Network.Init(
|
||||||
network.Peers(nodes...),
|
network.Nodes(nodes...),
|
||||||
)
|
)
|
||||||
|
|
||||||
// call the connect method
|
// call the connect method
|
||||||
|
@ -83,7 +83,8 @@ func readLoop(r server.Request, s client.Stream) error {
|
|||||||
|
|
||||||
// toNodes returns a list of node addresses from given routes
|
// toNodes returns a list of node addresses from given routes
|
||||||
func toNodes(routes []router.Route) []string {
|
func toNodes(routes []router.Route) []string {
|
||||||
nodes := make([]string, len(routes))
|
nodes := make([]string, 0, len(routes))
|
||||||
|
|
||||||
for _, node := range routes {
|
for _, node := range routes {
|
||||||
address := node.Address
|
address := node.Address
|
||||||
if len(node.Gateway) > 0 {
|
if len(node.Gateway) > 0 {
|
||||||
@ -91,11 +92,13 @@ func toNodes(routes []router.Route) []string {
|
|||||||
}
|
}
|
||||||
nodes = append(nodes, address)
|
nodes = append(nodes, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSlice(r map[uint64]router.Route) []router.Route {
|
func toSlice(r map[uint64]router.Route) []router.Route {
|
||||||
routes := make([]router.Route, 0, len(r))
|
routes := make([]router.Route, 0, len(r))
|
||||||
|
|
||||||
for _, v := range r {
|
for _, v := range r {
|
||||||
routes = append(routes, v)
|
routes = append(routes, v)
|
||||||
}
|
}
|
||||||
@ -161,6 +164,8 @@ func (p *Proxy) filterRoutes(ctx context.Context, routes []router.Route) []route
|
|||||||
filteredRoutes = append(filteredRoutes, route)
|
filteredRoutes = append(filteredRoutes, route)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Tracef("Proxy filtered routes %+v\n", filteredRoutes)
|
||||||
|
|
||||||
return filteredRoutes
|
return filteredRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,13 +230,15 @@ func (p *Proxy) cacheRoutes(service string) ([]router.Route, error) {
|
|||||||
// refreshMetrics will refresh any metrics for our local cached routes.
|
// refreshMetrics will refresh any metrics for our local cached routes.
|
||||||
// we may not receive new watch events for these as they change.
|
// we may not receive new watch events for these as they change.
|
||||||
func (p *Proxy) refreshMetrics() {
|
func (p *Proxy) refreshMetrics() {
|
||||||
services := make([]string, 0, len(p.Routes))
|
|
||||||
|
|
||||||
// get a list of services to update
|
// get a list of services to update
|
||||||
p.RLock()
|
p.RLock()
|
||||||
|
|
||||||
|
services := make([]string, 0, len(p.Routes))
|
||||||
|
|
||||||
for service := range p.Routes {
|
for service := range p.Routes {
|
||||||
services = append(services, service)
|
services = append(services, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.RUnlock()
|
p.RUnlock()
|
||||||
|
|
||||||
// get and cache the routes for the service
|
// get and cache the routes for the service
|
||||||
@ -246,6 +253,8 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error {
|
|||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
|
log.Tracef("Proxy taking route action %v %+v\n", action, route)
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case "create", "update":
|
case "create", "update":
|
||||||
if _, ok := p.Routes[route.Service]; !ok {
|
if _, ok := p.Routes[route.Service]; !ok {
|
||||||
@ -253,7 +262,12 @@ func (p *Proxy) manageRoutes(route router.Route, action string) error {
|
|||||||
}
|
}
|
||||||
p.Routes[route.Service][route.Hash()] = route
|
p.Routes[route.Service][route.Hash()] = route
|
||||||
case "delete":
|
case "delete":
|
||||||
|
// delete that specific route
|
||||||
delete(p.Routes[route.Service], route.Hash())
|
delete(p.Routes[route.Service], route.Hash())
|
||||||
|
// clean up the cache entirely
|
||||||
|
if len(p.Routes[route.Service]) == 0 {
|
||||||
|
delete(p.Routes, route.Service)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown action: %s", action)
|
return fmt.Errorf("unknown action: %s", action)
|
||||||
}
|
}
|
||||||
@ -288,7 +302,7 @@ func (p *Proxy) ProcessMessage(ctx context.Context, msg server.Message) error {
|
|||||||
// TODO: check that we're not broadcast storming by sending to the same topic
|
// TODO: check that we're not broadcast storming by sending to the same topic
|
||||||
// that we're actually subscribed to
|
// that we're actually subscribed to
|
||||||
|
|
||||||
log.Tracef("Received message for %s", msg.Topic())
|
log.Tracef("Proxy received message for %s", msg.Topic())
|
||||||
|
|
||||||
var errors []string
|
var errors []string
|
||||||
|
|
||||||
@ -329,7 +343,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
|
|||||||
return errors.BadRequest("go.micro.proxy", "service name is blank")
|
return errors.BadRequest("go.micro.proxy", "service name is blank")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Received request for %s", service)
|
log.Tracef("Proxy received request for %s", service)
|
||||||
|
|
||||||
// are we network routing or local routing
|
// are we network routing or local routing
|
||||||
if len(p.Links) == 0 {
|
if len(p.Links) == 0 {
|
||||||
@ -363,15 +377,17 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
|
|||||||
}
|
}
|
||||||
|
|
||||||
//nolint:prealloc
|
//nolint:prealloc
|
||||||
var opts []client.CallOption
|
opts := []client.CallOption{
|
||||||
|
|
||||||
// set strategy to round robin
|
// set strategy to round robin
|
||||||
opts = append(opts, client.WithSelectOption(selector.WithStrategy(selector.RoundRobin)))
|
client.WithSelectOption(selector.WithStrategy(selector.RoundRobin)),
|
||||||
|
}
|
||||||
|
|
||||||
// if the address is already set just serve it
|
// if the address is already set just serve it
|
||||||
// TODO: figure it out if we should know to pick a link
|
// TODO: figure it out if we should know to pick a link
|
||||||
if len(addresses) > 0 {
|
if len(addresses) > 0 {
|
||||||
opts = append(opts, client.WithAddress(addresses...))
|
opts = append(opts,
|
||||||
|
client.WithAddress(addresses...),
|
||||||
|
)
|
||||||
|
|
||||||
// serve the normal way
|
// serve the normal way
|
||||||
return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...)
|
return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...)
|
||||||
@ -387,10 +403,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
|
|||||||
opts = append(opts, client.WithAddress(addresses...))
|
opts = append(opts, client.WithAddress(addresses...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Tracef("Proxy calling %+v\n", addresses)
|
||||||
// serve the normal way
|
// serve the normal way
|
||||||
return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...)
|
return p.serveRequest(ctx, p.Client, service, endpoint, req, rsp, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we're assuming we need routes to operate on
|
||||||
|
if len(routes) == 0 {
|
||||||
|
return errors.InternalServerError("go.micro.proxy", "route not found")
|
||||||
|
}
|
||||||
|
|
||||||
var gerr error
|
var gerr error
|
||||||
|
|
||||||
// we're routing globally with multiple links
|
// we're routing globally with multiple links
|
||||||
@ -404,11 +426,16 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Proxy using route %+v\n", route)
|
log.Tracef("Proxy using route %+v\n", route)
|
||||||
|
|
||||||
// set the address to call
|
// set the address to call
|
||||||
addresses := toNodes([]router.Route{route})
|
addresses := toNodes([]router.Route{route})
|
||||||
opts = append(opts, client.WithAddress(addresses...))
|
// set the address in the options
|
||||||
|
// disable retries since its one route processing
|
||||||
|
opts = append(opts,
|
||||||
|
client.WithAddress(addresses...),
|
||||||
|
client.WithRetries(0),
|
||||||
|
)
|
||||||
|
|
||||||
// do the request with the link
|
// do the request with the link
|
||||||
gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...)
|
gerr = p.serveRequest(ctx, link, service, endpoint, req, rsp, opts...)
|
||||||
@ -558,7 +585,9 @@ func NewProxy(opts ...options.Option) proxy.Proxy {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
t := time.NewTicker(time.Minute)
|
// TODO: speed up refreshing of metrics
|
||||||
|
// without this ticking effort e.g stream
|
||||||
|
t := time.NewTicker(time.Second * 10)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
|
|
||||||
// we must refresh route metrics since they do not trigger new events
|
// we must refresh route metrics since they do not trigger new events
|
||||||
|
@ -799,7 +799,8 @@ func (r *router) flushRouteEvents(evType EventType) ([]*Event, error) {
|
|||||||
|
|
||||||
// build a list of events to advertise
|
// build a list of events to advertise
|
||||||
events := make([]*Event, len(bestRoutes))
|
events := make([]*Event, len(bestRoutes))
|
||||||
i := 0
|
var i int
|
||||||
|
|
||||||
for _, route := range bestRoutes {
|
for _, route := range bestRoutes {
|
||||||
event := &Event{
|
event := &Event{
|
||||||
Type: evType,
|
Type: evType,
|
||||||
@ -823,9 +824,10 @@ func (r *router) Solicit() error {
|
|||||||
|
|
||||||
// advertise the routes
|
// advertise the routes
|
||||||
r.advertWg.Add(1)
|
r.advertWg.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer r.advertWg.Done()
|
r.publishAdvert(Solicitation, events)
|
||||||
r.publishAdvert(RouteUpdate, events)
|
r.advertWg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -111,6 +111,8 @@ const (
|
|||||||
Announce AdvertType = iota
|
Announce AdvertType = iota
|
||||||
// RouteUpdate advertises route updates
|
// RouteUpdate advertises route updates
|
||||||
RouteUpdate
|
RouteUpdate
|
||||||
|
// Solicitation indicates routes were solicited
|
||||||
|
Solicitation
|
||||||
)
|
)
|
||||||
|
|
||||||
// String returns human readable advertisement type
|
// String returns human readable advertisement type
|
||||||
@ -120,6 +122,8 @@ func (t AdvertType) String() string {
|
|||||||
return "announce"
|
return "announce"
|
||||||
case RouteUpdate:
|
case RouteUpdate:
|
||||||
return "update"
|
return "update"
|
||||||
|
case Solicitation:
|
||||||
|
return "solicitation"
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
@ -86,24 +86,18 @@ func getHeader(hdr string, md map[string]string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getHeaders(m *codec.Message) {
|
func getHeaders(m *codec.Message) {
|
||||||
get := func(hdr, v string) string {
|
set := func(v, hdr string) string {
|
||||||
if len(v) > 0 {
|
if len(v) > 0 {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
return m.Header[hdr]
|
||||||
if hd := m.Header[hdr]; len(hd) > 0 {
|
|
||||||
return hd
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// old
|
m.Id = set(m.Id, "Micro-Id")
|
||||||
return m.Header["X-"+hdr]
|
m.Error = set(m.Error, "Micro-Error")
|
||||||
}
|
m.Endpoint = set(m.Endpoint, "Micro-Endpoint")
|
||||||
|
m.Method = set(m.Method, "Micro-Method")
|
||||||
m.Id = get("Micro-Id", m.Id)
|
m.Target = set(m.Target, "Micro-Service")
|
||||||
m.Error = get("Micro-Error", m.Error)
|
|
||||||
m.Endpoint = get("Micro-Endpoint", m.Endpoint)
|
|
||||||
m.Method = get("Micro-Method", m.Method)
|
|
||||||
m.Target = get("Micro-Service", m.Target)
|
|
||||||
|
|
||||||
// TODO: remove this cruft
|
// TODO: remove this cruft
|
||||||
if len(m.Endpoint) == 0 {
|
if len(m.Endpoint) == 0 {
|
||||||
@ -321,7 +315,6 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
|
|||||||
|
|
||||||
// write an error if it failed
|
// write an error if it failed
|
||||||
m.Error = errors.Wrapf(err, "Unable to encode body").Error()
|
m.Error = errors.Wrapf(err, "Unable to encode body").Error()
|
||||||
m.Header["X-Micro-Error"] = m.Error
|
|
||||||
m.Header["Micro-Error"] = m.Error
|
m.Header["Micro-Error"] = m.Error
|
||||||
// no body to write
|
// no body to write
|
||||||
if err := c.codec.Write(m, nil); err != nil {
|
if err := c.codec.Write(m, nil); err != nil {
|
||||||
|
@ -549,6 +549,7 @@ func (s *rpcServer) Register() error {
|
|||||||
node.Metadata["protocol"] = "mucp"
|
node.Metadata["protocol"] = "mucp"
|
||||||
|
|
||||||
s.RLock()
|
s.RLock()
|
||||||
|
|
||||||
// Maps are ordered randomly, sort the keys for consistency
|
// Maps are ordered randomly, sort the keys for consistency
|
||||||
var handlerList []string
|
var handlerList []string
|
||||||
for n, e := range s.handlers {
|
for n, e := range s.handlers {
|
||||||
@ -557,6 +558,7 @@ func (s *rpcServer) Register() error {
|
|||||||
handlerList = append(handlerList, n)
|
handlerList = append(handlerList, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(handlerList)
|
sort.Strings(handlerList)
|
||||||
|
|
||||||
var subscriberList []Subscriber
|
var subscriberList []Subscriber
|
||||||
@ -566,18 +568,20 @@ func (s *rpcServer) Register() error {
|
|||||||
subscriberList = append(subscriberList, e)
|
subscriberList = append(subscriberList, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(subscriberList, func(i, j int) bool {
|
sort.Slice(subscriberList, func(i, j int) bool {
|
||||||
return subscriberList[i].Topic() > subscriberList[j].Topic()
|
return subscriberList[i].Topic() > subscriberList[j].Topic()
|
||||||
})
|
})
|
||||||
|
|
||||||
endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList))
|
endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList))
|
||||||
|
|
||||||
for _, n := range handlerList {
|
for _, n := range handlerList {
|
||||||
endpoints = append(endpoints, s.handlers[n].Endpoints()...)
|
endpoints = append(endpoints, s.handlers[n].Endpoints()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, e := range subscriberList {
|
for _, e := range subscriberList {
|
||||||
endpoints = append(endpoints, e.Endpoints()...)
|
endpoints = append(endpoints, e.Endpoints()...)
|
||||||
}
|
}
|
||||||
s.RUnlock()
|
|
||||||
|
|
||||||
service := ®istry.Service{
|
service := ®istry.Service{
|
||||||
Name: config.Name,
|
Name: config.Name,
|
||||||
@ -586,9 +590,10 @@ func (s *rpcServer) Register() error {
|
|||||||
Endpoints: endpoints,
|
Endpoints: endpoints,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Lock()
|
// get registered value
|
||||||
registered := s.registered
|
registered := s.registered
|
||||||
s.Unlock()
|
|
||||||
|
s.RUnlock()
|
||||||
|
|
||||||
if !registered {
|
if !registered {
|
||||||
log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id)
|
log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id)
|
||||||
@ -610,6 +615,8 @@ func (s *rpcServer) Register() error {
|
|||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
|
|
||||||
s.registered = true
|
s.registered = true
|
||||||
|
// set what we're advertising
|
||||||
|
s.opts.Advertise = addr
|
||||||
|
|
||||||
// subscribe to the topic with own name
|
// subscribe to the topic with own name
|
||||||
sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent)
|
sub, err := s.opts.Broker.Subscribe(config.Name, s.HandleEvent)
|
||||||
|
@ -9,9 +9,9 @@ import (
|
|||||||
|
|
||||||
"github.com/micro/go-micro/client"
|
"github.com/micro/go-micro/client"
|
||||||
"github.com/micro/go-micro/config/cmd"
|
"github.com/micro/go-micro/config/cmd"
|
||||||
"github.com/micro/go-micro/debug/service/handler"
|
|
||||||
"github.com/micro/go-micro/debug/profile"
|
"github.com/micro/go-micro/debug/profile"
|
||||||
"github.com/micro/go-micro/debug/profile/pprof"
|
"github.com/micro/go-micro/debug/profile/pprof"
|
||||||
|
"github.com/micro/go-micro/debug/service/handler"
|
||||||
"github.com/micro/go-micro/plugin"
|
"github.com/micro/go-micro/plugin"
|
||||||
"github.com/micro/go-micro/server"
|
"github.com/micro/go-micro/server"
|
||||||
"github.com/micro/go-micro/util/log"
|
"github.com/micro/go-micro/util/log"
|
||||||
|
@ -127,7 +127,6 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) {
|
|||||||
closed: make(chan bool),
|
closed: make(chan bool),
|
||||||
recv: make(chan *message, 128),
|
recv: make(chan *message, 128),
|
||||||
send: t.send,
|
send: t.send,
|
||||||
wait: make(chan bool),
|
|
||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,8 +197,8 @@ func (t *tun) announce(channel, session string, link *link) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// monitor monitors outbound links and attempts to reconnect to the failed ones
|
// manage monitors outbound links and attempts to reconnect to the failed ones
|
||||||
func (t *tun) monitor() {
|
func (t *tun) manage() {
|
||||||
reconnect := time.NewTicker(ReconnectTime)
|
reconnect := time.NewTicker(ReconnectTime)
|
||||||
defer reconnect.Stop()
|
defer reconnect.Stop()
|
||||||
|
|
||||||
@ -208,9 +207,48 @@ func (t *tun) monitor() {
|
|||||||
case <-t.closed:
|
case <-t.closed:
|
||||||
return
|
return
|
||||||
case <-reconnect.C:
|
case <-reconnect.C:
|
||||||
|
t.manageLinks()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// manageLink sends channel discover requests periodically and
|
||||||
|
// keepalive messages to link
|
||||||
|
func (t *tun) manageLink(link *link) {
|
||||||
|
keepalive := time.NewTicker(KeepAliveTime)
|
||||||
|
defer keepalive.Stop()
|
||||||
|
discover := time.NewTicker(DiscoverTime)
|
||||||
|
defer discover.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.closed:
|
||||||
|
return
|
||||||
|
case <-link.closed:
|
||||||
|
return
|
||||||
|
case <-discover.C:
|
||||||
|
// send a discovery message to all links
|
||||||
|
if err := t.sendMsg("discover", link); err != nil {
|
||||||
|
log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err)
|
||||||
|
}
|
||||||
|
case <-keepalive.C:
|
||||||
|
// send keepalive message
|
||||||
|
log.Debugf("Tunnel sending keepalive to link: %v", link.Remote())
|
||||||
|
if err := t.sendMsg("keepalive", link); err != nil {
|
||||||
|
log.Debugf("Tunnel error sending keepalive to link %v: %v", link.Remote(), err)
|
||||||
|
t.delLink(link.Remote())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// manageLinks is a function that can be called to immediately to link setup
|
||||||
|
func (t *tun) manageLinks() {
|
||||||
|
var delLinks []string
|
||||||
|
|
||||||
t.RLock()
|
t.RLock()
|
||||||
|
|
||||||
var delLinks []string
|
|
||||||
// check the link status and purge dead links
|
// check the link status and purge dead links
|
||||||
for node, link := range t.links {
|
for node, link := range t.links {
|
||||||
// check link status
|
// check link status
|
||||||
@ -242,27 +280,46 @@ func (t *tun) monitor() {
|
|||||||
|
|
||||||
// build list of unknown nodes to connect to
|
// build list of unknown nodes to connect to
|
||||||
t.RLock()
|
t.RLock()
|
||||||
|
|
||||||
for _, node := range t.options.Nodes {
|
for _, node := range t.options.Nodes {
|
||||||
if _, ok := t.links[node]; !ok {
|
if _, ok := t.links[node]; !ok {
|
||||||
connect = append(connect, node)
|
connect = append(connect, node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.RUnlock()
|
t.RUnlock()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for _, node := range connect {
|
for _, node := range connect {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(node string) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
// create new link
|
// create new link
|
||||||
link, err := t.setupLink(node)
|
link, err := t.setupLink(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
|
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the link
|
// save the link
|
||||||
t.Lock()
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
// just check nothing else was setup in the interim
|
||||||
|
if _, ok := t.links[node]; ok {
|
||||||
|
link.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// save the link
|
||||||
t.links[node] = link
|
t.links[node] = link
|
||||||
t.Unlock()
|
}(node)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for all threads to finish
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// process outgoing messages sent by all local sessions
|
// process outgoing messages sent by all local sessions
|
||||||
@ -328,6 +385,7 @@ func (t *tun) process() {
|
|||||||
// and the message is being sent outbound via
|
// and the message is being sent outbound via
|
||||||
// a dialled connection don't use this link
|
// a dialled connection don't use this link
|
||||||
if loopback && msg.outbound {
|
if loopback && msg.outbound {
|
||||||
|
log.Tracef("Link for node %s is loopback", node)
|
||||||
err = errors.New("link is loopback")
|
err = errors.New("link is loopback")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -335,6 +393,7 @@ func (t *tun) process() {
|
|||||||
// if the message was being returned by the loopback listener
|
// if the message was being returned by the loopback listener
|
||||||
// send it back up the loopback link only
|
// send it back up the loopback link only
|
||||||
if msg.loopback && !loopback {
|
if msg.loopback && !loopback {
|
||||||
|
log.Tracef("Link for message %s is loopback", node)
|
||||||
err = errors.New("link is not loopback")
|
err = errors.New("link is not loopback")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -364,7 +423,7 @@ func (t *tun) process() {
|
|||||||
// send the message
|
// send the message
|
||||||
for _, link := range sendTo {
|
for _, link := range sendTo {
|
||||||
// send the message via the current link
|
// send the message via the current link
|
||||||
log.Tracef("Sending %+v to %s", newMsg.Header, link.Remote())
|
log.Tracef("Tunnel sending %+v to %s", newMsg.Header, link.Remote())
|
||||||
|
|
||||||
if errr := link.Send(newMsg); errr != nil {
|
if errr := link.Send(newMsg); errr != nil {
|
||||||
log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr)
|
log.Debugf("Tunnel error sending %+v to %s: %v", newMsg.Header, link.Remote(), errr)
|
||||||
@ -470,6 +529,9 @@ func (t *tun) listen(link *link) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this state machine block handles the only message types
|
||||||
|
// that we know or care about; connect, close, open, accept,
|
||||||
|
// discover, announce, session, keepalive
|
||||||
switch mtype {
|
switch mtype {
|
||||||
case "connect":
|
case "connect":
|
||||||
log.Debugf("Tunnel link %s received connect message", link.Remote())
|
log.Debugf("Tunnel link %s received connect message", link.Remote())
|
||||||
@ -495,14 +557,14 @@ func (t *tun) listen(link *link) {
|
|||||||
t.links[link.Remote()] = link
|
t.links[link.Remote()] = link
|
||||||
t.Unlock()
|
t.Unlock()
|
||||||
|
|
||||||
// send back a discovery
|
// send back an announcement of our channels discovery
|
||||||
go t.announce("", "", link)
|
go t.announce("", "", link)
|
||||||
|
// ask for the things on the other wise
|
||||||
|
go t.sendMsg("discover", link)
|
||||||
// nothing more to do
|
// nothing more to do
|
||||||
continue
|
continue
|
||||||
case "close":
|
case "close":
|
||||||
// TODO: handle the close message
|
log.Debugf("Tunnel link %s received close message", link.Remote())
|
||||||
// maybe report io.EOF or kill the link
|
|
||||||
|
|
||||||
// if there is no channel then we close the link
|
// if there is no channel then we close the link
|
||||||
// as its a signal from the other side to close the connection
|
// as its a signal from the other side to close the connection
|
||||||
if len(channel) == 0 {
|
if len(channel) == 0 {
|
||||||
@ -521,6 +583,8 @@ func (t *tun) listen(link *link) {
|
|||||||
// try get the dialing socket
|
// try get the dialing socket
|
||||||
s, exists := t.getSession(channel, sessionId)
|
s, exists := t.getSession(channel, sessionId)
|
||||||
if exists && !loopback {
|
if exists && !loopback {
|
||||||
|
// only delete the session if its unicast
|
||||||
|
// otherwise ignore close on the multicast
|
||||||
if s.mode == Unicast {
|
if s.mode == Unicast {
|
||||||
// only delete this if its unicast
|
// only delete this if its unicast
|
||||||
// but not if its a loopback conn
|
// but not if its a loopback conn
|
||||||
@ -541,20 +605,24 @@ func (t *tun) listen(link *link) {
|
|||||||
// an accept returned by the listener
|
// an accept returned by the listener
|
||||||
case "accept":
|
case "accept":
|
||||||
s, exists := t.getSession(channel, sessionId)
|
s, exists := t.getSession(channel, sessionId)
|
||||||
// we don't need this
|
// just set accepted on anything not unicast
|
||||||
if exists && s.mode > Unicast {
|
if exists && s.mode > Unicast {
|
||||||
s.accepted = true
|
s.accepted = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// if its already accepted move on
|
||||||
if exists && s.accepted {
|
if exists && s.accepted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// otherwise we're going to process to accept
|
||||||
// a continued session
|
// a continued session
|
||||||
case "session":
|
case "session":
|
||||||
// process message
|
// process message
|
||||||
log.Tracef("Received %+v from %s", msg.Header, link.Remote())
|
log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote())
|
||||||
// an announcement of a channel listener
|
// an announcement of a channel listener
|
||||||
case "announce":
|
case "announce":
|
||||||
|
log.Tracef("Tunnel received %+v from %s", msg.Header, link.Remote())
|
||||||
|
|
||||||
// process the announcement
|
// process the announcement
|
||||||
channels := strings.Split(channel, ",")
|
channels := strings.Split(channel, ",")
|
||||||
|
|
||||||
@ -562,7 +630,10 @@ func (t *tun) listen(link *link) {
|
|||||||
link.setChannel(channels...)
|
link.setChannel(channels...)
|
||||||
|
|
||||||
// this was an announcement not intended for anything
|
// this was an announcement not intended for anything
|
||||||
if sessionId == "listener" || sessionId == "" {
|
// if the dialing side sent "discover" then a session
|
||||||
|
// id would be present. We skip in case of multicast.
|
||||||
|
switch sessionId {
|
||||||
|
case "listener", "multicast", "":
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -574,14 +645,19 @@ func (t *tun) listen(link *link) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the announce back to the caller
|
msg := &message{
|
||||||
s.recv <- &message{
|
|
||||||
typ: "announce",
|
typ: "announce",
|
||||||
tunnel: id,
|
tunnel: id,
|
||||||
channel: channel,
|
channel: channel,
|
||||||
session: sessionId,
|
session: sessionId,
|
||||||
link: link.id,
|
link: link.id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// send the announce back to the caller
|
||||||
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
case s.recv <- msg:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
case "discover":
|
case "discover":
|
||||||
@ -618,7 +694,7 @@ func (t *tun) listen(link *link) {
|
|||||||
s, exists = t.getSession(channel, "listener")
|
s, exists = t.getSession(channel, "listener")
|
||||||
// only return accept to the session
|
// only return accept to the session
|
||||||
case mtype == "accept":
|
case mtype == "accept":
|
||||||
log.Debugf("Received accept message for %s %s", channel, sessionId)
|
log.Debugf("Tunnel received accept message for channel: %s session: %s", channel, sessionId)
|
||||||
s, exists = t.getSession(channel, sessionId)
|
s, exists = t.getSession(channel, sessionId)
|
||||||
if exists && s.accepted {
|
if exists && s.accepted {
|
||||||
continue
|
continue
|
||||||
@ -638,7 +714,7 @@ func (t *tun) listen(link *link) {
|
|||||||
|
|
||||||
// bail if no session or listener has been found
|
// bail if no session or listener has been found
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debugf("Tunnel skipping no session %s %s exists", channel, sessionId)
|
log.Tracef("Tunnel skipping no channel: %s session: %s exists", channel, sessionId)
|
||||||
// drop it, we don't care about
|
// drop it, we don't care about
|
||||||
// messages we don't know about
|
// messages we don't know about
|
||||||
continue
|
continue
|
||||||
@ -651,22 +727,10 @@ func (t *tun) listen(link *link) {
|
|||||||
delete(t.sessions, channel)
|
delete(t.sessions, channel)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
// process
|
// otherwise process
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Tunnel using channel %s session %s", s.channel, s.session)
|
log.Tracef("Tunnel using channel: %s session: %s type: %s", s.channel, s.session, mtype)
|
||||||
|
|
||||||
// is the session new?
|
|
||||||
select {
|
|
||||||
// if its new the session is actually blocked waiting
|
|
||||||
// for a connection. so we check if its waiting.
|
|
||||||
case <-s.wait:
|
|
||||||
// if its waiting e.g its new then we close it
|
|
||||||
default:
|
|
||||||
// set remote address of the session
|
|
||||||
s.remote = msg.Header["Remote"]
|
|
||||||
close(s.wait)
|
|
||||||
}
|
|
||||||
|
|
||||||
// construct a new transport message
|
// construct a new transport message
|
||||||
tmsg := &transport.Message{
|
tmsg := &transport.Message{
|
||||||
@ -696,68 +760,26 @@ func (t *tun) listen(link *link) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// discover sends channel discover requests periodically
|
func (t *tun) sendMsg(method string, link *link) error {
|
||||||
func (t *tun) discover(link *link) {
|
return link.Send(&transport.Message{
|
||||||
tick := time.NewTicker(DiscoverTime)
|
|
||||||
defer tick.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-tick.C:
|
|
||||||
// send a discovery message to all links
|
|
||||||
if err := link.Send(&transport.Message{
|
|
||||||
Header: map[string]string{
|
Header: map[string]string{
|
||||||
"Micro-Tunnel": "discover",
|
"Micro-Tunnel": method,
|
||||||
"Micro-Tunnel-Id": t.id,
|
"Micro-Tunnel-Id": t.id,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
})
|
||||||
log.Debugf("Tunnel failed to send discover to link %s: %v", link.Remote(), err)
|
|
||||||
}
|
|
||||||
case <-link.closed:
|
|
||||||
return
|
|
||||||
case <-t.closed:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keepalive periodically sends keepalive messages to link
|
|
||||||
func (t *tun) keepalive(link *link) {
|
|
||||||
keepalive := time.NewTicker(KeepAliveTime)
|
|
||||||
defer keepalive.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-t.closed:
|
|
||||||
return
|
|
||||||
case <-link.closed:
|
|
||||||
return
|
|
||||||
case <-keepalive.C:
|
|
||||||
// send keepalive message
|
|
||||||
log.Debugf("Tunnel sending keepalive to link: %v", link.Remote())
|
|
||||||
if err := link.Send(&transport.Message{
|
|
||||||
Header: map[string]string{
|
|
||||||
"Micro-Tunnel": "keepalive",
|
|
||||||
"Micro-Tunnel-Id": t.id,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err)
|
|
||||||
t.delLink(link.Remote())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupLink connects to node and returns link if successful
|
// setupLink connects to node and returns link if successful
|
||||||
// It returns error if the link failed to be established
|
// It returns error if the link failed to be established
|
||||||
func (t *tun) setupLink(node string) (*link, error) {
|
func (t *tun) setupLink(node string) (*link, error) {
|
||||||
log.Debugf("Tunnel setting up link: %s", node)
|
log.Debugf("Tunnel setting up link: %s", node)
|
||||||
|
|
||||||
c, err := t.options.Transport.Dial(node)
|
c, err := t.options.Transport.Dial(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Tunnel failed to connect to %s: %v", node, err)
|
log.Debugf("Tunnel failed to connect to %s: %v", node, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Tunnel connected to %s", node)
|
log.Debugf("Tunnel connected to %s", node)
|
||||||
|
|
||||||
// create a new link
|
// create a new link
|
||||||
@ -766,12 +788,8 @@ func (t *tun) setupLink(node string) (*link, error) {
|
|||||||
link.id = c.Remote()
|
link.id = c.Remote()
|
||||||
|
|
||||||
// send the first connect message
|
// send the first connect message
|
||||||
if err := link.Send(&transport.Message{
|
if err := t.sendMsg("connect", link); err != nil {
|
||||||
Header: map[string]string{
|
link.Close()
|
||||||
"Micro-Tunnel": "connect",
|
|
||||||
"Micro-Tunnel-Id": t.id,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -782,37 +800,40 @@ func (t *tun) setupLink(node string) (*link, error) {
|
|||||||
// process incoming messages
|
// process incoming messages
|
||||||
go t.listen(link)
|
go t.listen(link)
|
||||||
|
|
||||||
// start keepalive monitor
|
// manage keepalives and discovery messages
|
||||||
go t.keepalive(link)
|
go t.manageLink(link)
|
||||||
|
|
||||||
// discover things on the remote side
|
|
||||||
go t.discover(link)
|
|
||||||
|
|
||||||
return link, nil
|
return link, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) setupLinks() {
|
func (t *tun) setupLinks() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for _, node := range t.options.Nodes {
|
for _, node := range t.options.Nodes {
|
||||||
// skip zero length nodes
|
wg.Add(1)
|
||||||
if len(node) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// link already exists
|
go func(node string) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// we're not trying to fix existing links
|
||||||
if _, ok := t.links[node]; ok {
|
if _, ok := t.links[node]; ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to node and return link
|
// create new link
|
||||||
link, err := t.setupLink(node)
|
link, err := t.setupLink(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Tunnel failed to establish node link to %s: %v", node, err)
|
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the link
|
// save the link
|
||||||
t.links[node] = link
|
t.links[node] = link
|
||||||
|
}(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for all threads to finish
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect the tunnel to all the nodes and listen for incoming tunnel connections
|
// connect the tunnel to all the nodes and listen for incoming tunnel connections
|
||||||
@ -833,11 +854,8 @@ func (t *tun) connect() error {
|
|||||||
// create a new link
|
// create a new link
|
||||||
link := newLink(sock)
|
link := newLink(sock)
|
||||||
|
|
||||||
// start keepalive monitor
|
// manage the link
|
||||||
go t.keepalive(link)
|
go t.manageLink(link)
|
||||||
|
|
||||||
// discover things on the remote side
|
|
||||||
go t.discover(link)
|
|
||||||
|
|
||||||
// listen for inbound messages.
|
// listen for inbound messages.
|
||||||
// only save the link once connected.
|
// only save the link once connected.
|
||||||
@ -864,12 +882,13 @@ func (t *tun) Connect() error {
|
|||||||
|
|
||||||
// already connected
|
// already connected
|
||||||
if t.connected {
|
if t.connected {
|
||||||
// setup links
|
// do it immediately
|
||||||
t.setupLinks()
|
t.setupLinks()
|
||||||
|
// setup links
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the connect message
|
// connect the tunnel: start the listener
|
||||||
if err := t.connect(); err != nil {
|
if err := t.connect(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -879,15 +898,15 @@ func (t *tun) Connect() error {
|
|||||||
// create new close channel
|
// create new close channel
|
||||||
t.closed = make(chan bool)
|
t.closed = make(chan bool)
|
||||||
|
|
||||||
// setup links
|
|
||||||
t.setupLinks()
|
|
||||||
|
|
||||||
// process outbound messages to be sent
|
// process outbound messages to be sent
|
||||||
// process sends to all links
|
// process sends to all links
|
||||||
go t.process()
|
go t.process()
|
||||||
|
|
||||||
// monitor links
|
// call setup before managing them
|
||||||
go t.monitor()
|
t.setupLinks()
|
||||||
|
|
||||||
|
// manage the links
|
||||||
|
go t.manage()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1025,7 +1044,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
// set the multicast option
|
// set the multicast option
|
||||||
c.mode = options.Mode
|
c.mode = options.Mode
|
||||||
// set the dial timeout
|
// set the dial timeout
|
||||||
c.timeout = options.Timeout
|
c.dialTimeout = options.Timeout
|
||||||
|
// set read timeout set to never
|
||||||
|
c.readTimeout = time.Duration(-1)
|
||||||
|
|
||||||
var links []*link
|
var links []*link
|
||||||
// did we measure the rtt
|
// did we measure the rtt
|
||||||
@ -1052,7 +1073,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
|
|
||||||
t.RUnlock()
|
t.RUnlock()
|
||||||
|
|
||||||
// link not found
|
// link not found and one was specified so error out
|
||||||
if len(links) == 0 && len(options.Link) > 0 {
|
if len(links) == 0 && len(options.Link) > 0 {
|
||||||
// delete session and return error
|
// delete session and return error
|
||||||
t.delSession(c.channel, c.session)
|
t.delSession(c.channel, c.session)
|
||||||
@ -1061,15 +1082,14 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// discovered so set the link if not multicast
|
// discovered so set the link if not multicast
|
||||||
// TODO: pick the link efficiently based
|
|
||||||
// on link status and saturation.
|
|
||||||
if c.discovered && c.mode == Unicast {
|
if c.discovered && c.mode == Unicast {
|
||||||
// pickLink will pick the best link
|
// pickLink will pick the best link
|
||||||
link := t.pickLink(links)
|
link := t.pickLink(links)
|
||||||
|
// set the link
|
||||||
c.link = link.id
|
c.link = link.id
|
||||||
}
|
}
|
||||||
|
|
||||||
// shit fuck
|
// if its not already discovered we need to attempt to do so
|
||||||
if !c.discovered {
|
if !c.discovered {
|
||||||
// piggy back roundtrip
|
// piggy back roundtrip
|
||||||
nowRTT := time.Now()
|
nowRTT := time.Now()
|
||||||
@ -1098,7 +1118,15 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// a unicast session so we call "open" and wait for an "accept"
|
// return early if its not unicast
|
||||||
|
// we will not call "open" for multicast
|
||||||
|
if c.mode != Unicast {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: we go no further for multicast or broadcast.
|
||||||
|
// This is a unicast session so we call "open" and wait
|
||||||
|
// for an "accept"
|
||||||
|
|
||||||
// reset now in case we use it
|
// reset now in case we use it
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@ -1115,7 +1143,7 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
d := time.Since(now)
|
d := time.Since(now)
|
||||||
|
|
||||||
// if we haven't measured the roundtrip do it now
|
// if we haven't measured the roundtrip do it now
|
||||||
if !measured && c.mode == Unicast {
|
if !measured {
|
||||||
// set the link time
|
// set the link time
|
||||||
t.RLock()
|
t.RLock()
|
||||||
link, ok := t.links[c.link]
|
link, ok := t.links[c.link]
|
||||||
@ -1134,7 +1162,11 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
||||||
log.Debugf("Tunnel listening on %s", channel)
|
log.Debugf("Tunnel listening on %s", channel)
|
||||||
|
|
||||||
var options ListenOptions
|
options := ListenOptions{
|
||||||
|
// Read timeout defaults to never
|
||||||
|
Timeout: time.Duration(-1),
|
||||||
|
}
|
||||||
|
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o(&options)
|
o(&options)
|
||||||
}
|
}
|
||||||
@ -1145,6 +1177,7 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
|||||||
return nil, errors.New("already listening on " + channel)
|
return nil, errors.New("already listening on " + channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete function removes the session when closed
|
||||||
delFunc := func() {
|
delFunc := func() {
|
||||||
t.delSession(channel, "listener")
|
t.delSession(channel, "listener")
|
||||||
}
|
}
|
||||||
@ -1155,6 +1188,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
|||||||
c.local = channel
|
c.local = channel
|
||||||
// set mode
|
// set mode
|
||||||
c.mode = options.Mode
|
c.mode = options.Mode
|
||||||
|
// set the timeout
|
||||||
|
c.readTimeout = options.Timeout
|
||||||
|
|
||||||
tl := &tunListener{
|
tl := &tunListener{
|
||||||
channel: channel,
|
channel: channel,
|
||||||
|
@ -2,6 +2,7 @@ package tunnel
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -14,7 +15,11 @@ import (
|
|||||||
type link struct {
|
type link struct {
|
||||||
transport.Socket
|
transport.Socket
|
||||||
|
|
||||||
|
// transport to use for connections
|
||||||
|
transport transport.Transport
|
||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
// stops the link
|
// stops the link
|
||||||
closed chan bool
|
closed chan bool
|
||||||
// link state channel for testing link
|
// link state channel for testing link
|
||||||
@ -65,6 +70,8 @@ var (
|
|||||||
linkRequest = []byte{0, 0, 0, 0}
|
linkRequest = []byte{0, 0, 0, 0}
|
||||||
// the 4 byte 1 filled packet sent to determine link state
|
// the 4 byte 1 filled packet sent to determine link state
|
||||||
linkResponse = []byte{1, 1, 1, 1}
|
linkResponse = []byte{1, 1, 1, 1}
|
||||||
|
|
||||||
|
ErrLinkConnectTimeout = errors.New("link connect timeout")
|
||||||
)
|
)
|
||||||
|
|
||||||
func newLink(s transport.Socket) *link {
|
func newLink(s transport.Socket) *link {
|
||||||
@ -72,8 +79,8 @@ func newLink(s transport.Socket) *link {
|
|||||||
Socket: s,
|
Socket: s,
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
lastKeepAlive: time.Now(),
|
lastKeepAlive: time.Now(),
|
||||||
channels: make(map[string]time.Time),
|
|
||||||
closed: make(chan bool),
|
closed: make(chan bool),
|
||||||
|
channels: make(map[string]time.Time),
|
||||||
state: make(chan *packet, 64),
|
state: make(chan *packet, 64),
|
||||||
sendQueue: make(chan *packet, 128),
|
sendQueue: make(chan *packet, 128),
|
||||||
recvQueue: make(chan *packet, 128),
|
recvQueue: make(chan *packet, 128),
|
||||||
@ -87,6 +94,32 @@ func newLink(s transport.Socket) *link {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *link) connect(addr string) error {
|
||||||
|
c, err := l.transport.Dial(addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Lock()
|
||||||
|
l.Socket = c
|
||||||
|
l.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *link) accept(sock transport.Socket) error {
|
||||||
|
l.Lock()
|
||||||
|
l.Socket = sock
|
||||||
|
l.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *link) setLoopback(v bool) {
|
||||||
|
l.Lock()
|
||||||
|
l.loopback = v
|
||||||
|
l.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// setRate sets the bits per second rate as a float64
|
// setRate sets the bits per second rate as a float64
|
||||||
func (l *link) setRate(bits int64, delta time.Duration) {
|
func (l *link) setRate(bits int64, delta time.Duration) {
|
||||||
// rate of send in bits per nanosecond
|
// rate of send in bits per nanosecond
|
||||||
@ -167,6 +200,8 @@ func (l *link) process() {
|
|||||||
// process link state message
|
// process link state message
|
||||||
select {
|
select {
|
||||||
case l.state <- pk:
|
case l.state <- pk:
|
||||||
|
case <-l.closed:
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@ -188,7 +223,11 @@ func (l *link) process() {
|
|||||||
select {
|
select {
|
||||||
case pk := <-l.sendQueue:
|
case pk := <-l.sendQueue:
|
||||||
// send the message
|
// send the message
|
||||||
pk.status <- l.send(pk.message)
|
select {
|
||||||
|
case pk.status <- l.send(pk.message):
|
||||||
|
case <-l.closed:
|
||||||
|
return
|
||||||
|
}
|
||||||
case <-l.closed:
|
case <-l.closed:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -201,11 +240,15 @@ func (l *link) manage() {
|
|||||||
t := time.NewTicker(time.Minute)
|
t := time.NewTicker(time.Minute)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
|
|
||||||
|
// get link id
|
||||||
|
linkId := l.Id()
|
||||||
|
|
||||||
// used to send link state packets
|
// used to send link state packets
|
||||||
send := func(b []byte) error {
|
send := func(b []byte) error {
|
||||||
return l.Send(&transport.Message{
|
return l.Send(&transport.Message{
|
||||||
Header: map[string]string{
|
Header: map[string]string{
|
||||||
"Micro-Method": "link",
|
"Micro-Method": "link",
|
||||||
|
"Micro-Link-Id": linkId,
|
||||||
}, Body: b,
|
}, Body: b,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -229,9 +272,7 @@ func (l *link) manage() {
|
|||||||
// check the type of message
|
// check the type of message
|
||||||
switch {
|
switch {
|
||||||
case bytes.Equal(p.message.Body, linkRequest):
|
case bytes.Equal(p.message.Body, linkRequest):
|
||||||
l.RLock()
|
log.Tracef("Link %s received link request", linkId)
|
||||||
log.Tracef("Link %s received link request %v", l.id, p.message.Body)
|
|
||||||
l.RUnlock()
|
|
||||||
|
|
||||||
// send response
|
// send response
|
||||||
if err := send(linkResponse); err != nil {
|
if err := send(linkResponse); err != nil {
|
||||||
@ -242,9 +283,7 @@ func (l *link) manage() {
|
|||||||
case bytes.Equal(p.message.Body, linkResponse):
|
case bytes.Equal(p.message.Body, linkResponse):
|
||||||
// set round trip time
|
// set round trip time
|
||||||
d := time.Since(now)
|
d := time.Since(now)
|
||||||
l.RLock()
|
log.Tracef("Link %s received link response in %v", linkId, d)
|
||||||
log.Tracef("Link %s received link response in %v", p.message.Body, d)
|
|
||||||
l.RUnlock()
|
|
||||||
// set the RTT
|
// set the RTT
|
||||||
l.setRTT(d)
|
l.setRTT(d)
|
||||||
}
|
}
|
||||||
@ -309,6 +348,12 @@ func (l *link) Rate() float64 {
|
|||||||
return l.rate
|
return l.rate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *link) Loopback() bool {
|
||||||
|
l.RLock()
|
||||||
|
defer l.RUnlock()
|
||||||
|
return l.loopback
|
||||||
|
}
|
||||||
|
|
||||||
// Length returns the roundtrip time as nanoseconds (lower is better).
|
// Length returns the roundtrip time as nanoseconds (lower is better).
|
||||||
// Returns 0 where no measurement has been taken.
|
// Returns 0 where no measurement has been taken.
|
||||||
func (l *link) Length() int64 {
|
func (l *link) Length() int64 {
|
||||||
@ -320,7 +365,6 @@ func (l *link) Length() int64 {
|
|||||||
func (l *link) Id() string {
|
func (l *link) Id() string {
|
||||||
l.RLock()
|
l.RLock()
|
||||||
defer l.RUnlock()
|
defer l.RUnlock()
|
||||||
|
|
||||||
return l.id
|
return l.id
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,13 +394,6 @@ func (l *link) Send(m *transport.Message) error {
|
|||||||
// get time now
|
// get time now
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// check if its closed first
|
|
||||||
select {
|
|
||||||
case <-l.closed:
|
|
||||||
return io.EOF
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// queue the message
|
// queue the message
|
||||||
select {
|
select {
|
||||||
case l.sendQueue <- p:
|
case l.sendQueue <- p:
|
||||||
|
@ -24,7 +24,7 @@ type tunListener struct {
|
|||||||
delFunc func()
|
delFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// periodically announce self
|
// periodically announce self the channel being listened on
|
||||||
func (t *tunListener) announce() {
|
func (t *tunListener) announce() {
|
||||||
tick := time.NewTicker(time.Second * 30)
|
tick := time.NewTicker(time.Second * 30)
|
||||||
defer tick.Stop()
|
defer tick.Stop()
|
||||||
@ -48,9 +48,12 @@ func (t *tunListener) process() {
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// close the sessions
|
// close the sessions
|
||||||
for _, conn := range conns {
|
for id, conn := range conns {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
delete(conns, id)
|
||||||
}
|
}
|
||||||
|
// unassign
|
||||||
|
conns = nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -62,9 +65,24 @@ func (t *tunListener) process() {
|
|||||||
return
|
return
|
||||||
// receive a new message
|
// receive a new message
|
||||||
case m := <-t.session.recv:
|
case m := <-t.session.recv:
|
||||||
|
var sessionId string
|
||||||
|
var linkId string
|
||||||
|
|
||||||
|
switch m.mode {
|
||||||
|
case Multicast:
|
||||||
|
sessionId = "multicast"
|
||||||
|
linkId = "multicast"
|
||||||
|
case Broadcast:
|
||||||
|
sessionId = "broadcast"
|
||||||
|
linkId = "broadcast"
|
||||||
|
default:
|
||||||
|
sessionId = m.session
|
||||||
|
linkId = m.link
|
||||||
|
}
|
||||||
|
|
||||||
// get a session
|
// get a session
|
||||||
sess, ok := conns[m.session]
|
sess, ok := conns[sessionId]
|
||||||
log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok)
|
log.Tracef("Tunnel listener received channel %s session %s type %s exists: %t", m.channel, m.session, m.typ, ok)
|
||||||
if !ok {
|
if !ok {
|
||||||
// we only process open and session types
|
// we only process open and session types
|
||||||
switch m.typ {
|
switch m.typ {
|
||||||
@ -80,13 +98,13 @@ func (t *tunListener) process() {
|
|||||||
// the channel
|
// the channel
|
||||||
channel: m.channel,
|
channel: m.channel,
|
||||||
// the session id
|
// the session id
|
||||||
session: m.session,
|
session: sessionId,
|
||||||
// tunnel token
|
// tunnel token
|
||||||
token: t.token,
|
token: t.token,
|
||||||
// is loopback conn
|
// is loopback conn
|
||||||
loopback: m.loopback,
|
loopback: m.loopback,
|
||||||
// the link the message was received on
|
// the link the message was received on
|
||||||
link: m.link,
|
link: linkId,
|
||||||
// set the connection mode
|
// set the connection mode
|
||||||
mode: m.mode,
|
mode: m.mode,
|
||||||
// close chan
|
// close chan
|
||||||
@ -95,14 +113,14 @@ func (t *tunListener) process() {
|
|||||||
recv: make(chan *message, 128),
|
recv: make(chan *message, 128),
|
||||||
// use the internal send buffer
|
// use the internal send buffer
|
||||||
send: t.session.send,
|
send: t.session.send,
|
||||||
// wait
|
|
||||||
wait: make(chan bool),
|
|
||||||
// error channel
|
// error channel
|
||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
|
// set the read timeout
|
||||||
|
readTimeout: t.session.readTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the session
|
// save the session
|
||||||
conns[m.session] = sess
|
conns[sessionId] = sess
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-t.closed:
|
case <-t.closed:
|
||||||
@ -114,17 +132,19 @@ func (t *tunListener) process() {
|
|||||||
|
|
||||||
// an existing session was found
|
// an existing session was found
|
||||||
|
|
||||||
// received a close message
|
|
||||||
switch m.typ {
|
switch m.typ {
|
||||||
case "close":
|
case "close":
|
||||||
|
// received a close message
|
||||||
select {
|
select {
|
||||||
|
// check if the session is closed
|
||||||
case <-sess.closed:
|
case <-sess.closed:
|
||||||
// no op
|
// no op
|
||||||
delete(conns, m.session)
|
delete(conns, sessionId)
|
||||||
default:
|
default:
|
||||||
|
// only close if unicast session
|
||||||
// close and delete session
|
// close and delete session
|
||||||
close(sess.closed)
|
close(sess.closed)
|
||||||
delete(conns, m.session)
|
delete(conns, sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// continue
|
// continue
|
||||||
@ -139,9 +159,9 @@ func (t *tunListener) process() {
|
|||||||
// send this to the accept chan
|
// send this to the accept chan
|
||||||
select {
|
select {
|
||||||
case <-sess.closed:
|
case <-sess.closed:
|
||||||
delete(conns, m.session)
|
delete(conns, sessionId)
|
||||||
case sess.recv <- m:
|
case sess.recv <- m:
|
||||||
log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session)
|
log.Tracef("Tunnel listener sent to recv chan channel %s session %s type %s", m.channel, sessionId, m.typ)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -180,6 +200,10 @@ func (t *tunListener) Accept() (Session, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
// return without accept
|
||||||
|
if c.mode != Unicast {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
// send back the accept
|
// send back the accept
|
||||||
if err := c.Accept(); err != nil {
|
if err := c.Accept(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -47,6 +47,8 @@ type ListenOption func(*ListenOptions)
|
|||||||
type ListenOptions struct {
|
type ListenOptions struct {
|
||||||
// specify mode of the session
|
// specify mode of the session
|
||||||
Mode Mode
|
Mode Mode
|
||||||
|
// The read timeout
|
||||||
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tunnel id
|
// The tunnel id
|
||||||
@ -84,16 +86,6 @@ func Transport(t transport.Transport) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultOptions returns router default options
|
|
||||||
func DefaultOptions() Options {
|
|
||||||
return Options{
|
|
||||||
Id: uuid.New().String(),
|
|
||||||
Address: DefaultAddress,
|
|
||||||
Token: DefaultToken,
|
|
||||||
Transport: quic.NewTransport(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen options
|
// Listen options
|
||||||
func ListenMode(m Mode) ListenOption {
|
func ListenMode(m Mode) ListenOption {
|
||||||
return func(o *ListenOptions) {
|
return func(o *ListenOptions) {
|
||||||
@ -101,6 +93,13 @@ func ListenMode(m Mode) ListenOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Timeout for reads and writes on the listener session
|
||||||
|
func ListenTimeout(t time.Duration) ListenOption {
|
||||||
|
return func(o *ListenOptions) {
|
||||||
|
o.Timeout = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Dial options
|
// Dial options
|
||||||
|
|
||||||
// Dial multicast sets the multicast option to send only to those mapped
|
// Dial multicast sets the multicast option to send only to those mapped
|
||||||
@ -124,3 +123,13 @@ func DialLink(id string) DialOption {
|
|||||||
o.Link = id
|
o.Link = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultOptions returns router default options
|
||||||
|
func DefaultOptions() Options {
|
||||||
|
return Options{
|
||||||
|
Id: uuid.New().String(),
|
||||||
|
Address: DefaultAddress,
|
||||||
|
Token: DefaultToken,
|
||||||
|
Transport: quic.NewTransport(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,7 +2,6 @@ package tunnel
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -30,8 +29,6 @@ type session struct {
|
|||||||
send chan *message
|
send chan *message
|
||||||
// recv chan
|
// recv chan
|
||||||
recv chan *message
|
recv chan *message
|
||||||
// wait until we have a connection
|
|
||||||
wait chan bool
|
|
||||||
// if the discovery worked
|
// if the discovery worked
|
||||||
discovered bool
|
discovered bool
|
||||||
// if the session was accepted
|
// if the session was accepted
|
||||||
@ -42,8 +39,10 @@ type session struct {
|
|||||||
loopback bool
|
loopback bool
|
||||||
// mode of the connection
|
// mode of the connection
|
||||||
mode Mode
|
mode Mode
|
||||||
// the timeout
|
// the dial timeout
|
||||||
timeout time.Duration
|
dialTimeout time.Duration
|
||||||
|
// the read timeout
|
||||||
|
readTimeout time.Duration
|
||||||
// the link on which this message was received
|
// the link on which this message was received
|
||||||
link string
|
link string
|
||||||
// the error response
|
// the error response
|
||||||
@ -109,65 +108,114 @@ func (s *session) newMessage(typ string) *message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) sendMsg(msg *message) error {
|
||||||
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
return io.EOF
|
||||||
|
case s.send <- msg:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) wait(msg *message) error {
|
||||||
|
// wait for an error response
|
||||||
|
select {
|
||||||
|
case err := <-msg.errChan:
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-s.closed:
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// waitFor waits for the message type required until the timeout specified
|
// waitFor waits for the message type required until the timeout specified
|
||||||
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
|
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
after := func(timeout time.Duration) time.Duration {
|
after := func(timeout time.Duration) <-chan time.Time {
|
||||||
|
if timeout < time.Duration(0) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the delta
|
||||||
d := time.Since(now)
|
d := time.Since(now)
|
||||||
|
|
||||||
// dial timeout minus time since
|
// dial timeout minus time since
|
||||||
wait := timeout - d
|
wait := timeout - d
|
||||||
|
|
||||||
if wait < time.Duration(0) {
|
if wait < time.Duration(0) {
|
||||||
return time.Duration(0)
|
wait = time.Duration(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
return wait
|
return time.After(wait)
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the message type
|
// wait for the message type
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-s.recv:
|
case msg := <-s.recv:
|
||||||
|
// there may be no message type
|
||||||
|
if len(msgType) == 0 {
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ignore what we don't want
|
// ignore what we don't want
|
||||||
if msg.typ != msgType {
|
if msg.typ != msgType {
|
||||||
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// got the message
|
// got the message
|
||||||
return msg, nil
|
return msg, nil
|
||||||
case <-time.After(after(timeout)):
|
case <-after(timeout):
|
||||||
return nil, ErrDialTimeout
|
return nil, ErrReadTimeout
|
||||||
case <-s.closed:
|
case <-s.closed:
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discover attempts to discover the link for a specific channel
|
// Discover attempts to discover the link for a specific channel.
|
||||||
|
// This is only used by the tunnel.Dial when first connecting.
|
||||||
func (s *session) Discover() error {
|
func (s *session) Discover() error {
|
||||||
// create a new discovery message for this channel
|
// create a new discovery message for this channel
|
||||||
msg := s.newMessage("discover")
|
msg := s.newMessage("discover")
|
||||||
|
// broadcast the message to all links
|
||||||
msg.mode = Broadcast
|
msg.mode = Broadcast
|
||||||
|
// its an outbound connection since we're dialling
|
||||||
msg.outbound = true
|
msg.outbound = true
|
||||||
|
// don't set the link since we don't know where it is
|
||||||
msg.link = ""
|
msg.link = ""
|
||||||
|
|
||||||
// send the discovery message
|
// if multicast then set that as session
|
||||||
s.send <- msg
|
if s.mode == Multicast {
|
||||||
|
msg.session = "multicast"
|
||||||
|
}
|
||||||
|
|
||||||
|
// send discover message
|
||||||
|
if err := s.sendMsg(msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// set time now
|
// set time now
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
|
// after strips down the dial timeout
|
||||||
after := func() time.Duration {
|
after := func() time.Duration {
|
||||||
d := time.Since(now)
|
d := time.Since(now)
|
||||||
// dial timeout minus time since
|
// dial timeout minus time since
|
||||||
wait := s.timeout - d
|
wait := s.dialTimeout - d
|
||||||
|
// make sure its always > 0
|
||||||
if wait < time.Duration(0) {
|
if wait < time.Duration(0) {
|
||||||
return time.Duration(0)
|
return time.Duration(0)
|
||||||
}
|
}
|
||||||
return wait
|
return wait
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the discover message is sent out, now
|
||||||
// wait to hear back about the sent message
|
// wait to hear back about the sent message
|
||||||
select {
|
select {
|
||||||
case <-time.After(after()):
|
case <-time.After(after()):
|
||||||
@ -178,27 +226,16 @@ func (s *session) Discover() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
// bail early if its not unicast
|
||||||
|
// we don't need to wait for the announce
|
||||||
// set a new dialTimeout
|
|
||||||
dialTimeout := after()
|
|
||||||
|
|
||||||
// set a shorter delay for multicast
|
|
||||||
if s.mode != Unicast {
|
|
||||||
// shorten this
|
|
||||||
dialTimeout = time.Millisecond * 500
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for announce
|
|
||||||
_, err = s.waitFor("announce", dialTimeout)
|
|
||||||
|
|
||||||
// if its multicast just go ahead because this is best effort
|
|
||||||
if s.mode != Unicast {
|
if s.mode != Unicast {
|
||||||
s.discovered = true
|
s.discovered = true
|
||||||
s.accepted = true
|
s.accepted = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait for announce
|
||||||
|
_, err := s.waitFor("announce", after())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -210,31 +247,23 @@ func (s *session) Discover() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Open will fire the open message for the session. This is called by the dialler.
|
// Open will fire the open message for the session. This is called by the dialler.
|
||||||
|
// This is to indicate that we want to create a new session.
|
||||||
func (s *session) Open() error {
|
func (s *session) Open() error {
|
||||||
// create a new message
|
// create a new message
|
||||||
msg := s.newMessage("open")
|
msg := s.newMessage("open")
|
||||||
|
|
||||||
// send open message
|
// send open message
|
||||||
s.send <- msg
|
if err := s.sendMsg(msg); err != nil {
|
||||||
|
|
||||||
// wait for an error response for send
|
|
||||||
select {
|
|
||||||
case err := <-msg.errChan:
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case <-s.closed:
|
|
||||||
return io.EOF
|
// wait for an error response for send
|
||||||
|
if err := s.wait(msg); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't wait on multicast/broadcast
|
// now wait for the accept message to be returned
|
||||||
if s.mode == Multicast {
|
msg, err := s.waitFor("accept", s.dialTimeout)
|
||||||
s.accepted = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// now wait for the accept
|
|
||||||
msg, err := s.waitFor("accept", s.timeout)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -252,32 +281,16 @@ func (s *session) Accept() error {
|
|||||||
msg := s.newMessage("accept")
|
msg := s.newMessage("accept")
|
||||||
|
|
||||||
// send the accept message
|
// send the accept message
|
||||||
select {
|
if err := s.sendMsg(msg); err != nil {
|
||||||
case <-s.closed:
|
return err
|
||||||
return io.EOF
|
|
||||||
case s.send <- msg:
|
|
||||||
// no op here
|
|
||||||
}
|
|
||||||
|
|
||||||
// don't wait on multicast/broadcast
|
|
||||||
if s.mode == Multicast {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for send response
|
// wait for send response
|
||||||
select {
|
return s.wait(msg)
|
||||||
case err := <-s.errChan:
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case <-s.closed:
|
|
||||||
return io.EOF
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Announce sends an announcement to notify that this session exists.
|
||||||
}
|
// This is primarily used by the listener.
|
||||||
|
|
||||||
// Announce sends an announcement to notify that this session exists. This is primarily used by the listener.
|
|
||||||
func (s *session) Announce() error {
|
func (s *session) Announce() error {
|
||||||
msg := s.newMessage("announce")
|
msg := s.newMessage("announce")
|
||||||
// we don't need an error back
|
// we don't need an error back
|
||||||
@ -287,23 +300,12 @@ func (s *session) Announce() error {
|
|||||||
// we don't need the link
|
// we don't need the link
|
||||||
msg.link = ""
|
msg.link = ""
|
||||||
|
|
||||||
select {
|
// send announce message
|
||||||
case s.send <- msg:
|
return s.sendMsg(msg)
|
||||||
return nil
|
|
||||||
case <-s.closed:
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send is used to send a message
|
// Send is used to send a message
|
||||||
func (s *session) Send(m *transport.Message) error {
|
func (s *session) Send(m *transport.Message) error {
|
||||||
select {
|
|
||||||
case <-s.closed:
|
|
||||||
return io.EOF
|
|
||||||
default:
|
|
||||||
// no op
|
|
||||||
}
|
|
||||||
|
|
||||||
// encrypt the transport message payload
|
// encrypt the transport message payload
|
||||||
body, err := Encrypt(m.Body, s.token+s.channel+s.session)
|
body, err := Encrypt(m.Body, s.token+s.channel+s.session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -335,32 +337,28 @@ func (s *session) Send(m *transport.Message) error {
|
|||||||
msg.data = data
|
msg.data = data
|
||||||
|
|
||||||
// if multicast don't set the link
|
// if multicast don't set the link
|
||||||
if s.mode == Multicast {
|
if s.mode != Unicast {
|
||||||
msg.link = ""
|
msg.link = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Appending %+v to send backlog", msg)
|
log.Tracef("Appending %+v to send backlog", msg)
|
||||||
|
|
||||||
// send the actual message
|
// send the actual message
|
||||||
s.send <- msg
|
if err := s.sendMsg(msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// wait for an error response
|
// wait for an error response
|
||||||
select {
|
return s.wait(msg)
|
||||||
case err := <-msg.errChan:
|
|
||||||
return err
|
|
||||||
case <-s.closed:
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recv is used to receive a message
|
// Recv is used to receive a message
|
||||||
func (s *session) Recv(m *transport.Message) error {
|
func (s *session) Recv(m *transport.Message) error {
|
||||||
var msg *message
|
var msg *message
|
||||||
|
|
||||||
select {
|
msg, err := s.waitFor("", s.readTimeout)
|
||||||
case <-s.closed:
|
if err != nil {
|
||||||
return errors.New("session is closed")
|
return err
|
||||||
// recv from backlog
|
|
||||||
case msg = <-s.recv:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check the error if one exists
|
// check the error if one exists
|
||||||
@ -371,10 +369,13 @@ func (s *session) Recv(m *transport.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//log.Tracef("Received %+v from recv backlog", msg)
|
//log.Tracef("Received %+v from recv backlog", msg)
|
||||||
log.Debugf("Received %+v from recv backlog", msg)
|
log.Tracef("Received %+v from recv backlog", msg)
|
||||||
|
|
||||||
// decrypt the received payload using the token
|
// decrypt the received payload using the token
|
||||||
body, err := Decrypt(msg.data.Body, s.token+s.channel+s.session)
|
// we have to used msg.session because multicast has a shared
|
||||||
|
// session id of "multicast" in this session struct on
|
||||||
|
// the listener side
|
||||||
|
body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to decrypt message body: %v", err)
|
log.Debugf("failed to decrypt message body: %v", err)
|
||||||
return err
|
return err
|
||||||
@ -390,7 +391,7 @@ func (s *session) Recv(m *transport.Message) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// encrypt the transport message payload
|
// encrypt the transport message payload
|
||||||
val, err := Decrypt([]byte(h), s.token+s.channel+s.session)
|
val, err := Decrypt([]byte(h), s.token+s.channel+msg.session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to decrypt message header %s: %v", k, err)
|
log.Debugf("failed to decrypt message header %s: %v", k, err)
|
||||||
return err
|
return err
|
||||||
@ -399,6 +400,12 @@ func (s *session) Recv(m *transport.Message) error {
|
|||||||
msg.data.Header[k] = string(val)
|
msg.data.Header[k] = string(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set the link
|
||||||
|
// TODO: decruft, this is only for multicast
|
||||||
|
// since the session is now a single session
|
||||||
|
// likely provide as part of message.Link()
|
||||||
|
msg.data.Header["Micro-Link"] = msg.link
|
||||||
|
|
||||||
// set message
|
// set message
|
||||||
*m = *msg.data
|
*m = *msg.data
|
||||||
// return nil
|
// return nil
|
||||||
@ -413,6 +420,11 @@ func (s *session) Close() error {
|
|||||||
default:
|
default:
|
||||||
close(s.closed)
|
close(s.closed)
|
||||||
|
|
||||||
|
// don't send close on multicast or broadcast
|
||||||
|
if s.mode != Unicast {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// append to backlog
|
// append to backlog
|
||||||
msg := s.newMessage("close")
|
msg := s.newMessage("close")
|
||||||
// no error response on close
|
// no error response on close
|
||||||
@ -421,7 +433,7 @@ func (s *session) Close() error {
|
|||||||
// send the close message
|
// send the close message
|
||||||
select {
|
select {
|
||||||
case s.send <- msg:
|
case s.send <- msg:
|
||||||
default:
|
case <-time.After(time.Millisecond * 10):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,6 +26,8 @@ var (
|
|||||||
ErrDiscoverChan = errors.New("failed to discover channel")
|
ErrDiscoverChan = errors.New("failed to discover channel")
|
||||||
// ErrLinkNotFound is returned when a link is specified at dial time and does not exist
|
// ErrLinkNotFound is returned when a link is specified at dial time and does not exist
|
||||||
ErrLinkNotFound = errors.New("link not found")
|
ErrLinkNotFound = errors.New("link not found")
|
||||||
|
// ErrReadTimeout is a timeout on session.Recv
|
||||||
|
ErrReadTimeout = errors.New("read timeout")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Mode of the session
|
// Mode of the session
|
||||||
@ -64,7 +66,9 @@ type Link interface {
|
|||||||
Length() int64
|
Length() int64
|
||||||
// Current transfer rate as bits per second (lower is better)
|
// Current transfer rate as bits per second (lower is better)
|
||||||
Rate() float64
|
Rate() float64
|
||||||
// State of the link e.g connected/closed
|
// Is this a loopback link
|
||||||
|
Loopback() bool
|
||||||
|
// State of the link: connected/closed/error
|
||||||
State() string
|
State() string
|
||||||
// honours transport socket
|
// honours transport socket
|
||||||
transport.Socket
|
transport.Socket
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
// +build !race
|
|
||||||
|
|
||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReconnectTunnel(t *testing.T) {
|
|
||||||
// create a new tunnel client
|
|
||||||
tunA := NewTunnel(
|
|
||||||
Address("127.0.0.1:9096"),
|
|
||||||
Nodes("127.0.0.1:9097"),
|
|
||||||
)
|
|
||||||
|
|
||||||
// create a new tunnel server
|
|
||||||
tunB := NewTunnel(
|
|
||||||
Address("127.0.0.1:9097"),
|
|
||||||
)
|
|
||||||
|
|
||||||
// start tunnel
|
|
||||||
err := tunB.Connect()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tunB.Close()
|
|
||||||
|
|
||||||
// we manually override the tunnel.ReconnectTime value here
|
|
||||||
// this is so that we make the reconnects faster than the default 5s
|
|
||||||
ReconnectTime = 200 * time.Millisecond
|
|
||||||
|
|
||||||
// start tunnel
|
|
||||||
err = tunA.Connect()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tunA.Close()
|
|
||||||
|
|
||||||
wait := make(chan bool)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
// start tunnel listener
|
|
||||||
go testBrokenTunAccept(t, tunB, wait, &wg)
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
// start tunnel sender
|
|
||||||
go testBrokenTunSend(t, tunA, wait, &wg)
|
|
||||||
|
|
||||||
// wait until done
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
@ -8,6 +8,90 @@ import (
|
|||||||
"github.com/micro/go-micro/transport"
|
"github.com/micro/go-micro/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// listen on some virtual address
|
||||||
|
tl, err := tun.Listen("test-tunnel")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiver ready; notify sender
|
||||||
|
wait <- true
|
||||||
|
|
||||||
|
// accept a connection
|
||||||
|
c, err := tl.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// accept the message and close the tunnel
|
||||||
|
// we do this to simulate loss of network connection
|
||||||
|
m := new(transport.Message)
|
||||||
|
if err := c.Recv(m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// close all the links
|
||||||
|
for _, link := range tun.Links() {
|
||||||
|
link.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiver ready; notify sender
|
||||||
|
wait <- true
|
||||||
|
|
||||||
|
// accept the message
|
||||||
|
m = new(transport.Message)
|
||||||
|
if err := c.Recv(m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// notify the sender we have received
|
||||||
|
wait <- true
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup, reconnect time.Duration) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// wait for the listener to get ready
|
||||||
|
<-wait
|
||||||
|
|
||||||
|
// dial a new session
|
||||||
|
c, err := tun.Dial("test-tunnel")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
m := transport.Message{
|
||||||
|
Header: map[string]string{
|
||||||
|
"test": "send",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the message
|
||||||
|
if err := c.Send(&m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the listener to get ready
|
||||||
|
<-wait
|
||||||
|
|
||||||
|
// give it time to reconnect
|
||||||
|
time.Sleep(reconnect)
|
||||||
|
|
||||||
|
// send the message
|
||||||
|
if err := c.Send(&m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the listener to receive the message
|
||||||
|
// c.Send merely enqueues the message to the link send queue and returns
|
||||||
|
// in order to verify it was received we wait for the listener to tell us
|
||||||
|
<-wait
|
||||||
|
}
|
||||||
|
|
||||||
// testAccept will accept connections on the transport, create a new link and tunnel on top
|
// testAccept will accept connections on the transport, create a new link and tunnel on top
|
||||||
func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
|
func testAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -163,90 +247,6 @@ func TestLoopbackTunnel(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
// listen on some virtual address
|
|
||||||
tl, err := tun.Listen("test-tunnel")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// receiver ready; notify sender
|
|
||||||
wait <- true
|
|
||||||
|
|
||||||
// accept a connection
|
|
||||||
c, err := tl.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// accept the message and close the tunnel
|
|
||||||
// we do this to simulate loss of network connection
|
|
||||||
m := new(transport.Message)
|
|
||||||
if err := c.Recv(m); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// close all the links
|
|
||||||
for _, link := range tun.Links() {
|
|
||||||
link.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// receiver ready; notify sender
|
|
||||||
wait <- true
|
|
||||||
|
|
||||||
// accept the message
|
|
||||||
m = new(transport.Message)
|
|
||||||
if err := c.Recv(m); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// notify the sender we have received
|
|
||||||
wait <- true
|
|
||||||
}
|
|
||||||
|
|
||||||
func testBrokenTunSend(t *testing.T, tun Tunnel, wait chan bool, wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
// wait for the listener to get ready
|
|
||||||
<-wait
|
|
||||||
|
|
||||||
// dial a new session
|
|
||||||
c, err := tun.Dial("test-tunnel")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
m := transport.Message{
|
|
||||||
Header: map[string]string{
|
|
||||||
"test": "send",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the message
|
|
||||||
if err := c.Send(&m); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for the listener to get ready
|
|
||||||
<-wait
|
|
||||||
|
|
||||||
// give it time to reconnect
|
|
||||||
time.Sleep(5 * ReconnectTime)
|
|
||||||
|
|
||||||
// send the message
|
|
||||||
if err := c.Send(&m); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for the listener to receive the message
|
|
||||||
// c.Send merely enqueues the message to the link send queue and returns
|
|
||||||
// in order to verify it was received we wait for the listener to tell us
|
|
||||||
<-wait
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTunnelRTTRate(t *testing.T) {
|
func TestTunnelRTTRate(t *testing.T) {
|
||||||
// create a new tunnel client
|
// create a new tunnel client
|
||||||
tunA := NewTunnel(
|
tunA := NewTunnel(
|
||||||
@ -296,3 +296,49 @@ func TestTunnelRTTRate(t *testing.T) {
|
|||||||
t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate())
|
t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReconnectTunnel(t *testing.T) {
|
||||||
|
// we manually override the tunnel.ReconnectTime value here
|
||||||
|
// this is so that we make the reconnects faster than the default 5s
|
||||||
|
ReconnectTime = 200 * time.Millisecond
|
||||||
|
|
||||||
|
// create a new tunnel client
|
||||||
|
tunA := NewTunnel(
|
||||||
|
Address("127.0.0.1:9098"),
|
||||||
|
Nodes("127.0.0.1:9099"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// create a new tunnel server
|
||||||
|
tunB := NewTunnel(
|
||||||
|
Address("127.0.0.1:9099"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// start tunnel
|
||||||
|
err := tunB.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tunB.Close()
|
||||||
|
|
||||||
|
// start tunnel
|
||||||
|
err = tunA.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tunA.Close()
|
||||||
|
|
||||||
|
wait := make(chan bool)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
// start tunnel listener
|
||||||
|
go testBrokenTunAccept(t, tunB, wait, &wg)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
// start tunnel sender
|
||||||
|
go testBrokenTunSend(t, tunA, wait, &wg, ReconnectTime*5)
|
||||||
|
|
||||||
|
// wait until done
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user