move tunnel/resolver into network

This commit is contained in:
Asim Aslam
2020-08-23 15:00:27 +01:00
parent 44f281f8d9
commit d60d85de5c
21 changed files with 22 additions and 22 deletions

View File

@@ -0,0 +1,215 @@
// Package broker is a tunnel broker
package broker
import (
"context"
"github.com/micro/go-micro/v3/broker"
"github.com/micro/go-micro/v3/transport"
"github.com/micro/go-micro/v3/network/tunnel"
"github.com/micro/go-micro/v3/network/tunnel/mucp"
)
type tunBroker struct {
opts broker.Options
tunnel tunnel.Tunnel
}
type tunSubscriber struct {
topic string
handler broker.Handler
opts broker.SubscribeOptions
closed chan bool
listener tunnel.Listener
}
type tunEvent struct {
topic string
message *broker.Message
}
// used to access tunnel from options context
type tunnelKey struct{}
type tunnelAddr struct{}
func (t *tunBroker) Init(opts ...broker.Option) error {
for _, o := range opts {
o(&t.opts)
}
return nil
}
func (t *tunBroker) Options() broker.Options {
return t.opts
}
func (t *tunBroker) Address() string {
return t.tunnel.Address()
}
func (t *tunBroker) Connect() error {
return t.tunnel.Connect()
}
func (t *tunBroker) Disconnect() error {
return t.tunnel.Close()
}
func (t *tunBroker) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error {
// TODO: this is probably inefficient, we might want to just maintain an open connection
// it may be easier to add broadcast to the tunnel
c, err := t.tunnel.Dial(topic, tunnel.DialMode(tunnel.Multicast))
if err != nil {
return err
}
defer c.Close()
return c.Send(&transport.Message{
Header: m.Header,
Body: m.Body,
})
}
func (t *tunBroker) Subscribe(topic string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) {
l, err := t.tunnel.Listen(topic, tunnel.ListenMode(tunnel.Multicast))
if err != nil {
return nil, err
}
var options broker.SubscribeOptions
for _, o := range opts {
o(&options)
}
tunSub := &tunSubscriber{
topic: topic,
handler: h,
opts: options,
closed: make(chan bool),
listener: l,
}
// start processing
go tunSub.run()
return tunSub, nil
}
func (t *tunBroker) String() string {
return "tunnel"
}
func (t *tunSubscriber) run() {
for {
// accept a new connection
c, err := t.listener.Accept()
if err != nil {
select {
case <-t.closed:
return
default:
continue
}
}
// receive message
m := new(transport.Message)
if err := c.Recv(m); err != nil {
c.Close()
continue
}
// close the connection
c.Close()
// handle the message
go t.handler(&broker.Message{
Header: m.Header,
Body: m.Body,
})
}
}
func (t *tunSubscriber) Options() broker.SubscribeOptions {
return t.opts
}
func (t *tunSubscriber) Topic() string {
return t.topic
}
func (t *tunSubscriber) Unsubscribe() error {
select {
case <-t.closed:
return nil
default:
close(t.closed)
return t.listener.Close()
}
}
func (t *tunEvent) Topic() string {
return t.topic
}
func (t *tunEvent) Message() *broker.Message {
return t.message
}
func (t *tunEvent) Ack() error {
return nil
}
func (t *tunEvent) Error() error {
return nil
}
func NewBroker(opts ...broker.Option) broker.Broker {
options := broker.Options{
Context: context.Background(),
}
for _, o := range opts {
o(&options)
}
t, ok := options.Context.Value(tunnelKey{}).(tunnel.Tunnel)
if !ok {
t = mucp.NewTunnel()
}
a, ok := options.Context.Value(tunnelAddr{}).(string)
if ok {
// initialise address
t.Init(tunnel.Address(a))
}
if len(options.Addrs) > 0 {
// initialise nodes
t.Init(tunnel.Nodes(options.Addrs...))
}
return &tunBroker{
opts: options,
tunnel: t,
}
}
// WithAddress sets the tunnel address
func WithAddress(a string) broker.Option {
return func(o *broker.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, tunnelAddr{}, a)
}
}
// WithTunnel sets the internal tunnel
func WithTunnel(t tunnel.Tunnel) broker.Option {
return func(o *broker.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, tunnelKey{}, t)
}
}

View File

@@ -0,0 +1,84 @@
package mucp
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"github.com/micro/go-micro/v3/network/tunnel"
"github.com/oxtoacart/bpool"
)
var (
// the local buffer pool
// gcmStandardNonceSize from crypto/cipher/gcm.go is 12 bytes
// 100 - is max size of pool
noncePool = bpool.NewBytePool(100, 12)
)
// hash hahes the data into 32 bytes key and returns it
// hash uses sha256 underneath to hash the supplied key
func hash(key []byte) []byte {
sum := sha256.Sum256(key)
return sum[:]
}
// Encrypt encrypts data and returns the encrypted data
func Encrypt(gcm cipher.AEAD, data []byte) ([]byte, error) {
var err error
// get new byte array the size of the nonce from pool
// NOTE: we might use smaller nonce size in the future
nonce := noncePool.Get()
if _, err = rand.Read(nonce); err != nil {
return nil, err
}
defer noncePool.Put(nonce)
// NOTE: we prepend the nonce to the payload
// we need to do this as we need the same nonce
// to decrypt the payload when receiving it
return gcm.Seal(nonce, nonce, data, nil), nil
}
// Decrypt decrypts the payload and returns the decrypted data
func newCipher(key []byte) (cipher.AEAD, error) {
var err error
// generate a new AES cipher using our 32 byte key for decrypting the message
c, err := aes.NewCipher(hash(key))
if err != nil {
return nil, err
}
// gcm or Galois/Counter Mode, is a mode of operation
// for symmetric key cryptographic block ciphers
// - https://en.wikipedia.org/wiki/Galois/Counter_Mode
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return gcm, nil
}
func Decrypt(gcm cipher.AEAD, data []byte) ([]byte, error) {
var err error
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, tunnel.ErrDecryptingData
}
// NOTE: we need to parse out nonce from the payload
// we prepend the nonce to every encrypted payload
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
ciphertext, err = gcm.Open(ciphertext[:0], nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return ciphertext, nil
}

