Merge pull request #729 from micro/tunnel
Tunnel session management and unicast/multicast
This commit is contained in:
		| @@ -718,7 +718,7 @@ func (n *network) Connect() error { | |||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	// dial into ControlChannel to send route adverts | 	// dial into ControlChannel to send route adverts | ||||||
| 	ctrlClient, err := n.Tunnel.Dial(ControlChannel) | 	ctrlClient, err := n.Tunnel.Dial(ControlChannel, tunnel.DialMulticast()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -732,7 +732,7 @@ func (n *network) Connect() error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// dial into NetworkChannel to send network messages | 	// dial into NetworkChannel to send network messages | ||||||
| 	netClient, err := n.Tunnel.Dial(NetworkChannel) | 	netClient, err := n.Tunnel.Dial(NetworkChannel, tunnel.DialMulticast()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -58,7 +58,7 @@ func (t *tunBroker) Disconnect() error { | |||||||
| func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error { | func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error { | ||||||
| 	// TODO: this is probably inefficient, we might want to just maintain an open connection | 	// TODO: this is probably inefficient, we might want to just maintain an open connection | ||||||
| 	// it may be easier to add broadcast to the tunnel | 	// it may be easier to add broadcast to the tunnel | ||||||
| 	c, err := t.tunnel.Dial(topic) | 	c, err := t.tunnel.Dial(topic, tunnel.DialMulticast()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -87,11 +87,17 @@ func (t *tun) getSession(channel, session string) (*session, bool) { | |||||||
| 	return s, ok | 	return s, ok | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t *tun) delSession(channel, session string) { | ||||||
|  | 	t.Lock() | ||||||
|  | 	delete(t.sessions, channel+session) | ||||||
|  | 	t.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
| // newSession creates a new session and saves it | // newSession creates a new session and saves it | ||||||
| func (t *tun) newSession(channel, sessionId string) (*session, bool) { | func (t *tun) newSession(channel, sessionId string) (*session, bool) { | ||||||
| 	// new session | 	// new session | ||||||
| 	s := &session{ | 	s := &session{ | ||||||
| 		id:      t.id, | 		tunnel:  t.id, | ||||||
| 		channel: channel, | 		channel: channel, | ||||||
| 		session: sessionId, | 		session: sessionId, | ||||||
| 		closed:  make(chan bool), | 		closed:  make(chan bool), | ||||||
| @@ -150,7 +156,9 @@ func (t *tun) monitor() { | |||||||
| 					log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) | 					log.Debugf("Tunnel failed to setup node link to %s: %v", node, err) | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  | 				// set the link id to the node | ||||||
|  | 				// TODO: hash it | ||||||
|  | 				link.id = node | ||||||
| 				// save the link | 				// save the link | ||||||
| 				t.Lock() | 				t.Lock() | ||||||
| 				t.links[node] = link | 				t.links[node] = link | ||||||
| @@ -169,18 +177,21 @@ func (t *tun) process() { | |||||||
| 		case msg := <-t.send: | 		case msg := <-t.send: | ||||||
| 			newMsg := &transport.Message{ | 			newMsg := &transport.Message{ | ||||||
| 				Header: make(map[string]string), | 				Header: make(map[string]string), | ||||||
| 				Body:   msg.data.Body, |  | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			for k, v := range msg.data.Header { | 			// set the data | ||||||
| 				newMsg.Header[k] = v | 			if msg.data != nil { | ||||||
|  | 				for k, v := range msg.data.Header { | ||||||
|  | 					newMsg.Header[k] = v | ||||||
|  | 				} | ||||||
|  | 				newMsg.Body = msg.data.Body | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// set message head | 			// set message head | ||||||
| 			newMsg.Header["Micro-Tunnel"] = msg.typ | 			newMsg.Header["Micro-Tunnel"] = msg.typ | ||||||
|  |  | ||||||
| 			// set the tunnel id on the outgoing message | 			// set the tunnel id on the outgoing message | ||||||
| 			newMsg.Header["Micro-Tunnel-Id"] = msg.id | 			newMsg.Header["Micro-Tunnel-Id"] = msg.tunnel | ||||||
|  |  | ||||||
| 			// set the tunnel channel on the outgoing message | 			// set the tunnel channel on the outgoing message | ||||||
| 			newMsg.Header["Micro-Tunnel-Channel"] = msg.channel | 			newMsg.Header["Micro-Tunnel-Channel"] = msg.channel | ||||||
| @@ -195,7 +206,7 @@ func (t *tun) process() { | |||||||
| 			t.Lock() | 			t.Lock() | ||||||
|  |  | ||||||
| 			if len(t.links) == 0 { | 			if len(t.links) == 0 { | ||||||
| 				log.Debugf("No links to send to") | 				log.Debugf("No links to send message type: %s channel: %s", msg.typ, msg.channel) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			var sent bool | 			var sent bool | ||||||
| @@ -232,25 +243,55 @@ func (t *tun) process() { | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
|  | 				// check the multicast mappings | ||||||
|  | 				if msg.multicast { | ||||||
|  | 					link.RLock() | ||||||
|  | 					_, ok := link.channels[msg.channel] | ||||||
|  | 					link.RUnlock() | ||||||
|  | 					// channel mapping not found in link | ||||||
|  | 					if !ok { | ||||||
|  | 						continue | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  |  | ||||||
| 				// send the message via the current link | 				// send the message via the current link | ||||||
| 				log.Debugf("Sending %+v to %s", newMsg, node) | 				log.Debugf("Sending %+v to %s", newMsg, node) | ||||||
|  |  | ||||||
| 				if errr := link.Send(newMsg); errr != nil { | 				if errr := link.Send(newMsg); errr != nil { | ||||||
| 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, errr) | 					log.Debugf("Tunnel error sending %+v to %s: %v", newMsg, node, errr) | ||||||
| 					err = errors.New(errr.Error()) | 					err = errors.New(errr.Error()) | ||||||
|  | 					// kill the link | ||||||
|  | 					link.Close() | ||||||
|  | 					// delete the link | ||||||
| 					delete(t.links, node) | 					delete(t.links, node) | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// is sent | 				// is sent | ||||||
| 				sent = true | 				sent = true | ||||||
|  |  | ||||||
|  | 				// keep sending broadcast messages | ||||||
|  | 				if msg.broadcast || msg.multicast { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// break on unicast | ||||||
|  | 				break | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			t.Unlock() | 			t.Unlock() | ||||||
|  |  | ||||||
|  | 			// set the error if not sent | ||||||
| 			var gerr error | 			var gerr error | ||||||
| 			if !sent { | 			if !sent { | ||||||
| 				gerr = err | 				gerr = err | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// skip if its not been set | ||||||
|  | 			if msg.errChan == nil { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// return error non blocking | 			// return error non blocking | ||||||
| 			select { | 			select { | ||||||
| 			case msg.errChan <- gerr: | 			case msg.errChan <- gerr: | ||||||
| @@ -262,14 +303,25 @@ func (t *tun) process() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t *tun) delLink(id string) { | ||||||
|  | 	t.Lock() | ||||||
|  | 	defer t.Unlock() | ||||||
|  | 	// get the link | ||||||
|  | 	link, ok := t.links[id] | ||||||
|  | 	if !ok { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	// close and delete | ||||||
|  | 	link.Close() | ||||||
|  | 	delete(t.links, id) | ||||||
|  | } | ||||||
|  |  | ||||||
| // process incoming messages | // process incoming messages | ||||||
| func (t *tun) listen(link *link) { | func (t *tun) listen(link *link) { | ||||||
| 	// remove the link on exit | 	// remove the link on exit | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		log.Debugf("Tunnel deleting connection from %s", link.Remote()) | 		log.Debugf("Tunnel deleting connection from %s", link.Remote()) | ||||||
| 		t.Lock() | 		t.delLink(link.Remote()) | ||||||
| 		delete(t.links, link.Remote()) |  | ||||||
| 		t.Unlock() |  | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// let us know if its a loopback | 	// let us know if its a loopback | ||||||
| @@ -292,18 +344,34 @@ func (t *tun) listen(link *link) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		switch msg.Header["Micro-Tunnel"] { | 		// message type | ||||||
|  | 		mtype := msg.Header["Micro-Tunnel"] | ||||||
|  | 		// the tunnel id | ||||||
|  | 		id := msg.Header["Micro-Tunnel-Id"] | ||||||
|  | 		// the tunnel channel | ||||||
|  | 		channel := msg.Header["Micro-Tunnel-Channel"] | ||||||
|  | 		// the session id | ||||||
|  | 		sessionId := msg.Header["Micro-Tunnel-Session"] | ||||||
|  |  | ||||||
|  | 		// if its not connected throw away the link | ||||||
|  | 		// the first message we process needs to be connect | ||||||
|  | 		if !link.connected && mtype != "connect" { | ||||||
|  | 			log.Debugf("Tunnel link %s not connected", link.id) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		switch mtype { | ||||||
| 		case "connect": | 		case "connect": | ||||||
| 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | 			log.Debugf("Tunnel link %s received connect message", link.Remote()) | ||||||
|  |  | ||||||
| 			id := msg.Header["Micro-Tunnel-Id"] |  | ||||||
|  |  | ||||||
| 			// are we connecting to ourselves? | 			// are we connecting to ourselves? | ||||||
| 			if id == t.id { | 			if id == t.id { | ||||||
| 				link.loopback = true | 				link.loopback = true | ||||||
| 				loopback = true | 				loopback = true | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// set to remote node | ||||||
|  | 			link.id = id | ||||||
| 			// set as connected | 			// set as connected | ||||||
| 			link.connected = true | 			link.connected = true | ||||||
|  |  | ||||||
| @@ -315,10 +383,31 @@ func (t *tun) listen(link *link) { | |||||||
| 			// nothing more to do | 			// nothing more to do | ||||||
| 			continue | 			continue | ||||||
| 		case "close": | 		case "close": | ||||||
| 			log.Debugf("Tunnel link %s closing connection", link.Remote()) |  | ||||||
| 			// TODO: handle the close message | 			// TODO: handle the close message | ||||||
| 			// maybe report io.EOF or kill the link | 			// maybe report io.EOF or kill the link | ||||||
| 			return |  | ||||||
|  | 			// close the link entirely | ||||||
|  | 			if len(channel) == 0 { | ||||||
|  | 				log.Debugf("Tunnel link %s received close message", link.Remote()) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// the entire listener was closed so remove it from the mapping | ||||||
|  | 			if sessionId == "listener" { | ||||||
|  | 				link.Lock() | ||||||
|  | 				delete(link.channels, channel) | ||||||
|  | 				link.Unlock() | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// try get the dialing socket | ||||||
|  | 			s, exists := t.getSession(channel, sessionId) | ||||||
|  | 			if exists { | ||||||
|  | 				// close and continue | ||||||
|  | 				s.Close() | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			// otherwise its a session mapping of sorts | ||||||
| 		case "keepalive": | 		case "keepalive": | ||||||
| 			log.Debugf("Tunnel link %s received keepalive", link.Remote()) | 			log.Debugf("Tunnel link %s received keepalive", link.Remote()) | ||||||
| 			t.Lock() | 			t.Lock() | ||||||
| @@ -326,27 +415,64 @@ func (t *tun) listen(link *link) { | |||||||
| 			link.lastKeepAlive = time.Now() | 			link.lastKeepAlive = time.Now() | ||||||
| 			t.Unlock() | 			t.Unlock() | ||||||
| 			continue | 			continue | ||||||
| 		case "message": | 		// a new connection dialled outbound | ||||||
|  | 		case "open": | ||||||
|  | 			// we just let it pass through to be processed | ||||||
|  | 		// an accept returned by the listener | ||||||
|  | 		case "accept": | ||||||
|  |  | ||||||
|  | 		// a continued session | ||||||
|  | 		case "session": | ||||||
| 			// process message | 			// process message | ||||||
| 			log.Debugf("Received %+v from %s", msg, link.Remote()) | 			log.Debugf("Received %+v from %s", msg, link.Remote()) | ||||||
|  | 		// an announcement of a channel listener | ||||||
|  | 		case "announce": | ||||||
|  | 			// update mapping in the link | ||||||
|  | 			link.Lock() | ||||||
|  | 			link.channels[channel] = time.Now() | ||||||
|  | 			link.Unlock() | ||||||
|  |  | ||||||
|  | 			// get the session that asked for the discovery | ||||||
|  | 			s, exists := t.getSession(channel, sessionId) | ||||||
|  | 			if exists { | ||||||
|  | 				// don't bother it's already discovered | ||||||
|  | 				if s.discovered { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// send the announce back to the caller | ||||||
|  | 				s.recv <- &message{ | ||||||
|  | 					typ:     "announce", | ||||||
|  | 					tunnel:  id, | ||||||
|  | 					channel: channel, | ||||||
|  | 					session: sessionId, | ||||||
|  | 					link:    link.id, | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		case "discover": | ||||||
|  | 			// looking for existing mapping | ||||||
|  | 			_, exists := t.getSession(channel, "listener") | ||||||
|  | 			if exists { | ||||||
|  | 				log.Debugf("Tunnel sending announce for discovery of channel %s", channel) | ||||||
|  | 				// send back the announcement | ||||||
|  | 				link.Send(&transport.Message{ | ||||||
|  | 					Header: map[string]string{ | ||||||
|  | 						"Micro-Tunnel":         "announce", | ||||||
|  | 						"Micro-Tunnel-Id":      t.id, | ||||||
|  | 						"Micro-Tunnel-Channel": channel, | ||||||
|  | 						"Micro-Tunnel-Session": sessionId, | ||||||
|  | 						"Micro-Tunnel-Link":    link.id, | ||||||
|  | 						"Micro-Tunnel-Token":   t.token, | ||||||
|  | 					}, | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
| 		default: | 		default: | ||||||
| 			// blackhole it | 			// blackhole it | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// if its not connected throw away the link |  | ||||||
| 		if !link.connected { |  | ||||||
| 			log.Debugf("Tunnel link %s not connected", link.id) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// the tunnel id |  | ||||||
| 		id := msg.Header["Micro-Tunnel-Id"] |  | ||||||
| 		// the tunnel channel |  | ||||||
| 		channel := msg.Header["Micro-Tunnel-Channel"] |  | ||||||
| 		// the session id |  | ||||||
| 		sessionId := msg.Header["Micro-Tunnel-Session"] |  | ||||||
|  |  | ||||||
| 		// strip tunnel message header | 		// strip tunnel message header | ||||||
| 		for k, _ := range msg.Header { | 		for k, _ := range msg.Header { | ||||||
| 			if strings.HasPrefix(k, "Micro-Tunnel") { | 			if strings.HasPrefix(k, "Micro-Tunnel") { | ||||||
| @@ -368,8 +494,15 @@ func (t *tun) listen(link *link) { | |||||||
| 		// If its a loopback connection then we've enabled link direction | 		// If its a loopback connection then we've enabled link direction | ||||||
| 		// listening side is used for listening, the dialling side for dialling | 		// listening side is used for listening, the dialling side for dialling | ||||||
| 		switch { | 		switch { | ||||||
| 		case loopback: | 		case loopback, mtype == "open": | ||||||
| 			s, exists = t.getSession(channel, "listener") | 			s, exists = t.getSession(channel, "listener") | ||||||
|  | 		// only return accept to the session | ||||||
|  | 		case mtype == "accept": | ||||||
|  | 			log.Debugf("Received accept message for %s %s", channel, sessionId) | ||||||
|  | 			s, exists = t.getSession(channel, sessionId) | ||||||
|  | 			if exists && s.accepted { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
| 		default: | 		default: | ||||||
| 			// get the session based on the tunnel id and session | 			// get the session based on the tunnel id and session | ||||||
| 			// this could be something we dialed in which case | 			// this could be something we dialed in which case | ||||||
| @@ -383,7 +516,7 @@ func (t *tun) listen(link *link) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// bail if no session has been found | 		// bail if no session or listener has been found | ||||||
| 		if !exists { | 		if !exists { | ||||||
| 			log.Debugf("Tunnel skipping no session exists") | 			log.Debugf("Tunnel skipping no session exists") | ||||||
| 			// drop it, we don't care about | 			// drop it, we don't care about | ||||||
| @@ -391,8 +524,6 @@ func (t *tun) listen(link *link) { | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debugf("Tunnel using session %s %s", s.channel, s.session) |  | ||||||
|  |  | ||||||
| 		// is the session closed? | 		// is the session closed? | ||||||
| 		select { | 		select { | ||||||
| 		case <-s.closed: | 		case <-s.closed: | ||||||
| @@ -403,6 +534,8 @@ func (t *tun) listen(link *link) { | |||||||
| 			// process | 			// process | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		log.Debugf("Tunnel using channel %s session %s", s.channel, s.session) | ||||||
|  |  | ||||||
| 		// is the session new? | 		// is the session new? | ||||||
| 		select { | 		select { | ||||||
| 		// if its new the session is actually blocked waiting | 		// if its new the session is actually blocked waiting | ||||||
| @@ -423,7 +556,8 @@ func (t *tun) listen(link *link) { | |||||||
|  |  | ||||||
| 		// construct the internal message | 		// construct the internal message | ||||||
| 		imsg := &message{ | 		imsg := &message{ | ||||||
| 			id:       id, | 			tunnel:   id, | ||||||
|  | 			typ:      mtype, | ||||||
| 			channel:  channel, | 			channel:  channel, | ||||||
| 			session:  sessionId, | 			session:  sessionId, | ||||||
| 			data:     tmsg, | 			data:     tmsg, | ||||||
| @@ -461,9 +595,7 @@ func (t *tun) keepalive(link *link) { | |||||||
| 				}, | 				}, | ||||||
| 			}); err != nil { | 			}); err != nil { | ||||||
| 				log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) | 				log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err) | ||||||
| 				t.Lock() | 				t.delLink(link.Remote()) | ||||||
| 				delete(t.links, link.Remote()) |  | ||||||
| 				t.Unlock() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @@ -481,6 +613,7 @@ func (t *tun) setupLink(node string) (*link, error) { | |||||||
| 	} | 	} | ||||||
| 	log.Debugf("Tunnel connected to %s", node) | 	log.Debugf("Tunnel connected to %s", node) | ||||||
|  |  | ||||||
|  | 	// send the first connect message | ||||||
| 	if err := c.Send(&transport.Message{ | 	if err := c.Send(&transport.Message{ | ||||||
| 		Header: map[string]string{ | 		Header: map[string]string{ | ||||||
| 			"Micro-Tunnel":       "connect", | 			"Micro-Tunnel":       "connect", | ||||||
| @@ -493,9 +626,11 @@ func (t *tun) setupLink(node string) (*link, error) { | |||||||
|  |  | ||||||
| 	// create a new link | 	// create a new link | ||||||
| 	link := newLink(c) | 	link := newLink(c) | ||||||
| 	link.connected = true | 	// set link id to remote side | ||||||
|  | 	link.id = c.Remote() | ||||||
| 	// we made the outbound connection | 	// we made the outbound connection | ||||||
| 	// and sent the connect message | 	// and sent the connect message | ||||||
|  | 	link.connected = true | ||||||
|  |  | ||||||
| 	// process incoming messages | 	// process incoming messages | ||||||
| 	go t.listen(link) | 	go t.listen(link) | ||||||
| @@ -553,7 +688,7 @@ func (t *tun) connect() error { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// save the link | 		// save the link | ||||||
| 		t.links[node] = link | 		t.links[link.Remote()] = link | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// process outbound messages to be sent | 	// process outbound messages to be sent | ||||||
| @@ -627,6 +762,8 @@ func (t *tun) Close() error { | |||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug("Tunnel closing") | ||||||
|  |  | ||||||
| 	select { | 	select { | ||||||
| 	case <-t.closed: | 	case <-t.closed: | ||||||
| 		return nil | 		return nil | ||||||
| @@ -650,7 +787,7 @@ func (t *tun) Close() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Dial an address | // Dial an address | ||||||
| func (t *tun) Dial(channel string) (Session, error) { | func (t *tun) Dial(channel string, opts ...DialOption) (Session, error) { | ||||||
| 	log.Debugf("Tunnel dialing %s", channel) | 	log.Debugf("Tunnel dialing %s", channel) | ||||||
| 	c, ok := t.newSession(channel, t.newSessionId()) | 	c, ok := t.newSession(channel, t.newSessionId()) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| @@ -663,18 +800,93 @@ func (t *tun) Dial(channel string) (Session, error) { | |||||||
| 	// outbound session | 	// outbound session | ||||||
| 	c.outbound = true | 	c.outbound = true | ||||||
|  |  | ||||||
|  | 	// get opts | ||||||
|  | 	options := DialOptions{ | ||||||
|  | 		Timeout: DefaultDialTimeout, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// set the multicast option | ||||||
|  | 	c.multicast = options.Multicast | ||||||
|  | 	// set the dial timeout | ||||||
|  | 	c.timeout = options.Timeout | ||||||
|  |  | ||||||
|  | 	// don't bother with the song and dance below | ||||||
|  | 	// we're just going to assume things come online | ||||||
|  | 	// as and when. | ||||||
|  | 	if c.multicast { | ||||||
|  | 		return c, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// non multicast so we need to find the link | ||||||
|  | 	t.RLock() | ||||||
|  | 	for _, link := range t.links { | ||||||
|  | 		link.RLock() | ||||||
|  | 		_, ok := link.channels[channel] | ||||||
|  | 		link.RUnlock() | ||||||
|  |  | ||||||
|  | 		// we have at least one channel mapping | ||||||
|  | 		if ok { | ||||||
|  | 			c.discovered = true | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	t.RUnlock() | ||||||
|  |  | ||||||
|  | 	// shit fuck | ||||||
|  | 	if !c.discovered { | ||||||
|  | 		msg := c.newMessage("discover") | ||||||
|  | 		msg.broadcast = true | ||||||
|  | 		msg.outbound = true | ||||||
|  | 		msg.link = "" | ||||||
|  |  | ||||||
|  | 		// send the discovery message | ||||||
|  | 		t.send <- msg | ||||||
|  |  | ||||||
|  | 		select { | ||||||
|  | 		case err := <-c.errChan: | ||||||
|  | 			if err != nil { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// wait for announce | ||||||
|  | 		select { | ||||||
|  | 		case msg := <-c.recv: | ||||||
|  | 			if msg.typ != "announce" { | ||||||
|  | 				return nil, errors.New("failed to discover channel") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// try to open the session | ||||||
|  | 	err := c.Open() | ||||||
|  | 	if err != nil { | ||||||
|  | 		// delete the session | ||||||
|  | 		t.delSession(c.channel, c.session) | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return c, nil | 	return c, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Accept a connection on the address | // Accept a connection on the address | ||||||
| func (t *tun) Listen(channel string) (Listener, error) { | func (t *tun) Listen(channel string) (Listener, error) { | ||||||
| 	log.Debugf("Tunnel listening on %s", channel) | 	log.Debugf("Tunnel listening on %s", channel) | ||||||
|  |  | ||||||
| 	// create a new session by hashing the address | 	// create a new session by hashing the address | ||||||
| 	c, ok := t.newSession(channel, "listener") | 	c, ok := t.newSession(channel, "listener") | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, errors.New("already listening on " + channel) | 		return nil, errors.New("already listening on " + channel) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	delFunc := func() { | ||||||
|  | 		t.delSession(channel, "listener") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// set remote. it will be replaced by the first message received | 	// set remote. it will be replaced by the first message received | ||||||
| 	c.remote = "remote" | 	c.remote = "remote" | ||||||
| 	// set local | 	// set local | ||||||
| @@ -690,6 +902,8 @@ func (t *tun) Listen(channel string) (Listener, error) { | |||||||
| 		tunClosed: t.closed, | 		tunClosed: t.closed, | ||||||
| 		// the listener session | 		// the listener session | ||||||
| 		session: c, | 		session: c, | ||||||
|  | 		// delete session | ||||||
|  | 		delFunc: delFunc, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// this kicks off the internal message processor | 	// this kicks off the internal message processor | ||||||
| @@ -698,10 +912,26 @@ func (t *tun) Listen(channel string) (Listener, error) { | |||||||
| 	// to the existign sessions | 	// to the existign sessions | ||||||
| 	go tl.process() | 	go tl.process() | ||||||
|  |  | ||||||
|  | 	// announces the listener channel to others | ||||||
|  | 	go tl.announce() | ||||||
|  |  | ||||||
| 	// return the listener | 	// return the listener | ||||||
| 	return tl, nil | 	return tl, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t *tun) Links() []Link { | ||||||
|  | 	t.RLock() | ||||||
|  | 	defer t.RUnlock() | ||||||
|  |  | ||||||
|  | 	var links []Link | ||||||
|  |  | ||||||
|  | 	for _, link := range t.links { | ||||||
|  | 		links = append(links, link) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return links | ||||||
|  | } | ||||||
|  |  | ||||||
| func (t *tun) String() string { | func (t *tun) String() string { | ||||||
| 	return "mucp" | 	return "mucp" | ||||||
| } | } | ||||||
|   | |||||||
| @@ -9,9 +9,10 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type link struct { | type link struct { | ||||||
|  | 	transport.Socket | ||||||
|  |  | ||||||
| 	sync.RWMutex | 	sync.RWMutex | ||||||
|  |  | ||||||
| 	transport.Socket |  | ||||||
| 	// unique id of this link e.g uuid | 	// unique id of this link e.g uuid | ||||||
| 	// which we define for ourselves | 	// which we define for ourselves | ||||||
| 	id string | 	id string | ||||||
| @@ -27,11 +28,77 @@ type link struct { | |||||||
| 	// the last time we received a keepalive | 	// the last time we received a keepalive | ||||||
| 	// on this link from the remote side | 	// on this link from the remote side | ||||||
| 	lastKeepAlive time.Time | 	lastKeepAlive time.Time | ||||||
|  | 	// channels keeps a mapping of channels and last seen | ||||||
|  | 	channels map[string]time.Time | ||||||
|  | 	// stop the link | ||||||
|  | 	closed chan bool | ||||||
| } | } | ||||||
|  |  | ||||||
| func newLink(s transport.Socket) *link { | func newLink(s transport.Socket) *link { | ||||||
| 	return &link{ | 	l := &link{ | ||||||
| 		Socket: s, | 		Socket:   s, | ||||||
| 		id:     uuid.New().String(), | 		id:       uuid.New().String(), | ||||||
|  | 		channels: make(map[string]time.Time), | ||||||
|  | 		closed:   make(chan bool), | ||||||
|  | 	} | ||||||
|  | 	go l.run() | ||||||
|  | 	return l | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l *link) run() { | ||||||
|  | 	t := time.NewTicker(time.Minute) | ||||||
|  | 	defer t.Stop() | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-l.closed: | ||||||
|  | 			return | ||||||
|  | 		case <-t.C: | ||||||
|  | 			// drop any channel mappings older than 2 minutes | ||||||
|  | 			var kill []string | ||||||
|  | 			killTime := time.Minute * 2 | ||||||
|  |  | ||||||
|  | 			l.RLock() | ||||||
|  | 			for ch, t := range l.channels { | ||||||
|  | 				if d := time.Since(t); d > killTime { | ||||||
|  | 					kill = append(kill, ch) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			l.RUnlock() | ||||||
|  |  | ||||||
|  | 			// if nothing to kill don't bother with a wasted lock | ||||||
|  | 			if len(kill) == 0 { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// kill the channels! | ||||||
|  | 			l.Lock() | ||||||
|  | 			for _, ch := range kill { | ||||||
|  | 				delete(l.channels, ch) | ||||||
|  | 			} | ||||||
|  | 			l.Unlock() | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (l *link) Id() string { | ||||||
|  | 	l.RLock() | ||||||
|  | 	defer l.RUnlock() | ||||||
|  |  | ||||||
|  | 	return l.id | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l *link) Close() error { | ||||||
|  | 	l.Lock() | ||||||
|  | 	defer l.Unlock() | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case <-l.closed: | ||||||
|  | 		return nil | ||||||
|  | 	default: | ||||||
|  | 		close(l.closed) | ||||||
|  | 		return l.Socket.Close() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package tunnel | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"io" | 	"io" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/util/log" | 	"github.com/micro/go-micro/util/log" | ||||||
| ) | ) | ||||||
| @@ -17,26 +18,62 @@ type tunListener struct { | |||||||
| 	tunClosed chan bool | 	tunClosed chan bool | ||||||
| 	// the listener session | 	// the listener session | ||||||
| 	session *session | 	session *session | ||||||
|  | 	// del func to kill listener | ||||||
|  | 	delFunc func() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // periodically announce self | ||||||
|  | func (t *tunListener) announce() { | ||||||
|  | 	tick := time.NewTicker(time.Minute) | ||||||
|  | 	defer tick.Stop() | ||||||
|  |  | ||||||
|  | 	// first announcement | ||||||
|  | 	t.session.Announce() | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-tick.C: | ||||||
|  | 			t.session.Announce() | ||||||
|  | 		case <-t.closed: | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *tunListener) process() { | func (t *tunListener) process() { | ||||||
| 	// our connection map for session | 	// our connection map for session | ||||||
| 	conns := make(map[string]*session) | 	conns := make(map[string]*session) | ||||||
|  |  | ||||||
|  | 	defer func() { | ||||||
|  | 		// close the sessions | ||||||
|  | 		for _, conn := range conns { | ||||||
|  | 			conn.Close() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case <-t.closed: | 		case <-t.closed: | ||||||
| 			return | 			return | ||||||
|  | 		case <-t.tunClosed: | ||||||
|  | 			t.Close() | ||||||
|  | 			return | ||||||
| 		// receive a new message | 		// receive a new message | ||||||
| 		case m := <-t.session.recv: | 		case m := <-t.session.recv: | ||||||
| 			// get a session | 			// get a session | ||||||
| 			sess, ok := conns[m.session] | 			sess, ok := conns[m.session] | ||||||
| 			log.Debugf("Tunnel listener received id %s session %s exists: %t", m.id, m.session, ok) | 			log.Debugf("Tunnel listener received channel %s session %s exists: %t", m.channel, m.session, ok) | ||||||
| 			if !ok { | 			if !ok { | ||||||
|  | 				switch m.typ { | ||||||
|  | 				case "open", "session": | ||||||
|  | 				default: | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  |  | ||||||
| 				// create a new session session | 				// create a new session session | ||||||
| 				sess = &session{ | 				sess = &session{ | ||||||
| 					// the id of the remote side | 					// the id of the remote side | ||||||
| 					id: m.id, | 					tunnel: m.tunnel, | ||||||
| 					// the channel | 					// the channel | ||||||
| 					channel: m.channel, | 					channel: m.channel, | ||||||
| 					// the session id | 					// the session id | ||||||
| @@ -45,6 +82,8 @@ func (t *tunListener) process() { | |||||||
| 					loopback: m.loopback, | 					loopback: m.loopback, | ||||||
| 					// the link the message was received on | 					// the link the message was received on | ||||||
| 					link: m.link, | 					link: m.link, | ||||||
|  | 					// set multicast | ||||||
|  | 					multicast: m.multicast, | ||||||
| 					// close chan | 					// close chan | ||||||
| 					closed: make(chan bool), | 					closed: make(chan bool), | ||||||
| 					// recv called by the acceptor | 					// recv called by the acceptor | ||||||
| @@ -60,20 +99,44 @@ func (t *tunListener) process() { | |||||||
| 				// save the session | 				// save the session | ||||||
| 				conns[m.session] = sess | 				conns[m.session] = sess | ||||||
|  |  | ||||||
| 				// send to accept chan |  | ||||||
| 				select { | 				select { | ||||||
| 				case <-t.closed: | 				case <-t.closed: | ||||||
| 					return | 					return | ||||||
|  | 				// send to accept chan | ||||||
| 				case t.accept <- sess: | 				case t.accept <- sess: | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// an existing session was found | ||||||
|  |  | ||||||
|  | 			// received a close message | ||||||
|  | 			switch m.typ { | ||||||
|  | 			case "close": | ||||||
|  | 				select { | ||||||
|  | 				case <-sess.closed: | ||||||
|  | 					// no op | ||||||
|  | 					delete(conns, m.session) | ||||||
|  | 				default: | ||||||
|  | 					// close and delete session | ||||||
|  | 					close(sess.closed) | ||||||
|  | 					delete(conns, m.session) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// continue | ||||||
|  | 				continue | ||||||
|  | 			case "session": | ||||||
|  | 				// operate on this | ||||||
|  | 			default: | ||||||
|  | 				// non operational type | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// send this to the accept chan | 			// send this to the accept chan | ||||||
| 			select { | 			select { | ||||||
| 			case <-sess.closed: | 			case <-sess.closed: | ||||||
| 				delete(conns, m.session) | 				delete(conns, m.session) | ||||||
| 			case sess.recv <- m: | 			case sess.recv <- m: | ||||||
| 				log.Debugf("Tunnel listener sent to recv chan id %s session %s", m.id, m.session) | 				log.Debugf("Tunnel listener sent to recv chan channel %s session %s", m.channel, m.session) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -89,6 +152,9 @@ func (t *tunListener) Close() error { | |||||||
| 	case <-t.closed: | 	case <-t.closed: | ||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
|  | 		// close and delete | ||||||
|  | 		t.delFunc() | ||||||
|  | 		t.session.Close() | ||||||
| 		close(t.closed) | 		close(t.closed) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| @@ -102,13 +168,17 @@ func (t *tunListener) Accept() (Session, error) { | |||||||
| 		return nil, io.EOF | 		return nil, io.EOF | ||||||
| 	case <-t.tunClosed: | 	case <-t.tunClosed: | ||||||
| 		// close the listener when the tunnel closes | 		// close the listener when the tunnel closes | ||||||
| 		t.Close() |  | ||||||
| 		return nil, io.EOF | 		return nil, io.EOF | ||||||
| 	// wait for a new connection | 	// wait for a new connection | ||||||
| 	case c, ok := <-t.accept: | 	case c, ok := <-t.accept: | ||||||
|  | 		// check if the accept chan is closed | ||||||
| 		if !ok { | 		if !ok { | ||||||
| 			return nil, io.EOF | 			return nil, io.EOF | ||||||
| 		} | 		} | ||||||
|  | 		// send back the accept | ||||||
|  | 		if err := c.Accept(); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
| 		return c, nil | 		return c, nil | ||||||
| 	} | 	} | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| package tunnel | package tunnel | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| 	"github.com/micro/go-micro/transport/quic" | 	"github.com/micro/go-micro/transport/quic" | ||||||
| @@ -29,6 +31,15 @@ type Options struct { | |||||||
| 	Transport transport.Transport | 	Transport transport.Transport | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type DialOption func(*DialOptions) | ||||||
|  |  | ||||||
|  | type DialOptions struct { | ||||||
|  | 	// specify a multicast connection | ||||||
|  | 	Multicast bool | ||||||
|  | 	// the dial timeout | ||||||
|  | 	Timeout time.Duration | ||||||
|  | } | ||||||
|  |  | ||||||
| // The tunnel id | // The tunnel id | ||||||
| func Id(id string) Option { | func Id(id string) Option { | ||||||
| 	return func(o *Options) { | 	return func(o *Options) { | ||||||
| @@ -73,3 +84,18 @@ func DefaultOptions() Options { | |||||||
| 		Transport: quic.NewTransport(), | 		Transport: quic.NewTransport(), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Dial options | ||||||
|  |  | ||||||
|  | // Dial multicast sets the multicast option to send only to those mapped | ||||||
|  | func DialMulticast() DialOption { | ||||||
|  | 	return func(o *DialOptions) { | ||||||
|  | 		o.Multicast = true | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DialTimeout(t time.Duration) DialOption { | ||||||
|  | 	return func(o *DialOptions) { | ||||||
|  | 		o.Timeout = t | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ package tunnel | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| 	"github.com/micro/go-micro/util/log" | 	"github.com/micro/go-micro/util/log" | ||||||
| @@ -10,8 +11,8 @@ import ( | |||||||
|  |  | ||||||
| // session is our pseudo session for transport.Socket | // session is our pseudo session for transport.Socket | ||||||
| type session struct { | type session struct { | ||||||
| 	// unique id based on the remote tunnel id | 	// the tunnel id | ||||||
| 	id string | 	tunnel string | ||||||
| 	// the channel name | 	// the channel name | ||||||
| 	channel string | 	channel string | ||||||
| 	// the session id based on Micro.Tunnel-Session | 	// the session id based on Micro.Tunnel-Session | ||||||
| @@ -28,10 +29,20 @@ type session struct { | |||||||
| 	recv chan *message | 	recv chan *message | ||||||
| 	// wait until we have a connection | 	// wait until we have a connection | ||||||
| 	wait chan bool | 	wait chan bool | ||||||
|  | 	// if the discovery worked | ||||||
|  | 	discovered bool | ||||||
|  | 	// if the session was accepted | ||||||
|  | 	accepted bool | ||||||
| 	// outbound marks the session as outbound dialled connection | 	// outbound marks the session as outbound dialled connection | ||||||
| 	outbound bool | 	outbound bool | ||||||
| 	// lookback marks the session as a loopback on the inbound | 	// lookback marks the session as a loopback on the inbound | ||||||
| 	loopback bool | 	loopback bool | ||||||
|  | 	// if the session is multicast | ||||||
|  | 	multicast bool | ||||||
|  | 	// if the session is broadcast | ||||||
|  | 	broadcast bool | ||||||
|  | 	// the timeout | ||||||
|  | 	timeout time.Duration | ||||||
| 	// the link on which this message was received | 	// the link on which this message was received | ||||||
| 	link string | 	link string | ||||||
| 	// the error response | 	// the error response | ||||||
| @@ -43,7 +54,7 @@ type message struct { | |||||||
| 	// type of message | 	// type of message | ||||||
| 	typ string | 	typ string | ||||||
| 	// tunnel id | 	// tunnel id | ||||||
| 	id string | 	tunnel string | ||||||
| 	// channel name | 	// channel name | ||||||
| 	channel string | 	channel string | ||||||
| 	// the session id | 	// the session id | ||||||
| @@ -52,6 +63,10 @@ type message struct { | |||||||
| 	outbound bool | 	outbound bool | ||||||
| 	// loopback marks the message intended for loopback | 	// loopback marks the message intended for loopback | ||||||
| 	loopback bool | 	loopback bool | ||||||
|  | 	// whether to send as multicast | ||||||
|  | 	multicast bool | ||||||
|  | 	// broadcast sets the broadcast type | ||||||
|  | 	broadcast bool | ||||||
| 	// the link to send the message on | 	// the link to send the message on | ||||||
| 	link string | 	link string | ||||||
| 	// transport data | 	// transport data | ||||||
| @@ -76,10 +91,111 @@ func (s *session) Channel() string { | |||||||
| 	return s.channel | 	return s.channel | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // newMessage creates a new message based on the session | ||||||
|  | func (s *session) newMessage(typ string) *message { | ||||||
|  | 	return &message{ | ||||||
|  | 		typ:       typ, | ||||||
|  | 		tunnel:    s.tunnel, | ||||||
|  | 		channel:   s.channel, | ||||||
|  | 		session:   s.session, | ||||||
|  | 		outbound:  s.outbound, | ||||||
|  | 		loopback:  s.loopback, | ||||||
|  | 		multicast: s.multicast, | ||||||
|  | 		link:      s.link, | ||||||
|  | 		errChan:   s.errChan, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Open will fire the open message for the session. This is called by the dialler. | ||||||
|  | func (s *session) Open() error { | ||||||
|  | 	// create a new message | ||||||
|  | 	msg := s.newMessage("open") | ||||||
|  |  | ||||||
|  | 	// send open message | ||||||
|  | 	s.send <- msg | ||||||
|  |  | ||||||
|  | 	// wait for an error response for send | ||||||
|  | 	select { | ||||||
|  | 	case err := <-msg.errChan: | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// we don't wait on multicast | ||||||
|  | 	if s.multicast { | ||||||
|  | 		s.accepted = true | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// now wait for the accept | ||||||
|  | 	select { | ||||||
|  | 	case msg = <-s.recv: | ||||||
|  | 		if msg.typ != "accept" { | ||||||
|  | 			log.Debugf("Received non accept message in Open %s", msg.typ) | ||||||
|  | 			return errors.New("failed to connect") | ||||||
|  | 		} | ||||||
|  | 		// set to accepted | ||||||
|  | 		s.accepted = true | ||||||
|  | 		// set link | ||||||
|  | 		s.link = msg.link | ||||||
|  | 	case <-time.After(s.timeout): | ||||||
|  | 		return ErrDialTimeout | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Accept sends the accept response to an open message from a dialled connection | ||||||
|  | func (s *session) Accept() error { | ||||||
|  | 	msg := s.newMessage("accept") | ||||||
|  |  | ||||||
|  | 	// send the accept message | ||||||
|  | 	select { | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	case s.send <- msg: | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// wait for send response | ||||||
|  | 	select { | ||||||
|  | 	case err := <-s.errChan: | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Announce sends an announcement to notify that this session exists. This is primarily used by the listener. | ||||||
|  | func (s *session) Announce() error { | ||||||
|  | 	msg := s.newMessage("announce") | ||||||
|  | 	// we don't need an error back | ||||||
|  | 	msg.errChan = nil | ||||||
|  | 	// we don't need the link | ||||||
|  | 	msg.link = "" | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case s.send <- msg: | ||||||
|  | 		return nil | ||||||
|  | 	case <-s.closed: | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Send is used to send a message | ||||||
| func (s *session) Send(m *transport.Message) error { | func (s *session) Send(m *transport.Message) error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.closed: | 	case <-s.closed: | ||||||
| 		return errors.New("session is closed") | 		return io.EOF | ||||||
| 	default: | 	default: | ||||||
| 		// no op | 		// no op | ||||||
| 	} | 	} | ||||||
| @@ -94,22 +210,18 @@ func (s *session) Send(m *transport.Message) error { | |||||||
| 		data.Header[k] = v | 		data.Header[k] = v | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// append to backlog | 	// create a new message | ||||||
| 	msg := &message{ | 	msg := s.newMessage("session") | ||||||
| 		typ:      "message", | 	// set the data | ||||||
| 		id:       s.id, | 	msg.data = data | ||||||
| 		channel:  s.channel, |  | ||||||
| 		session:  s.session, | 	// if multicast don't set the link | ||||||
| 		outbound: s.outbound, | 	if s.multicast { | ||||||
| 		loopback: s.loopback, | 		msg.link = "" | ||||||
| 		data:     data, |  | ||||||
| 		// specify the link on which to send this |  | ||||||
| 		// it will be blank for dialled sessions |  | ||||||
| 		link: s.link, |  | ||||||
| 		// error chan |  | ||||||
| 		errChan: s.errChan, |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debugf("Appending %+v to send backlog", msg) | 	log.Debugf("Appending %+v to send backlog", msg) | ||||||
|  | 	// send the actual message | ||||||
| 	s.send <- msg | 	s.send <- msg | ||||||
|  |  | ||||||
| 	// wait for an error response | 	// wait for an error response | ||||||
| @@ -123,6 +235,7 @@ func (s *session) Send(m *transport.Message) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Recv is used to receive a message | ||||||
| func (s *session) Recv(m *transport.Message) error { | func (s *session) Recv(m *transport.Message) error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.closed: | 	case <-s.closed: | ||||||
| @@ -147,13 +260,25 @@ func (s *session) Recv(m *transport.Message) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Close closes the session | // Close closes the session by sending a close message | ||||||
| func (s *session) Close() error { | func (s *session) Close() error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.closed: | 	case <-s.closed: | ||||||
| 		// no op | 		// no op | ||||||
| 	default: | 	default: | ||||||
| 		close(s.closed) | 		close(s.closed) | ||||||
|  |  | ||||||
|  | 		// append to backlog | ||||||
|  | 		msg := s.newMessage("close") | ||||||
|  | 		// no error response on close | ||||||
|  | 		msg.errChan = nil | ||||||
|  |  | ||||||
|  | 		// send the close message | ||||||
|  | 		select { | ||||||
|  | 		case s.send <- msg: | ||||||
|  | 		default: | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,9 +2,19 @@ | |||||||
| package tunnel | package tunnel | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	// ErrDialTimeout is returned by a call to Dial where the timeout occurs | ||||||
|  | 	ErrDialTimeout = errors.New("dial timeout") | ||||||
|  | 	// DefaultDialTimeout is the dial timeout if none is specified | ||||||
|  | 	DefaultDialTimeout = time.Second * 5 | ||||||
|  | ) | ||||||
|  |  | ||||||
| // Tunnel creates a gre tunnel on top of the go-micro/transport. | // Tunnel creates a gre tunnel on top of the go-micro/transport. | ||||||
| // It establishes multiple streams using the Micro-Tunnel-Channel header | // It establishes multiple streams using the Micro-Tunnel-Channel header | ||||||
| // and Micro-Tunnel-Session header. The tunnel id is a hash of | // and Micro-Tunnel-Session header. The tunnel id is a hash of | ||||||
| @@ -18,13 +28,23 @@ type Tunnel interface { | |||||||
| 	// Close closes the tunnel | 	// Close closes the tunnel | ||||||
| 	Close() error | 	Close() error | ||||||
| 	// Connect to a channel | 	// Connect to a channel | ||||||
| 	Dial(channel string) (Session, error) | 	Dial(channel string, opts ...DialOption) (Session, error) | ||||||
| 	// Accept connections on a channel | 	// Accept connections on a channel | ||||||
| 	Listen(channel string) (Listener, error) | 	Listen(channel string) (Listener, error) | ||||||
|  | 	// All the links the tunnel is connected to | ||||||
|  | 	Links() []Link | ||||||
| 	// Name of the tunnel implementation | 	// Name of the tunnel implementation | ||||||
| 	String() string | 	String() string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Link represents internal links to the tunnel | ||||||
|  | type Link interface { | ||||||
|  | 	// The id of the link | ||||||
|  | 	Id() string | ||||||
|  | 	// honours transport socket | ||||||
|  | 	transport.Socket | ||||||
|  | } | ||||||
|  |  | ||||||
| // The listener provides similar constructs to the transport.Listener | // The listener provides similar constructs to the transport.Listener | ||||||
| type Listener interface { | type Listener interface { | ||||||
| 	Accept() (Session, error) | 	Accept() (Session, error) | ||||||
|   | |||||||
| @@ -187,30 +187,15 @@ func testBrokenTunAccept(t *testing.T, tun Tunnel, wait chan bool, wg *sync.Wait | |||||||
| 	if err := c.Recv(m); err != nil { | 	if err := c.Recv(m); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	tun.Close() |  | ||||||
|  |  | ||||||
| 	// re-start tunnel | 	// close all the links | ||||||
| 	err = tun.Connect() | 	for _, link := range tun.Links() { | ||||||
| 	if err != nil { | 		link.Close() | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer tun.Close() |  | ||||||
|  |  | ||||||
| 	// listen on some virtual address |  | ||||||
| 	tl, err = tun.Listen("test-tunnel") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// receiver ready; notify sender | 	// receiver ready; notify sender | ||||||
| 	wait <- true | 	wait <- true | ||||||
|  |  | ||||||
| 	// accept a connection |  | ||||||
| 	c, err = tl.Accept() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// accept the message | 	// accept the message | ||||||
| 	m = new(transport.Message) | 	m = new(transport.Message) | ||||||
| 	if err := c.Recv(m); err != nil { | 	if err := c.Recv(m); err != nil { | ||||||
| @@ -279,6 +264,7 @@ func TestReconnectTunnel(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | 	defer tunB.Close() | ||||||
|  |  | ||||||
| 	// we manually override the tunnel.ReconnectTime value here | 	// we manually override the tunnel.ReconnectTime value here | ||||||
| 	// this is so that we make the reconnects faster than the default 5s | 	// this is so that we make the reconnects faster than the default 5s | ||||||
| @@ -289,6 +275,7 @@ func TestReconnectTunnel(t *testing.T) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | 	defer tunA.Close() | ||||||
|  |  | ||||||
| 	wait := make(chan bool) | 	wait := make(chan bool) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user