save changes
This commit is contained in:
parent
c445aed6b1
commit
9bd0fb9125
@ -234,7 +234,7 @@ func (n *network) resolveNodes() ([]string, error) {
|
|||||||
dns := &dns.Resolver{}
|
dns := &dns.Resolver{}
|
||||||
|
|
||||||
// append seed nodes if we have them
|
// append seed nodes if we have them
|
||||||
for _, node := range n.options.Peers {
|
for _, node := range n.options.Nodes {
|
||||||
// resolve anything that looks like a host name
|
// resolve anything that looks like a host name
|
||||||
records, err := dns.Resolve(node)
|
records, err := dns.Resolve(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -282,7 +282,8 @@ func (n *network) handleNetConn(s tunnel.Session, msg chan *message) {
|
|||||||
m := new(transport.Message)
|
m := new(transport.Message)
|
||||||
if err := s.Recv(m); err != nil {
|
if err := s.Recv(m); err != nil {
|
||||||
log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err)
|
log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err)
|
||||||
if err == io.EOF {
|
switch err {
|
||||||
|
case io.EOF, tunnel.ErrReadTimeout:
|
||||||
s.Close()
|
s.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -338,39 +339,10 @@ func (n *network) acceptNetConn(l tunnel.Listener, recv chan *message) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updatePeerLinks updates link for a given peer
|
|
||||||
func (n *network) updatePeerLinks(peerAddr string, linkId string) error {
|
|
||||||
n.Lock()
|
|
||||||
defer n.Unlock()
|
|
||||||
log.Tracef("Network looking up link %s in the peer links", linkId)
|
|
||||||
// lookup the peer link
|
|
||||||
var peerLink tunnel.Link
|
|
||||||
for _, link := range n.tunnel.Links() {
|
|
||||||
if link.Id() == linkId {
|
|
||||||
peerLink = link
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if peerLink == nil {
|
|
||||||
return ErrPeerLinkNotFound
|
|
||||||
}
|
|
||||||
// if the peerLink is found in the returned links update peerLinks
|
|
||||||
log.Tracef("Network updating peer links for peer %s", peerAddr)
|
|
||||||
// add peerLink to the peerLinks map
|
|
||||||
if link, ok := n.peerLinks[peerAddr]; ok {
|
|
||||||
// if the existing has better Length then the new, replace it
|
|
||||||
if link.Length() < peerLink.Length() {
|
|
||||||
n.peerLinks[peerAddr] = peerLink
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
n.peerLinks[peerAddr] = peerLink
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// processNetChan processes messages received on NetworkChannel
|
// processNetChan processes messages received on NetworkChannel
|
||||||
func (n *network) processNetChan(listener tunnel.Listener) {
|
func (n *network) processNetChan(listener tunnel.Listener) {
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
// receive network message queue
|
// receive network message queue
|
||||||
recv := make(chan *message, 128)
|
recv := make(chan *message, 128)
|
||||||
|
|
||||||
@ -707,13 +679,51 @@ func (n *network) sendMsg(method, channel string, msg proto.Message) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updatePeerLinks updates link for a given peer
|
||||||
|
func (n *network) updatePeerLinks(peerAddr string, linkId string) error {
|
||||||
|
n.Lock()
|
||||||
|
defer n.Unlock()
|
||||||
|
|
||||||
|
log.Tracef("Network looking up link %s in the peer links", linkId)
|
||||||
|
|
||||||
|
// lookup the peer link
|
||||||
|
var peerLink tunnel.Link
|
||||||
|
|
||||||
|
for _, link := range n.tunnel.Links() {
|
||||||
|
if link.Id() == linkId {
|
||||||
|
peerLink = link
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerLink == nil {
|
||||||
|
return ErrPeerLinkNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the peerLink is found in the returned links update peerLinks
|
||||||
|
log.Tracef("Network updating peer links for peer %s", peerAddr)
|
||||||
|
|
||||||
|
// add peerLink to the peerLinks map
|
||||||
|
if link, ok := n.peerLinks[peerAddr]; ok {
|
||||||
|
// if the existing has better Length then the new, replace it
|
||||||
|
if link.Length() < peerLink.Length() {
|
||||||
|
n.peerLinks[peerAddr] = peerLink
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n.peerLinks[peerAddr] = peerLink
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// handleCtrlConn handles ControlChannel connections
|
// handleCtrlConn handles ControlChannel connections
|
||||||
func (n *network) handleCtrlConn(s tunnel.Session, msg chan *message) {
|
func (n *network) handleCtrlConn(s tunnel.Session, msg chan *message) {
|
||||||
for {
|
for {
|
||||||
m := new(transport.Message)
|
m := new(transport.Message)
|
||||||
if err := s.Recv(m); err != nil {
|
if err := s.Recv(m); err != nil {
|
||||||
log.Debugf("Network tunnel [%s] receive error: %v", ControlChannel, err)
|
log.Debugf("Network tunnel [%s] receive error: %v", ControlChannel, err)
|
||||||
if err == io.EOF {
|
switch err {
|
||||||
|
case io.EOF, tunnel.ErrReadTimeout:
|
||||||
s.Close()
|
s.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -843,6 +853,8 @@ func (n *network) getRouteMetric(router string, gateway string, link string) int
|
|||||||
|
|
||||||
// processCtrlChan processes messages received on ControlChannel
|
// processCtrlChan processes messages received on ControlChannel
|
||||||
func (n *network) processCtrlChan(listener tunnel.Listener) {
|
func (n *network) processCtrlChan(listener tunnel.Listener) {
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
// receive control message queue
|
// receive control message queue
|
||||||
recv := make(chan *message, 128)
|
recv := make(chan *message, 128)
|
||||||
|
|
||||||
@ -1151,12 +1163,7 @@ func (n *network) connect() {
|
|||||||
// Connect connects the network
|
// Connect connects the network
|
||||||
func (n *network) Connect() error {
|
func (n *network) Connect() error {
|
||||||
n.Lock()
|
n.Lock()
|
||||||
|
defer n.Unlock()
|
||||||
// connect network tunnel
|
|
||||||
if err := n.tunnel.Connect(); err != nil {
|
|
||||||
n.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// try to resolve network nodes
|
// try to resolve network nodes
|
||||||
nodes, err := n.resolveNodes()
|
nodes, err := n.resolveNodes()
|
||||||
@ -1169,10 +1176,14 @@ func (n *network) Connect() error {
|
|||||||
tunnel.Nodes(nodes...),
|
tunnel.Nodes(nodes...),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// connect network tunnel
|
||||||
|
if err := n.tunnel.Connect(); err != nil {
|
||||||
|
n.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// return if already connected
|
// return if already connected
|
||||||
if n.connected {
|
if n.connected {
|
||||||
// unlock first
|
|
||||||
n.Unlock()
|
|
||||||
// send the connect message
|
// send the connect message
|
||||||
n.sendConnect()
|
n.sendConnect()
|
||||||
return nil
|
return nil
|
||||||
@ -1187,32 +1198,36 @@ func (n *network) Connect() error {
|
|||||||
// dial into ControlChannel to send route adverts
|
// dial into ControlChannel to send route adverts
|
||||||
ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMode(tunnel.Multicast))
|
ctrlClient, err := n.tunnel.Dial(ControlChannel, tunnel.DialMode(tunnel.Multicast))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
n.tunClient[ControlChannel] = ctrlClient
|
n.tunClient[ControlChannel] = ctrlClient
|
||||||
|
|
||||||
// listen on ControlChannel
|
// listen on ControlChannel
|
||||||
ctrlListener, err := n.tunnel.Listen(ControlChannel, tunnel.ListenMode(tunnel.Multicast))
|
ctrlListener, err := n.tunnel.Listen(
|
||||||
|
ControlChannel,
|
||||||
|
tunnel.ListenMode(tunnel.Multicast),
|
||||||
|
tunnel.ListenTimeout(router.AdvertiseTableTick*2),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial into NetworkChannel to send network messages
|
// dial into NetworkChannel to send network messages
|
||||||
netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMode(tunnel.Multicast))
|
netClient, err := n.tunnel.Dial(NetworkChannel, tunnel.DialMode(tunnel.Multicast))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
n.tunClient[NetworkChannel] = netClient
|
n.tunClient[NetworkChannel] = netClient
|
||||||
|
|
||||||
// listen on NetworkChannel
|
// listen on NetworkChannel
|
||||||
netListener, err := n.tunnel.Listen(NetworkChannel, tunnel.ListenMode(tunnel.Multicast))
|
netListener, err := n.tunnel.Listen(
|
||||||
|
NetworkChannel,
|
||||||
|
tunnel.ListenMode(tunnel.Multicast),
|
||||||
|
tunnel.ListenTimeout(AnnounceTime*2),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1221,23 +1236,19 @@ func (n *network) Connect() error {
|
|||||||
|
|
||||||
// start the router
|
// start the router
|
||||||
if err := n.options.Router.Start(); err != nil {
|
if err := n.options.Router.Start(); err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start advertising routes
|
// start advertising routes
|
||||||
advertChan, err := n.options.Router.Advertise()
|
advertChan, err := n.options.Router.Advertise()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the server
|
// start the server
|
||||||
if err := n.server.Start(); err != nil {
|
if err := n.server.Start(); err != nil {
|
||||||
n.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
n.Unlock()
|
|
||||||
|
|
||||||
// send connect after there's a link established
|
// send connect after there's a link established
|
||||||
go n.connect()
|
go n.connect()
|
||||||
@ -1252,9 +1263,8 @@ func (n *network) Connect() error {
|
|||||||
// accept and process routes
|
// accept and process routes
|
||||||
go n.processCtrlChan(ctrlListener)
|
go n.processCtrlChan(ctrlListener)
|
||||||
|
|
||||||
n.Lock()
|
// we're now connected
|
||||||
n.connected = true
|
n.connected = true
|
||||||
n.Unlock()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1340,6 +1350,7 @@ func (n *network) Close() error {
|
|||||||
Address: n.node.address,
|
Address: n.node.address,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := n.sendMsg("close", NetworkChannel, msg); err != nil {
|
if err := n.sendMsg("close", NetworkChannel, msg); err != nil {
|
||||||
log.Debugf("Network failed to send close message: %s", err)
|
log.Debugf("Network failed to send close message: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/micro/go-micro/client"
|
"github.com/micro/go-micro/client"
|
||||||
"github.com/micro/go-micro/server"
|
"github.com/micro/go-micro/server"
|
||||||
"github.com/micro/go-micro/transport"
|
|
||||||
"github.com/micro/go-micro/tunnel"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -22,8 +22,8 @@ type Options struct {
|
|||||||
Address string
|
Address string
|
||||||
// Advertise sets the address to advertise
|
// Advertise sets the address to advertise
|
||||||
Advertise string
|
Advertise string
|
||||||
// Peers is a list of peers to connect to
|
// Nodes is a list of nodes to connect to
|
||||||
Peers []string
|
Nodes []string
|
||||||
// Tunnel is network tunnel
|
// Tunnel is network tunnel
|
||||||
Tunnel tunnel.Tunnel
|
Tunnel tunnel.Tunnel
|
||||||
// Router is network router
|
// Router is network router
|
||||||
@ -62,10 +62,10 @@ func Advertise(a string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peers is a list of peers to connect to
|
// Nodes is a list of nodes to connect to
|
||||||
func Peers(n ...string) Option {
|
func Nodes(n ...string) Option {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.Peers = n
|
o.Nodes = n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,17 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
|||||||
r.Address = "1.0.0.1:53"
|
r.Address = "1.0.0.1:53"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:prealloc
|
||||||
|
var records []*resolver.Record
|
||||||
|
|
||||||
|
// parsed an actual ip
|
||||||
|
if v := net.ParseIP(host); v != nil {
|
||||||
|
records = append(records, &resolver.Record{
|
||||||
|
Address: net.JoinHostPort(host, port),
|
||||||
|
})
|
||||||
|
return records, nil
|
||||||
|
}
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||||
rec, err := dns.ExchangeContext(context.Background(), m, r.Address)
|
rec, err := dns.ExchangeContext(context.Background(), m, r.Address)
|
||||||
@ -38,9 +49,6 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:prealloc
|
|
||||||
var records []*resolver.Record
|
|
||||||
|
|
||||||
for _, answer := range rec.Answer {
|
for _, answer := range rec.Answer {
|
||||||
h := answer.Header()
|
h := answer.Header()
|
||||||
// check record type matches
|
// check record type matches
|
||||||
@ -59,5 +67,12 @@ func (r *Resolver) Resolve(name string) ([]*resolver.Record, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// no records returned so just best effort it
|
||||||
|
if len(records) == 0 {
|
||||||
|
records = append(records, &resolver.Record{
|
||||||
|
Address: net.JoinHostPort(host, port),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get list of existing nodes
|
// get list of existing nodes
|
||||||
nodes := n.Network.Options().Peers
|
nodes := n.Network.Options().Nodes
|
||||||
|
|
||||||
// generate a node map
|
// generate a node map
|
||||||
nodeMap := make(map[string]bool)
|
nodeMap := make(map[string]bool)
|
||||||
@ -84,7 +84,7 @@ func (n *Network) Connect(ctx context.Context, req *pbNet.ConnectRequest, resp *
|
|||||||
|
|
||||||
// reinitialise the peers
|
// reinitialise the peers
|
||||||
n.Network.Init(
|
n.Network.Init(
|
||||||
network.Peers(nodes...),
|
network.Nodes(nodes...),
|
||||||
)
|
)
|
||||||
|
|
||||||
// call the connect method
|
// call the connect method
|
||||||
|
@ -197,19 +197,30 @@ func (t *tun) announce(channel, session string, link *link) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// monitor monitors outbound links and attempts to reconnect to the failed ones
|
// manage monitors outbound links and attempts to reconnect to the failed ones
|
||||||
func (t *tun) monitor() {
|
func (t *tun) manage() {
|
||||||
reconnect := time.NewTicker(ReconnectTime)
|
reconnect := time.NewTicker(ReconnectTime)
|
||||||
defer reconnect.Stop()
|
defer reconnect.Stop()
|
||||||
|
|
||||||
|
// do it immediately
|
||||||
|
t.manageLinks()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.closed:
|
case <-t.closed:
|
||||||
return
|
return
|
||||||
case <-reconnect.C:
|
case <-reconnect.C:
|
||||||
|
t.manageLinks()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// manageLinks is a function that can be called to immediately to link setup
|
||||||
|
func (t *tun) manageLinks() {
|
||||||
|
var delLinks []string
|
||||||
|
|
||||||
t.RLock()
|
t.RLock()
|
||||||
|
|
||||||
var delLinks []string
|
|
||||||
// check the link status and purge dead links
|
// check the link status and purge dead links
|
||||||
for node, link := range t.links {
|
for node, link := range t.links {
|
||||||
// check link status
|
// check link status
|
||||||
@ -241,27 +252,39 @@ func (t *tun) monitor() {
|
|||||||
|
|
||||||
// build list of unknown nodes to connect to
|
// build list of unknown nodes to connect to
|
||||||
t.RLock()
|
t.RLock()
|
||||||
|
|
||||||
for _, node := range t.options.Nodes {
|
for _, node := range t.options.Nodes {
|
||||||
if _, ok := t.links[node]; !ok {
|
if _, ok := t.links[node]; !ok {
|
||||||
connect = append(connect, node)
|
connect = append(connect, node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.RUnlock()
|
t.RUnlock()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for _, node := range connect {
|
for _, node := range connect {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
// create new link
|
// create new link
|
||||||
link, err := t.setupLink(node)
|
link, err := t.setupLink(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
|
log.Debugf("Tunnel failed to setup node link to %s: %v", node, err)
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the link
|
// save the link
|
||||||
t.Lock()
|
t.Lock()
|
||||||
t.links[node] = link
|
t.links[node] = link
|
||||||
t.Unlock()
|
t.Unlock()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
// wait for all threads to finish
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// process outgoing messages sent by all local sessions
|
// process outgoing messages sent by all local sessions
|
||||||
@ -757,11 +780,13 @@ func (t *tun) keepalive(link *link) {
|
|||||||
// It returns error if the link failed to be established
|
// It returns error if the link failed to be established
|
||||||
func (t *tun) setupLink(node string) (*link, error) {
|
func (t *tun) setupLink(node string) (*link, error) {
|
||||||
log.Debugf("Tunnel setting up link: %s", node)
|
log.Debugf("Tunnel setting up link: %s", node)
|
||||||
|
|
||||||
c, err := t.options.Transport.Dial(node)
|
c, err := t.options.Transport.Dial(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Tunnel failed to connect to %s: %v", node, err)
|
log.Debugf("Tunnel failed to connect to %s: %v", node, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Tunnel connected to %s", node)
|
log.Debugf("Tunnel connected to %s", node)
|
||||||
|
|
||||||
// create a new link
|
// create a new link
|
||||||
@ -795,30 +820,6 @@ func (t *tun) setupLink(node string) (*link, error) {
|
|||||||
return link, nil
|
return link, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) setupLinks() {
|
|
||||||
for _, node := range t.options.Nodes {
|
|
||||||
// skip zero length nodes
|
|
||||||
if len(node) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// link already exists
|
|
||||||
if _, ok := t.links[node]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect to node and return link
|
|
||||||
link, err := t.setupLink(node)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("Tunnel failed to establish node link to %s: %v", node, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// save the link
|
|
||||||
t.links[node] = link
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect the tunnel to all the nodes and listen for incoming tunnel connections
|
// connect the tunnel to all the nodes and listen for incoming tunnel connections
|
||||||
func (t *tun) connect() error {
|
func (t *tun) connect() error {
|
||||||
l, err := t.options.Transport.Listen(t.options.Address)
|
l, err := t.options.Transport.Listen(t.options.Address)
|
||||||
@ -869,11 +870,10 @@ func (t *tun) Connect() error {
|
|||||||
// already connected
|
// already connected
|
||||||
if t.connected {
|
if t.connected {
|
||||||
// setup links
|
// setup links
|
||||||
t.setupLinks()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// send the connect message
|
// connect the tunnel: start the listener
|
||||||
if err := t.connect(); err != nil {
|
if err := t.connect(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -883,16 +883,13 @@ func (t *tun) Connect() error {
|
|||||||
// create new close channel
|
// create new close channel
|
||||||
t.closed = make(chan bool)
|
t.closed = make(chan bool)
|
||||||
|
|
||||||
// setup links
|
// manage the links
|
||||||
t.setupLinks()
|
go t.manage()
|
||||||
|
|
||||||
// process outbound messages to be sent
|
// process outbound messages to be sent
|
||||||
// process sends to all links
|
// process sends to all links
|
||||||
go t.process()
|
go t.process()
|
||||||
|
|
||||||
// monitor links
|
|
||||||
go t.monitor()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1029,7 +1026,9 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
// set the multicast option
|
// set the multicast option
|
||||||
c.mode = options.Mode
|
c.mode = options.Mode
|
||||||
// set the dial timeout
|
// set the dial timeout
|
||||||
c.timeout = options.Timeout
|
c.dialTimeout = options.Timeout
|
||||||
|
// set read timeout set to never
|
||||||
|
c.readTimeout = time.Duration(-1)
|
||||||
|
|
||||||
var links []*link
|
var links []*link
|
||||||
// did we measure the rtt
|
// did we measure the rtt
|
||||||
@ -1145,7 +1144,11 @@ func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) {
|
|||||||
func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
||||||
log.Debugf("Tunnel listening on %s", channel)
|
log.Debugf("Tunnel listening on %s", channel)
|
||||||
|
|
||||||
var options ListenOptions
|
options := ListenOptions{
|
||||||
|
// Read timeout defaults to never
|
||||||
|
Timeout: time.Duration(-1),
|
||||||
|
}
|
||||||
|
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o(&options)
|
o(&options)
|
||||||
}
|
}
|
||||||
@ -1167,6 +1170,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
|||||||
c.local = channel
|
c.local = channel
|
||||||
// set mode
|
// set mode
|
||||||
c.mode = options.Mode
|
c.mode = options.Mode
|
||||||
|
// set the timeout
|
||||||
|
c.readTimeout = options.Timeout
|
||||||
|
|
||||||
tl := &tunListener{
|
tl := &tunListener{
|
||||||
channel: channel,
|
channel: channel,
|
||||||
|
@ -65,18 +65,8 @@ func (t *tunListener) process() {
|
|||||||
return
|
return
|
||||||
// receive a new message
|
// receive a new message
|
||||||
case m := <-t.session.recv:
|
case m := <-t.session.recv:
|
||||||
var sessionId string
|
// session id
|
||||||
|
sessionId := m.session
|
||||||
// get the session id
|
|
||||||
switch m.mode {
|
|
||||||
case Multicast, Broadcast:
|
|
||||||
// use channel name if multicast/broadcast
|
|
||||||
sessionId = "multicast"
|
|
||||||
log.Tracef("Tunnel listener using session %s for real session %s", sessionId, m.session)
|
|
||||||
default:
|
|
||||||
// use session id if unicast
|
|
||||||
sessionId = m.session
|
|
||||||
}
|
|
||||||
|
|
||||||
// get a session
|
// get a session
|
||||||
sess, ok := conns[sessionId]
|
sess, ok := conns[sessionId]
|
||||||
@ -113,6 +103,8 @@ func (t *tunListener) process() {
|
|||||||
send: t.session.send,
|
send: t.session.send,
|
||||||
// error channel
|
// error channel
|
||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
|
// set the read timeout
|
||||||
|
readTimeout: t.session.readTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the session
|
// save the session
|
||||||
@ -137,13 +129,11 @@ func (t *tunListener) process() {
|
|||||||
// no op
|
// no op
|
||||||
delete(conns, sessionId)
|
delete(conns, sessionId)
|
||||||
default:
|
default:
|
||||||
if sess.mode == Unicast {
|
|
||||||
// only close if unicast session
|
// only close if unicast session
|
||||||
// close and delete session
|
// close and delete session
|
||||||
close(sess.closed)
|
close(sess.closed)
|
||||||
delete(conns, sessionId)
|
delete(conns, sessionId)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// continue
|
// continue
|
||||||
continue
|
continue
|
||||||
|
@ -47,6 +47,8 @@ type ListenOption func(*ListenOptions)
|
|||||||
type ListenOptions struct {
|
type ListenOptions struct {
|
||||||
// specify mode of the session
|
// specify mode of the session
|
||||||
Mode Mode
|
Mode Mode
|
||||||
|
// The read timeout
|
||||||
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tunnel id
|
// The tunnel id
|
||||||
@ -84,16 +86,6 @@ func Transport(t transport.Transport) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultOptions returns router default options
|
|
||||||
func DefaultOptions() Options {
|
|
||||||
return Options{
|
|
||||||
Id: uuid.New().String(),
|
|
||||||
Address: DefaultAddress,
|
|
||||||
Token: DefaultToken,
|
|
||||||
Transport: quic.NewTransport(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen options
|
// Listen options
|
||||||
func ListenMode(m Mode) ListenOption {
|
func ListenMode(m Mode) ListenOption {
|
||||||
return func(o *ListenOptions) {
|
return func(o *ListenOptions) {
|
||||||
@ -101,6 +93,13 @@ func ListenMode(m Mode) ListenOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Timeout for reads and writes on the listener session
|
||||||
|
func ListenTimeout(t time.Duration) ListenOption {
|
||||||
|
return func(o *ListenOptions) {
|
||||||
|
o.Timeout = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Dial options
|
// Dial options
|
||||||
|
|
||||||
// Dial multicast sets the multicast option to send only to those mapped
|
// Dial multicast sets the multicast option to send only to those mapped
|
||||||
@ -124,3 +123,13 @@ func DialLink(id string) DialOption {
|
|||||||
o.Link = id
|
o.Link = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultOptions returns router default options
|
||||||
|
func DefaultOptions() Options {
|
||||||
|
return Options{
|
||||||
|
Id: uuid.New().String(),
|
||||||
|
Address: DefaultAddress,
|
||||||
|
Token: DefaultToken,
|
||||||
|
Transport: quic.NewTransport(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -39,8 +39,10 @@ type session struct {
|
|||||||
loopback bool
|
loopback bool
|
||||||
// mode of the connection
|
// mode of the connection
|
||||||
mode Mode
|
mode Mode
|
||||||
// the timeout
|
// the dial timeout
|
||||||
timeout time.Duration
|
dialTimeout time.Duration
|
||||||
|
// the read timeout
|
||||||
|
readTimeout time.Duration
|
||||||
// the link on which this message was received
|
// the link on which this message was received
|
||||||
link string
|
link string
|
||||||
// the error response
|
// the error response
|
||||||
@ -133,31 +135,43 @@ func (s *session) wait(msg *message) error {
|
|||||||
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
|
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
after := func(timeout time.Duration) time.Duration {
|
after := func(timeout time.Duration) <-chan time.Time {
|
||||||
|
if timeout < time.Duration(0) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the delta
|
||||||
d := time.Since(now)
|
d := time.Since(now)
|
||||||
|
|
||||||
// dial timeout minus time since
|
// dial timeout minus time since
|
||||||
wait := timeout - d
|
wait := timeout - d
|
||||||
|
|
||||||
if wait < time.Duration(0) {
|
if wait < time.Duration(0) {
|
||||||
return time.Duration(0)
|
wait = time.Duration(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
return wait
|
return time.After(wait)
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the message type
|
// wait for the message type
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-s.recv:
|
case msg := <-s.recv:
|
||||||
|
// there may be no message type
|
||||||
|
if len(msgType) == 0 {
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ignore what we don't want
|
// ignore what we don't want
|
||||||
if msg.typ != msgType {
|
if msg.typ != msgType {
|
||||||
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// got the message
|
// got the message
|
||||||
return msg, nil
|
return msg, nil
|
||||||
case <-time.After(after(timeout)):
|
case <-after(timeout):
|
||||||
return nil, ErrDialTimeout
|
return nil, ErrReadTimeout
|
||||||
case <-s.closed:
|
case <-s.closed:
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
@ -193,7 +207,8 @@ func (s *session) Discover() error {
|
|||||||
after := func() time.Duration {
|
after := func() time.Duration {
|
||||||
d := time.Since(now)
|
d := time.Since(now)
|
||||||
// dial timeout minus time since
|
// dial timeout minus time since
|
||||||
wait := s.timeout - d
|
wait := s.dialTimeout - d
|
||||||
|
// make sure its always > 0
|
||||||
if wait < time.Duration(0) {
|
if wait < time.Duration(0) {
|
||||||
return time.Duration(0)
|
return time.Duration(0)
|
||||||
}
|
}
|
||||||
@ -248,7 +263,7 @@ func (s *session) Open() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// now wait for the accept message to be returned
|
// now wait for the accept message to be returned
|
||||||
msg, err := s.waitFor("accept", s.timeout)
|
msg, err := s.waitFor("accept", s.dialTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -341,11 +356,9 @@ func (s *session) Send(m *transport.Message) error {
|
|||||||
func (s *session) Recv(m *transport.Message) error {
|
func (s *session) Recv(m *transport.Message) error {
|
||||||
var msg *message
|
var msg *message
|
||||||
|
|
||||||
select {
|
msg, err := s.waitFor("", s.readTimeout)
|
||||||
case <-s.closed:
|
if err != nil {
|
||||||
return io.EOF
|
return err
|
||||||
// recv from backlog
|
|
||||||
case msg = <-s.recv:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check the error if one exists
|
// check the error if one exists
|
||||||
|
@ -26,6 +26,8 @@ var (
|
|||||||
ErrDiscoverChan = errors.New("failed to discover channel")
|
ErrDiscoverChan = errors.New("failed to discover channel")
|
||||||
// ErrLinkNotFound is returned when a link is specified at dial time and does not exist
|
// ErrLinkNotFound is returned when a link is specified at dial time and does not exist
|
||||||
ErrLinkNotFound = errors.New("link not found")
|
ErrLinkNotFound = errors.New("link not found")
|
||||||
|
// ErrReadTimeout is a timeout on session.Recv
|
||||||
|
ErrReadTimeout = errors.New("read timeout")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Mode of the session
|
// Mode of the session
|
||||||
|
Loading…
Reference in New Issue
Block a user