View File

@@ -0,0 +1,51 @@
package mucp
import (
"bytes"
"testing"
)
func TestEncrypt(t *testing.T) {
key := []byte("tokenpassphrase")
gcm, err := newCipher(key)
if err != nil {
t.Fatal(err)
}
data := []byte("supersecret")
cipherText, err := Encrypt(gcm, data)
if err != nil {
t.Errorf("failed to encrypt data: %v", err)
}
// verify the cipherText is not the same as data
if bytes.Equal(data, cipherText) {
t.Error("encrypted data are the same as plaintext")
}
}
func TestDecrypt(t *testing.T) {
key := []byte("tokenpassphrase")
gcm, err := newCipher(key)
if err != nil {
t.Fatal(err)
}
data := []byte("supersecret")
cipherText, err := Encrypt(gcm, data)
if err != nil {
t.Errorf("failed to encrypt data: %v", err)
}
plainText, err := Decrypt(gcm, cipherText)
if err != nil {
t.Errorf("failed to decrypt data: %v", err)
}
// verify the plainText is the same as data
if !bytes.Equal(data, plainText) {
t.Error("decrypted data not the same as plaintext")
}
}

535
network/tunnel/mucp/link.go Normal file
View File

@@ -0,0 +1,535 @@
package mucp
import (
"bytes"
"errors"
"io"
"sync"
"time"
"github.com/google/uuid"
"github.com/micro/go-micro/v3/logger"
"github.com/micro/go-micro/v3/transport"
)
type link struct {
transport.Socket
// transport to use for connections
transport transport.Transport
sync.RWMutex
// stops the link
closed chan bool
// metric used to track metrics
metric chan *metric
// link state channel for testing link
state chan *packet
// send queue for sending packets
sendQueue chan *packet
// receive queue for receiving packets
recvQueue chan *packet
// unique id of this link e.g uuid
// which we define for ourselves
id string
// whether its a loopback connection
// this flag is used by the transport listener
// which accepts inbound quic connections
loopback bool
// whether its actually connected
// dialled side sets it to connected
// after sending the message. the
// listener waits for the connect
connected bool
// the last time we received a keepalive
// on this link from the remote side
lastKeepAlive time.Time
// channels keeps a mapping of channels and last seen
channels map[string]time.Time
// the weighted moving average roundtrip
length int64
// weighted moving average of bits flowing
rate float64
// keep an error count on the link
errCount int
}
// packet send over link
type packet struct {
// message to send or received
message *transport.Message
// status returned when sent
status chan error
// receive related error
err error
}
// metric is used to record link rate
type metric struct {
// amount of data sent
data int
// time taken to send
duration time.Duration
// if an error occurred
status error
}
var (
// the 4 byte 0 packet sent to determine the link state
linkRequest = []byte{0, 0, 0, 0}
// the 4 byte 1 filled packet sent to determine link state
linkResponse = []byte{1, 1, 1, 1}
ErrLinkConnectTimeout = errors.New("link connect timeout")
)
func newLink(s transport.Socket) *link {
l := &link{
Socket: s,
id: uuid.New().String(),
lastKeepAlive: time.Now(),
closed: make(chan bool),
channels: make(map[string]time.Time),
state: make(chan *packet, 64),
sendQueue: make(chan *packet, 128),
recvQueue: make(chan *packet, 128),
metric: make(chan *metric, 128),
}
// process inbound/outbound packets
go l.process()
// manage the link state
go l.manage()
return l
}
// setRate sets the bits per second rate as a float64
func (l *link) setRate(bits int64, delta time.Duration) {
// rate of send in bits per nanosecond
rate := float64(bits) / float64(delta.Nanoseconds())
// default the rate if its zero
if l.rate == 0 {
// rate per second
l.rate = rate * 1e9
} else {
// set new rate per second
l.rate = 0.8*l.rate + 0.2*(rate*1e9)
}
}
// setRTT sets a nanosecond based moving average roundtrip time for the link
func (l *link) setRTT(d time.Duration) {
l.Lock()
if l.length <= 0 {
l.length = d.Nanoseconds()
l.Unlock()
return
}
// https://fishi.devtail.io/weblog/2015/04/12/measuring-bandwidth-and-round-trip-time-tcp-connection-inside-application-layer/
length := 0.8*float64(l.length) + 0.2*float64(d.Nanoseconds())
// set new length
l.length = int64(length)
l.Unlock()
}
func (l *link) delChannel(ch string) {
l.Lock()
delete(l.channels, ch)
l.Unlock()
}
func (l *link) getChannel(ch string) time.Time {
l.RLock()
t := l.channels[ch]
l.RUnlock()
return t
}
func (l *link) setChannel(channels ...string) {
l.Lock()
for _, ch := range channels {
l.channels[ch] = time.Now()
}
l.Unlock()
}
// set the keepalive time
func (l *link) keepalive() {
l.Lock()
l.lastKeepAlive = time.Now()
l.Unlock()
}
// process deals with the send queue
func (l *link) process() {
// receive messages
go func() {
for {
m := new(transport.Message)
err := l.recv(m)
if err != nil {
// record the metric
select {
case l.metric <- &metric{status: err}:
default:
}
}
// process new received message
pk := &packet{message: m, err: err}
// this is our link state packet
if m.Header["Micro-Method"] == "link" {
// process link state message
select {
case l.state <- pk:
case <-l.closed:
return
default:
}
continue
}
// process all messages as is
select {
case l.recvQueue <- pk:
case <-l.closed:
return
}
}
}()
// send messages
for {
select {
case pk := <-l.sendQueue:
// send the message
select {
case pk.status <- l.send(pk.message):
case <-l.closed:
return
}
case <-l.closed:
return
}
}
}
// manage manages the link state including rtt packets and channel mapping expiry
func (l *link) manage() {
// tick over every minute to expire and fire rtt packets
t1 := time.NewTicker(time.Minute)
defer t1.Stop()
// used to batch update link metrics
t2 := time.NewTicker(time.Second * 5)
defer t2.Stop()
// get link id
linkId := l.Id()
// used to send link state packets
send := func(b []byte) error {
return l.Send(&transport.Message{
Header: map[string]string{
"Micro-Method": "link",
"Micro-Link-Id": linkId,
}, Body: b,
})
}
// set time now
now := time.Now()
// send the initial rtt request packet
send(linkRequest)
for {
select {
// exit if closed
case <-l.closed:
return
// process link state rtt packets
case p := <-l.state:
if p.err != nil {
continue
}
// check the type of message
switch {
case bytes.Equal(p.message.Body, linkRequest):
if logger.V(logger.TraceLevel, log) {
log.Tracef("Link %s received link request", linkId)
}
// send response
if err := send(linkResponse); err != nil {
l.Lock()
l.errCount++
l.Unlock()
}
case bytes.Equal(p.message.Body, linkResponse):
// set round trip time
d := time.Since(now)
if logger.V(logger.TraceLevel, log) {
log.Tracef("Link %s received link response in %v", linkId, d)
}
// set the RTT
l.setRTT(d)
}
case <-t1.C:
// drop any channel mappings older than 2 minutes
var kill []string
killTime := time.Minute * 2
l.RLock()
for ch, t := range l.channels {
if d := time.Since(t); d > killTime {
kill = append(kill, ch)
}
}
l.RUnlock()
// if nothing to kill don't bother with a wasted lock
if len(kill) == 0 {
continue
}
// kill the channels!
l.Lock()
for _, ch := range kill {
delete(l.channels, ch)
}
l.Unlock()
// fire off a link state rtt packet
now = time.Now()
send(linkRequest)
case <-t2.C:
// get a batch of metrics
batch := l.batch()
// skip if there's no metrics
if len(batch) == 0 {
continue
}
// lock once to record a batch
l.Lock()
for _, metric := range batch {
l.record(metric)
}
l.Unlock()
}
}
}
func (l *link) batch() []*metric {
var metrics []*metric
// pull all the metrics
for {
select {
case m := <-l.metric:
metrics = append(metrics, m)
// non blocking return
default:
return metrics
}
}
}
func (l *link) record(m *metric) {
// there's an error increment the counter and bail
if m.status != nil {
l.errCount++
return
}
// reset the counter
l.errCount = 0
// calculate based on data
if m.data > 0 {
// bit sent
bits := m.data * 1024
// set the rate
l.setRate(int64(bits), m.duration)
}
}
func (l *link) send(m *transport.Message) error {
if m.Header == nil {
m.Header = make(map[string]string)
}
// send the message
return l.Socket.Send(m)
}
// recv a message on the link
func (l *link) recv(m *transport.Message) error {
if m.Header == nil {
m.Header = make(map[string]string)
}
// receive the transport message
return l.Socket.Recv(m)
}
// Delay is the current load on the link
func (l *link) Delay() int64 {
return int64(len(l.sendQueue) + len(l.recvQueue))
}
// Current transfer rate as bits per second (lower is better)
func (l *link) Rate() float64 {
l.RLock()
r := l.rate
l.RUnlock()
return r
}
func (l *link) Loopback() bool {
l.RLock()
lo := l.loopback
l.RUnlock()
return lo
}
// Length returns the roundtrip time as nanoseconds (lower is better).
// Returns 0 where no measurement has been taken.
func (l *link) Length() int64 {
l.RLock()
length := l.length
l.RUnlock()
return length
}
func (l *link) Id() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
func (l *link) Close() error {
l.Lock()
defer l.Unlock()
select {
case <-l.closed:
return nil
default:
l.Socket.Close()
close(l.closed)
}
return nil
}
// Send sencs a message on the link
func (l *link) Send(m *transport.Message) error {
// create a new packet to send over the link
p := &packet{
message: m,
status: make(chan error, 1),
}
// calculate the data sent
dataSent := len(m.Body)
// set header length
for k, v := range m.Header {
dataSent += (len(k) + len(v))
}
// get time now
now := time.Now()
// queue the message
select {
case l.sendQueue <- p:
// in the send queue
case <-l.closed:
return io.EOF
}
// error to use
var err error
// wait for response
select {
case <-l.closed:
return io.EOF
case err = <-p.status:
}
// create a metric with
// time taken, size of package, error status
mt := &metric{
data: dataSent,
duration: time.Since(now),
status: err,
}
// pass back a metric
// do not block
select {
case l.metric <- mt:
default:
}
return nil
}
// Accept accepts a message on the socket
func (l *link) Recv(m *transport.Message) error {
select {
case <-l.closed:
// check if there's any messages left
select {
case pk := <-l.recvQueue:
// check the packet receive error
if pk.err != nil {
return pk.err
}
*m = *pk.message
default:
return io.EOF
}
case pk := <-l.recvQueue:
// check the packet receive error
if pk.err != nil {
return pk.err
}
*m = *pk.message
}
return nil
}
// State can return connected, closed, error
func (l *link) State() string {
select {
case <-l.closed:
return "closed"
default:
l.RLock()
errCount := l.errCount
l.RUnlock()
if errCount > 3 {
return "error"
}
return "connected"
}
}

