save working solution

This commit is contained in:
Asim Aslam 2019-12-06 00:18:40 +00:00
parent 219efd27e9
commit 1d8c66780e
4 changed files with 71 additions and 33 deletions

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"io"
"math" "math"
"sync" "sync"
"time" "time"
@ -70,6 +71,9 @@ type network struct {
connected bool connected bool
// closed closes the network // closed closes the network
closed chan bool closed chan bool
// whether we've announced the first connect successfully
// and received back some sort of peer message
announced chan bool
} }
// newNetwork returns a new network node // newNetwork returns a new network node
@ -267,11 +271,12 @@ func (n *network) handleNetConn(s tunnel.Session, msg chan *message) {
m := new(transport.Message) m := new(transport.Message)
if err := s.Recv(m); err != nil { if err := s.Recv(m); err != nil {
log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err) log.Debugf("Network tunnel [%s] receive error: %v", NetworkChannel, err)
if sessionErr := s.Close(); sessionErr != nil { if err == io.EOF {
log.Debugf("Network tunnel [%s] closing connection error: %v", NetworkChannel, sessionErr) s.Close()
}
return return
} }
continue
}
select { select {
case msg <- &message{ case msg <- &message{
@ -452,6 +457,14 @@ func (n *network) processNetChan(listener tunnel.Listener) {
if err := n.node.UpdatePeer(peer); err != nil { if err := n.node.UpdatePeer(peer); err != nil {
log.Debugf("Network failed to update peers: %v", err) log.Debugf("Network failed to update peers: %v", err)
} }
// check if we announced and were discovered
select {
case <-n.announced:
// we've sent the connect and received this response
default:
close(n.announced)
}
case "close": case "close":
pbNetClose := &pbNet.Close{} pbNetClose := &pbNet.Close{}
if err := proto.Unmarshal(m.msg.Body, pbNetClose); err != nil { if err := proto.Unmarshal(m.msg.Body, pbNetClose); err != nil {
@ -622,11 +635,12 @@ func (n *network) handleCtrlConn(s tunnel.Session, msg chan *message) {
m := new(transport.Message) m := new(transport.Message)
if err := s.Recv(m); err != nil { if err := s.Recv(m); err != nil {
log.Debugf("Network tunnel [%s] receive error: %v", ControlChannel, err) log.Debugf("Network tunnel [%s] receive error: %v", ControlChannel, err)
if sessionErr := s.Close(); sessionErr != nil { if err == io.EOF {
log.Debugf("Network tunnel [%s] closing connection error: %v", ControlChannel, sessionErr) s.Close()
}
return return
} }
continue
}
select { select {
case msg <- &message{ case msg <- &message{
@ -930,6 +944,40 @@ func (n *network) sendConnect() {
} }
} }
// connect will wait for a link to be established
func (n *network) connect() {
// wait for connected state
var connected bool
for {
// check the links
for _, link := range n.tunnel.Links() {
if link.State() == "connected" {
connected = true
break
}
}
// if we're not conencted wait
if !connected {
time.Sleep(time.Second)
continue
}
// send the connect message
n.sendConnect()
// check the announce channel
select {
case <-n.announced:
return
default:
time.Sleep(time.Second)
// we have to go again
}
}
}
// Connect connects the network // Connect connects the network
func (n *network) Connect() error { func (n *network) Connect() error {
n.Lock() n.Lock()
@ -1000,6 +1048,8 @@ func (n *network) Connect() error {
// create closed channel // create closed channel
n.closed = make(chan bool) n.closed = make(chan bool)
// create new announced channel
n.announced = make(chan bool)
// start the router // start the router
if err := n.options.Router.Start(); err != nil { if err := n.options.Router.Start(); err != nil {
@ -1022,25 +1072,7 @@ func (n *network) Connect() error {
n.Unlock() n.Unlock()
// send connect after there's a link established // send connect after there's a link established
go func() { go n.connect()
// wait for 30 ticks e.g 30 seconds
for i := 0; i < 30; i++ {
// get the current links
links := n.tunnel.Links()
// if there are no links wait until we have one
if len(links) == 0 {
time.Sleep(time.Second)
continue
}
// send the connect message
n.sendConnect()
// most importantly
break
}
}()
// go resolving network nodes // go resolving network nodes
go n.resolve() go n.resolve()
// broadcast peers // broadcast peers

View File

@ -629,7 +629,7 @@ func (t *tun) listen(link *link) {
s, exists = t.getSession(channel, "listener") s, exists = t.getSession(channel, "listener")
// only return accept to the session // only return accept to the session
case mtype == "accept": case mtype == "accept":
log.Debugf("Received accept message for %s %s", channel, sessionId) log.Debugf("Received accept message for channel: %s session: %s", channel, sessionId)
s, exists = t.getSession(channel, sessionId) s, exists = t.getSession(channel, sessionId)
if exists && s.accepted { if exists && s.accepted {
continue continue
@ -649,7 +649,7 @@ func (t *tun) listen(link *link) {
// bail if no session or listener has been found // bail if no session or listener has been found
if !exists { if !exists {
log.Debugf("Tunnel skipping no session %s %s exists", channel, sessionId) log.Debugf("Tunnel skipping no channel: %s session: %s exists", channel, sessionId)
// drop it, we don't care about // drop it, we don't care about
// messages we don't know about // messages we don't know about
continue continue
@ -665,7 +665,7 @@ func (t *tun) listen(link *link) {
// otherwise process // otherwise process
} }
log.Debugf("Tunnel using channel %s session %s type %s", s.channel, s.session, mtype) log.Debugf("Tunnel using channel: %s session: %s type: %s", s.channel, s.session, mtype)
// construct a new transport message // construct a new transport message
tmsg := &transport.Message{ tmsg := &transport.Message{

View File

@ -71,7 +71,7 @@ func (t *tunListener) process() {
switch m.mode { switch m.mode {
case Multicast, Broadcast: case Multicast, Broadcast:
// use channel name if multicast/broadcast // use channel name if multicast/broadcast
sessionId = m.channel sessionId = "multicast"
log.Tracef("Tunnel listener using session %s for real session %s", sessionId, m.session) log.Tracef("Tunnel listener using session %s for real session %s", sessionId, m.session)
default: default:
// use session id if unicast // use session id if unicast
@ -198,6 +198,10 @@ func (t *tunListener) Accept() (Session, error) {
if !ok { if !ok {
return nil, io.EOF return nil, io.EOF
} }
// return without accept
if c.mode != Unicast {
return c, nil
}
// send back the accept // send back the accept
if err := c.Accept(); err != nil { if err := c.Accept(); err != nil {
return nil, err return nil, err

View File

@ -2,7 +2,6 @@ package tunnel
import ( import (
"encoding/hex" "encoding/hex"
"errors"
"io" "io"
"time" "time"
@ -344,7 +343,7 @@ func (s *session) Recv(m *transport.Message) error {
select { select {
case <-s.closed: case <-s.closed:
return errors.New("session is closed") return io.EOF
// recv from backlog // recv from backlog
case msg = <-s.recv: case msg = <-s.recv:
} }
@ -360,7 +359,10 @@ func (s *session) Recv(m *transport.Message) error {
log.Debugf("Received %+v from recv backlog", msg) log.Debugf("Received %+v from recv backlog", msg)
// decrypt the received payload using the token // decrypt the received payload using the token
body, err := Decrypt(msg.data.Body, s.token+s.channel+s.session) // we have to used msg.session because multicast has a shared
// session id of "multicast" in this session struct on
// the listener side
body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session)
if err != nil { if err != nil {
log.Debugf("failed to decrypt message body: %v", err) log.Debugf("failed to decrypt message body: %v", err)
return err return err
@ -376,7 +378,7 @@ func (s *session) Recv(m *transport.Message) error {
return err return err
} }
// encrypt the transport message payload // encrypt the transport message payload
val, err := Decrypt([]byte(h), s.token+s.channel+s.session) val, err := Decrypt([]byte(h), s.token+s.channel+msg.session)
if err != nil { if err != nil {
log.Debugf("failed to decrypt message header %s: %v", k, err) log.Debugf("failed to decrypt message header %s: %v", k, err)
return err return err