package tunnel import ( "bytes" "errors" "io" "sync" "time" "github.com/google/uuid" "github.com/micro/go-micro/v2/transport" "github.com/micro/go-micro/v2/util/log" ) type link struct { transport.Socket // transport to use for connections transport transport.Transport sync.RWMutex // stops the link closed chan bool // metric used to track metrics metric chan *metric // link state channel for testing link state chan *packet // send queue for sending packets sendQueue chan *packet // receive queue for receiving packets recvQueue chan *packet // unique id of this link e.g uuid // which we define for ourselves id string // whether its a loopback connection // this flag is used by the transport listener // which accepts inbound quic connections loopback bool // whether its actually connected // dialled side sets it to connected // after sending the message. the // listener waits for the connect connected bool // the last time we received a keepalive // on this link from the remote side lastKeepAlive time.Time // channels keeps a mapping of channels and last seen channels map[string]time.Time // the weighted moving average roundtrip length int64 // weighted moving average of bits flowing rate float64 // keep an error count on the link errCount int } // packet send over link type packet struct { // message to send or received message *transport.Message // status returned when sent status chan error // receive related error err error } // metric is used to record link rate type metric struct { // amount of data sent data int // time taken to send duration time.Duration // if an error occurred status error } var ( // the 4 byte 0 packet sent to determine the link state linkRequest = []byte{0, 0, 0, 0} // the 4 byte 1 filled packet sent to determine link state linkResponse = []byte{1, 1, 1, 1} ErrLinkConnectTimeout = errors.New("link connect timeout") ) func newLink(s transport.Socket) *link { l := &link{ Socket: s, id: uuid.New().String(), lastKeepAlive: time.Now(), closed: make(chan bool), channels: make(map[string]time.Time), state: make(chan *packet, 64), sendQueue: make(chan *packet, 128), recvQueue: make(chan *packet, 128), metric: make(chan *metric, 128), } // process inbound/outbound packets go l.process() // manage the link state go l.manage() return l } func (l *link) connect(addr string) error { c, err := l.transport.Dial(addr) if err != nil { return err } l.Lock() l.Socket = c l.Unlock() return nil } func (l *link) accept(sock transport.Socket) error { l.Lock() l.Socket = sock l.Unlock() return nil } func (l *link) setLoopback(v bool) { l.Lock() l.loopback = v l.Unlock() } // setRate sets the bits per second rate as a float64 func (l *link) setRate(bits int64, delta time.Duration) { // rate of send in bits per nanosecond rate := float64(bits) / float64(delta.Nanoseconds()) // default the rate if its zero if l.rate == 0 { // rate per second l.rate = rate * 1e9 } else { // set new rate per second l.rate = 0.8*l.rate + 0.2*(rate*1e9) } } // setRTT sets a nanosecond based moving average roundtrip time for the link func (l *link) setRTT(d time.Duration) { l.Lock() if l.length <= 0 { l.length = d.Nanoseconds() l.Unlock() return } // https://fishi.devtail.io/weblog/2015/04/12/measuring-bandwidth-and-round-trip-time-tcp-connection-inside-application-layer/ length := 0.8*float64(l.length) + 0.2*float64(d.Nanoseconds()) // set new length l.length = int64(length) l.Unlock() } func (l *link) delChannel(ch string) { l.Lock() delete(l.channels, ch) l.Unlock() } func (l *link) getChannel(ch string) time.Time { l.RLock() t := l.channels[ch] l.RUnlock() return t } func (l *link) setChannel(channels ...string) { l.Lock() for _, ch := range channels { l.channels[ch] = time.Now() } l.Unlock() } // set the keepalive time func (l *link) keepalive() { l.Lock() l.lastKeepAlive = time.Now() l.Unlock() } // process deals with the send queue func (l *link) process() { // receive messages go func() { for { m := new(transport.Message) err := l.recv(m) if err != nil { // record the metric select { case l.metric <- &metric{status: err}: default: } } // process new received message pk := &packet{message: m, err: err} // this is our link state packet if m.Header["Micro-Method"] == "link" { // process link state message select { case l.state <- pk: case <-l.closed: return default: } continue } // process all messages as is select { case l.recvQueue <- pk: case <-l.closed: return } } }() // send messages for { select { case pk := <-l.sendQueue: // send the message select { case pk.status <- l.send(pk.message): case <-l.closed: return } case <-l.closed: return } } } // manage manages the link state including rtt packets and channel mapping expiry func (l *link) manage() { // tick over every minute to expire and fire rtt packets t1 := time.NewTicker(time.Minute) defer t1.Stop() // used to batch update link metrics t2 := time.NewTicker(time.Second * 5) defer t2.Stop() // get link id linkId := l.Id() // used to send link state packets send := func(b []byte) error { return l.Send(&transport.Message{ Header: map[string]string{ "Micro-Method": "link", "Micro-Link-Id": linkId, }, Body: b, }) } // set time now now := time.Now() // send the initial rtt request packet send(linkRequest) for { select { // exit if closed case <-l.closed: return // process link state rtt packets case p := <-l.state: if p.err != nil { continue } // check the type of message switch { case bytes.Equal(p.message.Body, linkRequest): log.Tracef("Link %s received link request", linkId) // send response if err := send(linkResponse); err != nil { l.Lock() l.errCount++ l.Unlock() } case bytes.Equal(p.message.Body, linkResponse): // set round trip time d := time.Since(now) log.Tracef("Link %s received link response in %v", linkId, d) // set the RTT l.setRTT(d) } case <-t1.C: // drop any channel mappings older than 2 minutes var kill []string killTime := time.Minute * 2 l.RLock() for ch, t := range l.channels { if d := time.Since(t); d > killTime { kill = append(kill, ch) } } l.RUnlock() // if nothing to kill don't bother with a wasted lock if len(kill) == 0 { continue } // kill the channels! l.Lock() for _, ch := range kill { delete(l.channels, ch) } l.Unlock() // fire off a link state rtt packet now = time.Now() send(linkRequest) case <-t2.C: // get a batch of metrics batch := l.batch() // skip if there's no metrics if len(batch) == 0 { continue } // lock once to record a batch l.Lock() for _, metric := range batch { l.record(metric) } l.Unlock() } } } func (l *link) batch() []*metric { var metrics []*metric // pull all the metrics for { select { case m := <-l.metric: metrics = append(metrics, m) // non blocking return default: return metrics } } } func (l *link) record(m *metric) { // there's an error increment the counter and bail if m.status != nil { l.errCount++ return } // reset the counter l.errCount = 0 // calculate based on data if m.data > 0 { // bit sent bits := m.data * 1024 // set the rate l.setRate(int64(bits), m.duration) } } func (l *link) send(m *transport.Message) error { if m.Header == nil { m.Header = make(map[string]string) } // send the message return l.Socket.Send(m) } // recv a message on the link func (l *link) recv(m *transport.Message) error { if m.Header == nil { m.Header = make(map[string]string) } // receive the transport message return l.Socket.Recv(m) } // Delay is the current load on the link func (l *link) Delay() int64 { return int64(len(l.sendQueue) + len(l.recvQueue)) } // Current transfer rate as bits per second (lower is better) func (l *link) Rate() float64 { l.RLock() r := l.rate l.RUnlock() return r } func (l *link) Loopback() bool { l.RLock() lo := l.loopback l.RUnlock() return lo } // Length returns the roundtrip time as nanoseconds (lower is better). // Returns 0 where no measurement has been taken. func (l *link) Length() int64 { l.RLock() length := l.length l.RUnlock() return length } func (l *link) Id() string { l.RLock() id := l.id l.RUnlock() return id } func (l *link) Close() error { l.Lock() defer l.Unlock() select { case <-l.closed: return nil default: l.Socket.Close() close(l.closed) } return nil } // Send sencs a message on the link func (l *link) Send(m *transport.Message) error { // create a new packet to send over the link p := &packet{ message: m, status: make(chan error, 1), } // calculate the data sent dataSent := len(m.Body) // set header length for k, v := range m.Header { dataSent += (len(k) + len(v)) } // get time now now := time.Now() // queue the message select { case l.sendQueue <- p: // in the send queue case <-l.closed: return io.EOF } // error to use var err error // wait for response select { case <-l.closed: return io.EOF case err = <-p.status: } // create a metric with // time taken, size of package, error status mt := &metric{ data: dataSent, duration: time.Since(now), status: err, } // pass back a metric // do not block select { case l.metric <- mt: default: } return nil } // Accept accepts a message on the socket func (l *link) Recv(m *transport.Message) error { select { case <-l.closed: // check if there's any messages left select { case pk := <-l.recvQueue: // check the packet receive error if pk.err != nil { return pk.err } *m = *pk.message default: return io.EOF } case pk := <-l.recvQueue: // check the packet receive error if pk.err != nil { return pk.err } *m = *pk.message } return nil } // State can return connected, closed, error func (l *link) State() string { select { case <-l.closed: return "closed" default: l.RLock() errCount := l.errCount l.RUnlock() if errCount > 3 { return "error" } return "connected" } }