View File

@@ -0,0 +1,212 @@
package mucp
import (
"io"
"sync"
"github.com/micro/go-micro/v3/logger"
"github.com/micro/go-micro/v3/network/tunnel"
)
type tunListener struct {
// address of the listener
channel string
// token is the tunnel token
token string
// the accept channel
accept chan *session
// the tunnel closed channel
tunClosed chan bool
// the listener session
session *session
// del func to kill listener
delFunc func()
sync.RWMutex
// the channel to close
closed chan bool
}
func (t *tunListener) process() {
// our connection map for session
conns := make(map[string]*session)
defer func() {
// close the sessions
for id, conn := range conns {
conn.Close()
delete(conns, id)
}
// unassign
conns = nil
}()
for {
select {
case <-t.closed:
return
case <-t.tunClosed:
t.Close()
return
// receive a new message
case m := <-t.session.recv:
var sessionId string
var linkId string
switch t.session.mode {
case tunnel.Multicast:
sessionId = "multicast"
linkId = "multicast"
case tunnel.Broadcast:
sessionId = "broadcast"
linkId = "broadcast"
default:
sessionId = m.session
linkId = m.link
}
// get a session
sess, ok := conns[sessionId]
if logger.V(logger.TraceLevel, log) {
log.Tracef("Tunnel listener received channel %s session %s type %s exists: %t", m.channel, m.session, m.typ, ok)
}
if !ok {
// we only process open and session types
switch m.typ {
case "open", "session":
default:
continue
}
// create a new session session
sess = &session{
// the session key
key: []byte(t.token + m.channel + sessionId),
// the id of the remote side
tunnel: m.tunnel,
// the channel
channel: m.channel,
// the session id
session: sessionId,
// tunnel token
token: t.token,
// is loopback conn
loopback: m.loopback,
// the link the message was received on
link: linkId,
// set the connection mode
mode: t.session.mode,
// close chan
closed: make(chan bool),
// recv called by the acceptor
recv: make(chan *message, 128),
// use the internal send buffer
send: t.session.send,
// error channel
errChan: make(chan error, 1),
// set the read timeout
readTimeout: t.session.readTimeout,
}
// save the session
conns[sessionId] = sess
select {
case <-t.closed:
return
// send to accept chan
case t.accept <- sess:
}
}
// an existing session was found
switch m.typ {
case "close":
// don't close multicast sessions
if sess.mode > tunnel.Unicast {
continue
}
// received a close message
select {
// check if the session is closed
case <-sess.closed:
// no op
delete(conns, sessionId)
default:
// only close if unicast session
// close and delete session
close(sess.closed)
delete(conns, sessionId)
}
// continue
continue
case "session":
// operate on this
default:
// non operational type
continue
}
// send this to the accept chan
select {
case <-sess.closed:
delete(conns, sessionId)
case sess.recv <- m:
if logger.V(logger.TraceLevel, log) {
log.Tracef("Tunnel listener sent to recv chan channel %s session %s type %s", m.channel, sessionId, m.typ)
}
}
}
}
}
func (t *tunListener) Channel() string {
return t.channel
}
// Close closes tunnel listener
func (t *tunListener) Close() error {
t.Lock()
defer t.Unlock()
select {
case <-t.closed:
return nil
default:
// close and delete
t.delFunc()
t.session.Close()
close(t.closed)
}
return nil
}
// Everytime accept is called we essentially block till we get a new connection
func (t *tunListener) Accept() (tunnel.Session, error) {
select {
// if the session is closed return
case <-t.closed:
return nil, io.EOF
case <-t.tunClosed:
// close the listener when the tunnel closes
return nil, io.EOF
// wait for a new connection
case c, ok := <-t.accept:
// check if the accept chan is closed
if !ok {
return nil, io.EOF
}
// return without accept
if c.mode != tunnel.Unicast {
return c, nil
}
// send back the accept
if err := c.Accept(); err != nil {
return nil, err
}
return c, nil
}
}

