diff --git a/tunnel/default.go b/tunnel/default.go index eb26eaa1..ff86bb6b 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -386,6 +386,13 @@ func (t *tun) Connect() error { return nil } +func (t *tun) Init(opts ...Option) error { + for _, o := range opts { + o(&t.options) + } + return nil +} + // Dial an address func (t *tun) Dial(addr string) (Conn, error) { c, ok := t.newSocket(addr, t.newSession()) diff --git a/tunnel/transport/listener.go b/tunnel/transport/listener.go new file mode 100644 index 00000000..b7a7280c --- /dev/null +++ b/tunnel/transport/listener.go @@ -0,0 +1,30 @@ +package transport + +import ( + "github.com/micro/go-micro/transport" + "github.com/micro/go-micro/tunnel" +) + +type tunListener struct { + l tunnel.Listener +} + +func (t *tunListener) Addr() string { + return t.l.Addr() +} + +func (t *tunListener) Close() error { + return t.l.Close() +} + +func (t *tunListener) Accept(fn func(socket transport.Socket)) error { + for { + // accept connection + c, err := t.l.Accept() + if err != nil { + return err + } + // execute the function + go fn(c) + } +} diff --git a/tunnel/transport/transport.go b/tunnel/transport/transport.go new file mode 100644 index 00000000..d37468d2 --- /dev/null +++ b/tunnel/transport/transport.go @@ -0,0 +1,113 @@ +// Package transport provides a tunnel transport +package transport + +import ( + "context" + + "github.com/micro/go-micro/transport" + "github.com/micro/go-micro/tunnel" +) + +type tunTransport struct { + options transport.Options + + tunnel tunnel.Tunnel +} + +type tunnelKey struct{} + +type transportKey struct{} + +func (t *tunTransport) Init(opts ...transport.Option) error { + for _, o := range opts { + o(&t.options) + } + + // close the existing tunnel + if t.tunnel != nil { + t.tunnel.Close() + } + + // get the tunnel + tun, ok := t.options.Context.Value(tunnelKey{}).(tunnel.Tunnel) + if !ok { + tun = tunnel.NewTunnel() + } + + // get the transport + tr, ok := t.options.Context.Value(transportKey{}).(transport.Transport) + if ok { + tun.Init(tunnel.Transport(tr)) + } + + // set the tunnel + t.tunnel = tun + + return nil +} + +func (t *tunTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { + if err := t.tunnel.Connect(); err != nil { + return nil, err + } + + c, err := t.tunnel.Dial(addr) + if err != nil { + return nil, err + } + + return c, nil +} + +func (t *tunTransport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) { + if err := t.tunnel.Connect(); err != nil { + return nil, err + } + + l, err := t.tunnel.Listen(addr) + if err != nil { + return nil, err + } + + return &tunListener{l}, nil +} + +func (t *tunTransport) Options() transport.Options { + return t.options +} + +func (t *tunTransport) String() string { + return "tunnel" +} + +// NewTransport honours the initialiser used in +func NewTransport(opts ...transport.Option) transport.Transport { + t := &tunTransport{ + options: transport.Options{}, + } + + // initialise + t.Init(opts...) + + return t +} + +// WithTransport sets the internal tunnel +func WithTunnel(t tunnel.Tunnel) transport.Option { + return func(o *transport.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, tunnelKey{}, t) + } +} + +// WithTransport sets the internal transport +func WithTransport(t transport.Transport) transport.Option { + return func(o *transport.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, transportKey{}, t) + } +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index bcb54fd1..3c84c7eb 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -10,6 +10,7 @@ import ( // and Micro-Tunnel-Session header. The tunnel id is a hash of // the address being requested. type Tunnel interface { + Init(opts ...Option) error // Connect connects the tunnel Connect() error // Close closes the tunnel