1424
network/tunnel/mucp/mucp.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,347 @@
package mucp
import (
"os"
"sync"
"testing"
"time"
"github.com/micro/go-micro/v3/transport"
"github.com/micro/go-micro/v3/network/tunnel"
)
func testBrokenTunAccept(t *testing.T, tun tunnel.Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done()
// listen on some virtual address
tl, err := tun.Listen("test-tunnel")
if err != nil {
t.Fatal(err)
}
// receiver ready; notify sender
wait <- true
// accept a connection
c, err := tl.Accept()
if err != nil {
t.Fatal(err)
}
// accept the message and close the tunnel
// we do this to simulate loss of network connection
m := new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
// close all the links
for _, link := range tun.Links() {
link.Close()
}
// receiver ready; notify sender
wait <- true
// accept the message
m = new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
// notify the sender we have received
wait <- true
}
func testBrokenTunSend(t *testing.T, tun tunnel.Tunnel, wait chan bool, wg *sync.WaitGroup, reconnect time.Duration) {
defer wg.Done()
// wait for the listener to get ready
<-wait
// dial a new session
c, err := tun.Dial("test-tunnel")
if err != nil {
t.Fatal(err)
}
defer c.Close()
m := transport.Message{
Header: map[string]string{
"test": "send",
},
}
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// wait for the listener to get ready
<-wait
// give it time to reconnect
time.Sleep(reconnect)
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// wait for the listener to receive the message
// c.Send merely enqueues the message to the link send queue and returns
// in order to verify it was received we wait for the listener to tell us
<-wait
}
// testAccept will accept connections on the transport, create a new link and tunnel on top
func testAccept(t *testing.T, tun tunnel.Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done()
// listen on some virtual address
tl, err := tun.Listen("test-tunnel")
if err != nil {
t.Fatal(err)
}
// receiver ready; notify sender
wait <- true
// accept a connection
c, err := tl.Accept()
if err != nil {
t.Fatal(err)
}
// get a message
// accept the message
m := new(transport.Message)
if err := c.Recv(m); err != nil {
t.Fatal(err)
}
if v := m.Header["test"]; v != "send" {
t.Fatalf("Accept side expected test:send header. Received: %s", v)
}
// now respond
m.Header["test"] = "accept"
if err := c.Send(m); err != nil {
t.Fatal(err)
}
wait <- true
return
}
// testSend will create a new link to an address and then a tunnel on top
func testSend(t *testing.T, tun tunnel.Tunnel, wait chan bool, wg *sync.WaitGroup) {
defer wg.Done()
// wait for the listener to get ready
<-wait
// dial a new session
c, err := tun.Dial("test-tunnel")
if err != nil {
t.Fatal(err)
}
defer c.Close()
m := transport.Message{
Header: map[string]string{
"test": "send",
},
}
// send the message
if err := c.Send(&m); err != nil {
t.Fatal(err)
}
// now wait for the response
mr := new(transport.Message)
if err := c.Recv(mr); err != nil {
t.Fatal(err)
}
<-wait
if v := mr.Header["test"]; v != "accept" {
t.Fatalf("Message not received from accepted side. Received: %s", v)
}
}
func TestTunnel(t *testing.T) {
// create a new tunnel client
tunA := NewTunnel(
tunnel.Address("127.0.0.1:9096"),
tunnel.Nodes("127.0.0.1:9097"),
)
// create a new tunnel server
tunB := NewTunnel(
tunnel.Address("127.0.0.1:9097"),
)
// start tunB
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// start tunA
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start the listener
go testAccept(t, tunB, wait, &wg)
wg.Add(1)
// start the client
go testSend(t, tunA, wait, &wg)
// wait until done
wg.Wait()
}
func TestLoopbackTunnel(t *testing.T) {
// create a new tunnel
tun := NewTunnel(
tunnel.Address("127.0.0.1:9096"),
tunnel.Nodes("127.0.0.1:9096"),
)
// start tunnel
err := tun.Connect()
if err != nil {
t.Fatal(err)
}
defer tun.Close()
time.Sleep(500 * time.Millisecond)
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start the listener
go testAccept(t, tun, wait, &wg)
wg.Add(1)
// start the client
go testSend(t, tun, wait, &wg)
// wait until done
wg.Wait()
}
func TestTunnelRTTRate(t *testing.T) {
// create a new tunnel client
tunA := NewTunnel(
tunnel.Address("127.0.0.1:9096"),
tunnel.Nodes("127.0.0.1:9097"),
)
// create a new tunnel server
tunB := NewTunnel(
tunnel.Address("127.0.0.1:9097"),
)
// start tunB
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// start tunA
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start the listener
go testAccept(t, tunB, wait, &wg)
wg.Add(1)
// start the client
go testSend(t, tunA, wait, &wg)
// wait until done
wg.Wait()
if len(os.Getenv("IN_TRAVIS_CI")) == 0 {
// only needed for debug
for _, link := range tunA.Links() {
t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate())
}
for _, link := range tunB.Links() {
t.Logf("Link %s length %v rate %v", link.Id(), link.Length(), link.Rate())
}
}
}
func TestReconnectTunnel(t *testing.T) {
// we manually override the tunnel.ReconnectTime value here
// this is so that we make the reconnects faster than the default 5s
ReconnectTime = 200 * time.Millisecond
// create a new tunnel client
tunA := NewTunnel(
tunnel.Address("127.0.0.1:9098"),
tunnel.Nodes("127.0.0.1:9099"),
)
// create a new tunnel server
tunB := NewTunnel(
tunnel.Address("127.0.0.1:9099"),
)
// start tunnel
err := tunB.Connect()
if err != nil {
t.Fatal(err)
}
defer tunB.Close()
// start tunnel
err = tunA.Connect()
if err != nil {
t.Fatal(err)
}
defer tunA.Close()
wait := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)
// start tunnel listener
go testBrokenTunAccept(t, tunB, wait, &wg)
wg.Add(1)
// start tunnel sender
go testBrokenTunSend(t, tunA, wait, &wg, ReconnectTime*5)
// wait until done
wg.Wait()
}

View File

@@ -0,0 +1,503 @@
package mucp
import (
"crypto/cipher"
"encoding/base32"
"io"
"sync"
"time"
"github.com/micro/go-micro/v3/logger"
"github.com/micro/go-micro/v3/transport"
"github.com/micro/go-micro/v3/network/tunnel"
)
// session is our pseudo session for transport.Socket
type session struct {
// the tunnel id
tunnel string
// the channel name
channel string
// the session id based on Micro.Tunnel-Session
session string
// token is the session token
token string
// closed
closed chan bool
// remote addr
remote string
// local addr
local string
// send chan
send chan *message
// recv chan
recv chan *message
// if the discovery worked
discovered bool
// if the session was accepted
accepted bool
// outbound marks the session as outbound dialled connection
outbound bool
// lookback marks the session as a loopback on the inbound
loopback bool
// mode of the connection
mode tunnel.Mode
// the dial timeout
dialTimeout time.Duration
// the read timeout
readTimeout time.Duration
// the link on which this message was received
link string
// the error response
errChan chan error
// key for session encryption
key []byte
// cipher for session
gcm cipher.AEAD
sync.RWMutex
}
// message is sent over the send channel
type message struct {
// type of message
typ string
// tunnel id
tunnel string
// channel name
channel string
// the session id
session string
// outbound marks the message as outbound
outbound bool
// loopback marks the message intended for loopback
loopback bool
// mode of the connection
mode tunnel.Mode
// the link to send the message on
link string
// transport data
data *transport.Message
// the error channel
errChan chan error
}
func (s *session) Remote() string {
return s.remote
}
func (s *session) Local() string {
return s.local
}
func (s *session) Link() string {
return s.link
}
func (s *session) Id() string {
return s.session
}
func (s *session) Channel() string {
return s.channel
}
// newMessage creates a new message based on the session
func (s *session) newMessage(typ string) *message {
return &message{
typ: typ,
tunnel: s.tunnel,
channel: s.channel,
session: s.session,
outbound: s.outbound,
loopback: s.loopback,
mode: s.mode,
link: s.link,
errChan: s.errChan,
}
}
func (s *session) sendMsg(msg *message) error {
select {
case <-s.closed:
return io.EOF
case s.send <- msg:
return nil
}
}
func (s *session) wait(msg *message) error {
// wait for an error response
select {
case err := <-msg.errChan:
if err != nil {
return err
}
case <-s.closed:
return io.EOF
}
return nil
}
// waitFor waits for the message type required until the timeout specified
func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) {
now := time.Now()
after := func(timeout time.Duration) <-chan time.Time {
if timeout < time.Duration(0) {
return nil
}
// get the delta
d := time.Since(now)
// dial timeout minus time since
wait := timeout - d
if wait < time.Duration(0) {
wait = time.Duration(0)
}
return time.After(wait)
}
// wait for the message type
for {
select {
case msg := <-s.recv:
// there may be no message type
if len(msgType) == 0 {
return msg, nil
}
// ignore what we don't want
if msg.typ != msgType {
if logger.V(logger.DebugLevel, log) {
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
}
continue
}
// got the message
return msg, nil
case <-after(timeout):
return nil, tunnel.ErrReadTimeout
case <-s.closed:
// check pending message queue
select {
case msg := <-s.recv:
// there may be no message type
if len(msgType) == 0 {
return msg, nil
}
// ignore what we don't want
if msg.typ != msgType {
if logger.V(logger.DebugLevel, log) {
log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType)
}
continue
}
// got the message
return msg, nil
default:
// non blocking
}
return nil, io.EOF
}
}
}
// Discover attempts to discover the link for a specific channel.
// This is only used by the tunnel.Dial when first connecting.
func (s *session) Discover() error {
// create a new discovery message for this channel
msg := s.newMessage("discover")
// broadcast the message to all links
msg.mode = tunnel.Broadcast
// its an outbound connection since we're dialling
msg.outbound = true
// don't set the link since we don't know where it is
msg.link = ""
// if multicast then set that as session
if s.mode == tunnel.Multicast {
msg.session = "multicast"
}
// send discover message
if err := s.sendMsg(msg); err != nil {
return err
}
// set time now
now := time.Now()
// after strips down the dial timeout
after := func() time.Duration {
d := time.Since(now)
// dial timeout minus time since
wait := s.dialTimeout - d
// make sure its always > 0
if wait < time.Duration(0) {
return time.Duration(0)
}
return wait
}
// the discover message is sent out, now
// wait to hear back about the sent message
select {
case <-time.After(after()):
return tunnel.ErrDialTimeout
case err := <-s.errChan:
if err != nil {
return err
}
}
// bail early if its not unicast
// we don't need to wait for the announce
if s.mode != tunnel.Unicast {
s.discovered = true
s.accepted = true
return nil
}
// wait for announce
_, err := s.waitFor("announce", after())
if err != nil {
return err
}
// set discovered
s.discovered = true
return nil
}
// Open will fire the open message for the session. This is called by the dialler.
// This is to indicate that we want to create a new session.
func (s *session) Open() error {
// create a new message
msg := s.newMessage("open")
// send open message
if err := s.sendMsg(msg); err != nil {
return err
}
// wait for an error response for send
if err := s.wait(msg); err != nil {
return err
}
// now wait for the accept message to be returned
msg, err := s.waitFor("accept", s.dialTimeout)
if err != nil {
return err
}
// set to accepted
s.accepted = true
// set link
s.link = msg.link
return nil
}
// Accept sends the accept response to an open message from a dialled connection
func (s *session) Accept() error {
msg := s.newMessage("accept")
// send the accept message
if err := s.sendMsg(msg); err != nil {
return err
}
// wait for send response
return s.wait(msg)
}
// Announce sends an announcement to notify that this session exists.
// This is primarily used by the listener.
func (s *session) Announce() error {
msg := s.newMessage("announce")
// we don't need an error back
msg.errChan = nil
// announce to all
msg.mode = tunnel.Broadcast
// we don't need the link
msg.link = ""
// send announce message
return s.sendMsg(msg)
}
// Send is used to send a message
func (s *session) Send(m *transport.Message) error {
var err error
s.RLock()
gcm := s.gcm
s.RUnlock()
if gcm == nil {
gcm, err = newCipher(s.key)
if err != nil {
return err
}
s.Lock()
s.gcm = gcm
s.Unlock()
}
// encrypt the transport message payload
body, err := Encrypt(gcm, m.Body)
if err != nil {
log.Debugf("failed to encrypt message body: %v", err)
return err
}
// make copy, without rehash and realloc
data := &transport.Message{
Header: make(map[string]string, len(m.Header)),
Body: body,
}
// encrypt all the headers
for k, v := range m.Header {
// encrypt the transport message payload
val, err := Encrypt(s.gcm, []byte(v))
if err != nil {
log.Debugf("failed to encrypt message header %s: %v", k, err)
return err
}
// add the encrypted header value
data.Header[k] = base32.StdEncoding.EncodeToString(val)
}
// create a new message
msg := s.newMessage("session")
// set the data
msg.data = data
// if multicast don't set the link
if s.mode != tunnel.Unicast {
msg.link = ""
}
if logger.V(logger.TraceLevel, log) {
log.Tracef("Appending to send backlog: %v", msg)
}
// send the actual message
if err := s.sendMsg(msg); err != nil {
return err
}
// wait for an error response
return s.wait(msg)
}
// Recv is used to receive a message
func (s *session) Recv(m *transport.Message) error {
var msg *message
msg, err := s.waitFor("", s.readTimeout)
if err != nil {
return err
}
// check the error if one exists
select {
case err := <-msg.errChan:
return err
default:
}
if logger.V(logger.TraceLevel, log) {
log.Tracef("Received from recv backlog: %v", msg)
}
gcm, err := newCipher([]byte(s.token + s.channel + msg.session))
if err != nil {
if logger.V(logger.ErrorLevel, log) {
log.Errorf("unable to create cipher: %v", err)
}
return err
}
// decrypt the received payload using the token
// we have to used msg.session because multicast has a shared
// session id of "multicast" in this session struct on
// the listener side
msg.data.Body, err = Decrypt(gcm, msg.data.Body)
if err != nil {
if logger.V(logger.DebugLevel, log) {
log.Debugf("failed to decrypt message body: %v", err)
}
return err
}
// dencrypt all the headers
for k, v := range msg.data.Header {
// decode the header values
h, err := base32.StdEncoding.DecodeString(v)
if err != nil {
if logger.V(logger.DebugLevel, log) {
log.Debugf("failed to decode message header %s: %v", k, err)
}
return err
}
// dencrypt the transport message payload
val, err := Decrypt(gcm, h)
if err != nil {
if logger.V(logger.DebugLevel, log) {
log.Debugf("failed to decrypt message header %s: %v", k, err)
}
return err
}
// add decrypted header value
msg.data.Header[k] = string(val)
}
// set the link
// TODO: decruft, this is only for multicast
// since the session is now a single session
// likely provide as part of message.Link()
msg.data.Header["Micro-Link"] = msg.link
// set message
*m = *msg.data
// return nil
return nil
}
// Close closes the session by sending a close message
func (s *session) Close() error {
select {
case <-s.closed:
// no op
default:
close(s.closed)
// don't send close on multicast or broadcast
if s.mode != tunnel.Unicast {
return nil
}
// append to backlog
msg := s.newMessage("close")
// no error response on close
msg.errChan = nil
// send the close message
select {
case s.send <- msg:
case <-time.After(time.Millisecond * 10):
}
}
return nil
}

145
network/tunnel/options.go Normal file
View File

@@ -0,0 +1,145 @@
package tunnel
import (
"time"
"github.com/google/uuid"
"github.com/micro/go-micro/v3/transport"
"github.com/micro/go-micro/v3/transport/grpc"
)
var (
// DefaultAddress is default tunnel bind address
DefaultAddress = ":0"
// The shared default token
DefaultToken = "go.micro.tunnel"
)
type Option func(*Options)
// Options provides network configuration options
type Options struct {
// Id is tunnel id
Id string
// Address is tunnel address
Address string
// Nodes are remote nodes
Nodes []string
// The shared auth token
Token string
// Transport listens to incoming connections
Transport transport.Transport
}
type DialOption func(*DialOptions)
type DialOptions struct {
// Link specifies the link to use
Link string
// specify mode of the session
Mode Mode
// Wait for connection to be accepted
Wait bool
// the dial timeout
Timeout time.Duration
}
type ListenOption func(*ListenOptions)
type ListenOptions struct {
// specify mode of the session
Mode Mode
// The read timeout
Timeout time.Duration
}
// The tunnel id
func Id(id string) Option {
return func(o *Options) {
o.Id = id
}
}
// The tunnel address
func Address(a string) Option {
return func(o *Options) {
o.Address = a
}
}
// Nodes specify remote network nodes
func Nodes(n ...string) Option {
return func(o *Options) {
o.Nodes = n
}
}
// Token sets the shared token for auth
func Token(t string) Option {
return func(o *Options) {
o.Token = t
}
}
// Transport listens for incoming connections
func Transport(t transport.Transport) Option {
return func(o *Options) {
o.Transport = t
}
}
// Listen options
func ListenMode(m Mode) ListenOption {
return func(o *ListenOptions) {
o.Mode = m
}
}
// Timeout for reads and writes on the listener session
func ListenTimeout(t time.Duration) ListenOption {
return func(o *ListenOptions) {
o.Timeout = t
}
}
// Dial options
// Dial multicast sets the multicast option to send only to those mapped
func DialMode(m Mode) DialOption {
return func(o *DialOptions) {
o.Mode = m
}
}
// DialTimeout sets the dial timeout of the connection
func DialTimeout(t time.Duration) DialOption {
return func(o *DialOptions) {
o.Timeout = t
}
}
// DialLink specifies the link to pin this connection to.
// This is not applicable if the multicast option is set.
func DialLink(id string) DialOption {
return func(o *DialOptions) {
o.Link = id
}
}
// DialWait specifies whether to wait for the connection
// to be accepted before returning the session
func DialWait(b bool) DialOption {
return func(o *DialOptions) {
o.Wait = b
}
}
// DefaultOptions returns router default options
func DefaultOptions() Options {
return Options{
Id: uuid.New().String(),
Address: DefaultAddress,
Token: DefaultToken,
Transport: grpc.NewTransport(),
}
}

View File

@@ -0,0 +1,30 @@
package transport
import (
"github.com/micro/go-micro/v3/transport"
"github.com/micro/go-micro/v3/network/tunnel"
)
type tunListener struct {
l tunnel.Listener
}
func (t *tunListener) Addr() string {
return t.l.Channel()
}
func (t *tunListener) Close() error {
return t.l.Close()
}
func (t *tunListener) Accept(fn func(socket transport.Socket)) error {
for {
// accept connection
c, err := t.l.Accept()
if err != nil {
return err
}
// execute the function
go fn(c)
}
}

View File

@@ -0,0 +1,114 @@
// Package transport provides a tunnel transport
package transport
import (
"context"
"github.com/micro/go-micro/v3/transport"
"github.com/micro/go-micro/v3/network/tunnel"
"github.com/micro/go-micro/v3/network/tunnel/mucp"
)
type tunTransport struct {
options transport.Options
tunnel tunnel.Tunnel
}
type tunnelKey struct{}
type transportKey struct{}
func (t *tunTransport) Init(opts ...transport.Option) error {
for _, o := range opts {
o(&t.options)
}
// close the existing tunnel
if t.tunnel != nil {
t.tunnel.Close()
}
// get the tunnel
tun, ok := t.options.Context.Value(tunnelKey{}).(tunnel.Tunnel)
if !ok {
tun = mucp.NewTunnel()
}
// get the transport
tr, ok := t.options.Context.Value(transportKey{}).(transport.Transport)
if ok {
tun.Init(tunnel.Transport(tr))
}
// set the tunnel
t.tunnel = tun
return nil
}
func (t *tunTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) {
if err := t.tunnel.Connect(); err != nil {
return nil, err
}
c, err := t.tunnel.Dial(addr)
if err != nil {
return nil, err
}
return c, nil
}
func (t *tunTransport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) {
if err := t.tunnel.Connect(); err != nil {
return nil, err
}
l, err := t.tunnel.Listen(addr)
if err != nil {
return nil, err
}
return &tunListener{l}, nil
}
func (t *tunTransport) Options() transport.Options {
return t.options
}
func (t *tunTransport) String() string {
return "tunnel"
}
// NewTransport honours the initialiser used in
func NewTransport(opts ...transport.Option) transport.Transport {
t := &tunTransport{
options: transport.Options{},
}
// initialise
t.Init(opts...)
return t
}
// WithTransport sets the internal tunnel
func WithTunnel(t tunnel.Tunnel) transport.Option {
return func(o *transport.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, tunnelKey{}, t)
}
}
// WithTransport sets the internal transport
func WithTransport(t transport.Transport) transport.Option {
return func(o *transport.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, transportKey{}, t)
}
}

102
network/tunnel/tunnel.go Normal file
View File

@@ -0,0 +1,102 @@
// Package tunnel provides gre network tunnelling
package tunnel
import (
"errors"
"time"
"github.com/micro/go-micro/v3/transport"
)
const (
// send over one link
Unicast Mode = iota
// send to all channel listeners
Multicast
// send to all links
Broadcast
)
var (
// DefaultDialTimeout is the dial timeout if none is specified
DefaultDialTimeout = time.Second * 5
// ErrDialTimeout is returned by a call to Dial where the timeout occurs
ErrDialTimeout = errors.New("dial timeout")
// ErrDiscoverChan is returned when we failed to receive the "announce" back from a discovery
ErrDiscoverChan = errors.New("failed to discover channel")
// ErrLinkNotFound is returned when a link is specified at dial time and does not exist
ErrLinkNotFound = errors.New("link not found")
// ErrLinkDisconnected is returned when a link we attempt to send to is disconnected
ErrLinkDisconnected = errors.New("link not connected")
// ErrLinkLoppback is returned when attempting to send an outbound message over loopback link
ErrLinkLoopback = errors.New("link is loopback")
// ErrLinkRemote is returned when attempting to send a loopback message over remote link
ErrLinkRemote = errors.New("link is remote")
// ErrReadTimeout is a timeout on session.Recv
ErrReadTimeout = errors.New("read timeout")
// ErrDecryptingData is for when theres a nonce error
ErrDecryptingData = errors.New("error decrypting data")
)
// Mode of the session
type Mode uint8
// Tunnel creates a gre tunnel on top of the go-micro/transport.
// It establishes multiple streams using the Micro-Tunnel-Channel header
// and Micro-Tunnel-Session header. The tunnel id is a hash of
// the address being requested.
type Tunnel interface {
// Init initializes tunnel with options
Init(opts ...Option) error
// Address returns the address the tunnel is listening on
Address() string
// Connect connects the tunnel
Connect() error
// Close closes the tunnel
Close() error
// Links returns all the links the tunnel is connected to
Links() []Link
// Dial allows a client to connect to a channel
Dial(channel string, opts ...DialOption) (Session, error)
// Listen allows to accept connections on a channel
Listen(channel string, opts ...ListenOption) (Listener, error)
// String returns the name of the tunnel implementation
String() string
}
// Link represents internal links to the tunnel
type Link interface {
// Id returns the link unique Id
Id() string
// Delay is the current load on the link (lower is better)
Delay() int64
// Length returns the roundtrip time as nanoseconds (lower is better)
Length() int64
// Current transfer rate as bits per second (lower is better)
Rate() float64
// Is this a loopback link
Loopback() bool
// State of the link: connected/closed/error
State() string
// honours transport socket
transport.Socket
}
// The listener provides similar constructs to the transport.Listener
type Listener interface {
Accept() (Session, error)
Channel() string
Close() error
}
// Session is a unique session created when dialling or accepting connections on the tunnel
type Session interface {
// The unique session id
Id() string
// The channel name
Channel() string
// The link the session is on
Link() string
// a transport socket
transport.Socket
}