diff --git a/client/rpc_client.go b/client/rpc_client.go index 07b7d6e0..754e4329 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -96,19 +96,14 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, } } - var grr error c, err := r.pool.Get(address, transport.WithTimeout(opts.DialTimeout)) if err != nil { return errors.InternalServerError("go.micro.client", "connection error: %v", err) } - defer func() { - // defer execution of release - r.pool.Release(c, grr) - }() seq := atomic.LoadUint64(&r.seq) atomic.AddUint64(&r.seq, 1) - codec := newRpcCodec(msg, c, cf) + codec := newRpcCodec(msg, c, cf, "") rsp := &rpcResponse{ socket: c, @@ -116,15 +111,19 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, } stream := &rpcStream{ + id: fmt.Sprintf("%v", seq), context: ctx, request: req, response: rsp, codec: codec, closed: make(chan bool), - id: fmt.Sprintf("%v", seq), + release: func(err error) { r.pool.Release(c, err) }, + sendEOS: false, } + // close the stream on exiting this function defer stream.Close() + // wait for error response ch := make(chan error, 1) go func() { @@ -150,14 +149,26 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, ch <- nil }() + var grr error + select { case err := <-ch: grr = err return err case <-ctx.Done(): - grr = ctx.Err() - return errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) + grr = errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) } + + // set the stream error + if grr != nil { + stream.Lock() + stream.err = grr + stream.Unlock() + + return grr + } + + return nil } func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) { @@ -201,12 +212,18 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout)) } - c, err := r.opts.Transport.Dial(address, dOpts...) + c, err := r.pool.Get(address, dOpts...) if err != nil { return nil, errors.InternalServerError("go.micro.client", "connection error: %v", err) } - codec := newRpcCodec(msg, c, cf) + // increment the sequence number + seq := atomic.LoadUint64(&r.seq) + atomic.AddUint64(&r.seq, 1) + id := fmt.Sprintf("%v", seq) + + // create codec with stream id + codec := newRpcCodec(msg, c, cf, id) rsp := &rpcResponse{ socket: c, @@ -219,16 +236,24 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request } stream := &rpcStream{ + id: id, context: ctx, request: req, response: rsp, - closed: make(chan bool), codec: codec, + // used to close the stream + closed: make(chan bool), + // signal the end of stream, + sendEOS: true, + // release func + release: func(err error) { r.pool.Release(c, err) }, } + // wait for error response ch := make(chan error, 1) go func() { + // send the first message ch <- stream.Send(req.Body()) }() @@ -242,6 +267,12 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request } if grr != nil { + // set the error + stream.Lock() + stream.err = grr + stream.Unlock() + + // close the stream stream.Close() return nil, grr } diff --git a/client/rpc_codec.go b/client/rpc_codec.go index 6ff84a64..c20537ea 100644 --- a/client/rpc_codec.go +++ b/client/rpc_codec.go @@ -39,6 +39,9 @@ type rpcCodec struct { req *transport.Message buf *readWriteCloser + + // signify if its a stream + stream string } type readWriteCloser struct { @@ -113,7 +116,7 @@ func getHeaders(m *codec.Message) { } } -func setHeaders(m *codec.Message) { +func setHeaders(m *codec.Message, stream string) { set := func(hdr, v string) { if len(v) == 0 { return @@ -126,6 +129,11 @@ func setHeaders(m *codec.Message) { set("Micro-Service", m.Target) set("Micro-Method", m.Method) set("Micro-Endpoint", m.Endpoint) + set("Micro-Error", m.Error) + + if len(stream) > 0 { + set("Micro-Stream", stream) + } } // setupProtocol sets up the old protocol @@ -149,7 +157,7 @@ func setupProtocol(msg *transport.Message, node *registry.Node) codec.NewCodec { return defaultCodecs[msg.Header["Content-Type"]] } -func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCodec) codec.Codec { +func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCodec, stream string) codec.Codec { rwc := &readWriteCloser{ wbuf: bytes.NewBuffer(nil), rbuf: bytes.NewBuffer(nil), @@ -159,6 +167,7 @@ func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCod client: client, codec: c(rwc), req: req, + stream: stream, } return r } @@ -177,7 +186,7 @@ func (c *rpcCodec) Write(m *codec.Message, body interface{}) error { } // set the mucp headers - setHeaders(m) + setHeaders(m, c.stream) // if body is bytes Frame don't encode if body != nil { @@ -240,6 +249,12 @@ func (c *rpcCodec) ReadHeader(m *codec.Message, r codec.MessageType) error { func (c *rpcCodec) ReadBody(b interface{}) error { // read body + // read raw data + if v, ok := b.(*raw.Frame); ok { + v.Data = c.buf.rbuf.Bytes() + return nil + } + if err := c.codec.ReadBody(b); err != nil { return errors.InternalServerError("go.micro.client.codec", err.Error()) } diff --git a/client/rpc_stream.go b/client/rpc_stream.go index f605c11e..f904d2dd 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -18,6 +18,12 @@ type rpcStream struct { response Response codec codec.Codec context context.Context + + // signal whether we should send EOS + sendEOS bool + + // release releases the connection back to the pool + release func(err error) } func (r *rpcStream) isClosed() bool { @@ -120,6 +126,26 @@ func (r *rpcStream) Close() error { return nil default: close(r.closed) - return r.codec.Close() + + // send the end of stream message + if r.sendEOS { + // no need to check for error + r.codec.Write(&codec.Message{ + Id: r.id, + Target: r.request.Service(), + Method: r.request.Method(), + Endpoint: r.request.Endpoint(), + Type: codec.Error, + Error: lastStreamResponseError, + }, nil) + } + + err := r.codec.Close() + + // release the connection + r.release(r.Error()) + + // return the codec error + return err } } diff --git a/proxy/mucp/mucp.go b/proxy/mucp/mucp.go index 3f54d9c0..095c10d8 100644 --- a/proxy/mucp/mucp.go +++ b/proxy/mucp/mucp.go @@ -220,6 +220,23 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server // create new request with raw bytes body creq := p.Client.NewRequest(service, endpoint, &bytes.Frame{body}, client.WithContentType(req.ContentType())) + // not a stream so make a client.Call request + if !req.Stream() { + crsp := new(bytes.Frame) + + // make a call to the backend + if err := p.Client.Call(ctx, creq, crsp, opts...); err != nil { + return err + } + + // write the response + if err := rsp.Write(crsp.Data); err != nil { + return err + } + + return nil + } + // create new stream stream, err := p.Client.Stream(ctx, creq, opts...) if err != nil { @@ -227,7 +244,7 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server } defer stream.Close() - // create client request read loop + // create client request read loop if streaming go readLoop(req, stream) // get raw response diff --git a/server/rpc_codec.go b/server/rpc_codec.go index 797fa79c..19cdc564 100644 --- a/server/rpc_codec.go +++ b/server/rpc_codec.go @@ -154,11 +154,10 @@ func setupProtocol(msg *transport.Message) codec.NewCodec { func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec { rwc := &readWriteCloser{ - rbuf: bytes.NewBuffer(req.Body), + rbuf: bytes.NewBuffer(nil), wbuf: bytes.NewBuffer(nil), } r := &rpcCodec{ - first: true, buf: rwc, codec: c(rwc), req: req, @@ -174,33 +173,27 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { Body: c.req.Body, } - // if its a follow on request read it - if !c.first { - var tm transport.Message + var tm transport.Message - // read off the socket - if err := c.socket.Recv(&tm); err != nil { - return err - } - // reset the read buffer - c.buf.rbuf.Reset() + // read off the socket + if err := c.socket.Recv(&tm); err != nil { + return err + } + // reset the read buffer + c.buf.rbuf.Reset() - // write the body to the buffer - if _, err := c.buf.rbuf.Write(tm.Body); err != nil { - return err - } - - // set the message header - m.Header = tm.Header - // set the message body - m.Body = tm.Body - - // set req - c.req = &tm + // write the body to the buffer + if _, err := c.buf.rbuf.Write(tm.Body); err != nil { + return err } - // no longer first read - c.first = false + // set the message header + m.Header = tm.Header + // set the message body + m.Body = tm.Body + + // set req + c.req = &tm // set some internal things getHeaders(&m) diff --git a/server/rpc_server.go b/server/rpc_server.go index 516825db..5fb4af29 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -19,6 +19,7 @@ import ( "github.com/micro/go-micro/util/addr" log "github.com/micro/go-micro/util/log" mnet "github.com/micro/go-micro/util/net" + "github.com/micro/go-micro/util/socket" ) type rpcServer struct { @@ -70,23 +71,99 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { } }() + // multiplex the streams on a single socket by Micro-Stream + var mtx sync.RWMutex + sockets := make(map[string]*socket.Socket) + for { var msg transport.Message if err := sock.Recv(&msg); err != nil { return } + // use Micro-Stream as the stream identifier + // in the event its blank we'll always process + // on the same socket + id := msg.Header["Micro-Stream"] + + // if there's no stream id then its a standard request + // use the Micro-Id + if len(id) == 0 { + id = msg.Header["Micro-Id"] + } + // add to wait group if "wait" is opt-in if s.wg != nil { s.wg.Add(1) } + // check we have an existing socket + mtx.RLock() + psock, ok := sockets[id] + mtx.RUnlock() + + // got the socket + if ok { + // accept the message + if err := psock.Accept(&msg); err != nil { + // delete the socket + mtx.Lock() + delete(sockets, id) + mtx.Unlock() + } + + // done(1) + if s.wg != nil { + s.wg.Done() + } + + // continue to the next message + continue + } + + // no socket was found + psock = socket.New() + psock.SetLocal(sock.Local()) + psock.SetRemote(sock.Remote()) + + // load the socket + psock.Accept(&msg) + + // save a new socket + mtx.Lock() + sockets[id] = psock + mtx.Unlock() + + // process the outbound messages from the socket + go func(id string, psock *socket.Socket) { + defer psock.Close() + + for { + // get the message from our internal handler/stream + m := new(transport.Message) + if err := psock.Process(m); err != nil { + // delete the socket + mtx.Lock() + delete(sockets, id) + mtx.Unlock() + return + } + + // send the message back over the socket + if err := sock.Send(m); err != nil { + return + } + } + }(id, psock) + + // now walk the usual path + // we use this Timeout header to set a server deadline to := msg.Header["Timeout"] // we use this Content-Type header to identify the codec needed ct := msg.Header["Content-Type"] - // strip our headers + // copy the message headers hdr := make(map[string]string) for k, v := range msg.Header { hdr[k] = v @@ -96,17 +173,17 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { hdr["Local"] = sock.Local() hdr["Remote"] = sock.Remote() - // create new context + // create new context with the metadata ctx := metadata.NewContext(context.Background(), hdr) - // set the timeout if we have it + // set the timeout from the header if we have it if len(to) > 0 { if n, err := strconv.ParseUint(to, 10, 64); err == nil { ctx, _ = context.WithTimeout(ctx, time.Duration(n)) } } - // no content type + // if there's no content type default it if len(ct) == 0 { msg.Header["Content-Type"] = DefaultContentType ct = DefaultContentType @@ -133,7 +210,13 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { } } - rcodec := newRpcCodec(&msg, sock, cf) + rcodec := newRpcCodec(&msg, psock, cf) + + // check stream id + var stream bool + if v := getHeader("Micro-Stream", msg.Header); len(v) > 0 { + stream = true + } // internal request request := &rpcRequest{ @@ -144,15 +227,14 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { codec: rcodec, header: msg.Header, body: msg.Body, - socket: sock, - stream: true, - first: true, + socket: psock, + stream: stream, } // internal response response := &rpcResponse{ header: make(map[string]string), - socket: sock, + socket: psock, codec: rcodec, } @@ -175,25 +257,34 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { r = rpcRouter{handler} } - // serve the actual request using the request router - if err := r.ServeRequest(ctx, request, response); err != nil { - // write an error response - err = rcodec.Write(&codec.Message{ - Header: msg.Header, - Error: err.Error(), - Type: codec.Error, - }, nil) - // could not write the error response - if err != nil { - log.Logf("rpc: unable to write error response: %v", err) + // serve the request in a go routine as this may be a stream + go func(id string, psock *socket.Socket) { + // serve the actual request using the request router + if err := r.ServeRequest(ctx, request, response); err != nil { + // write an error response + err = rcodec.Write(&codec.Message{ + Header: msg.Header, + Error: err.Error(), + Type: codec.Error, + }, nil) + + // could not write the error response + if err != nil { + log.Logf("rpc: unable to write error response: %v", err) + } } + + mtx.Lock() + delete(sockets, id) + mtx.Unlock() + + // once done serving signal we're done if s.wg != nil { s.wg.Done() } - return - } + }(id, psock) - // done + // signal we're done if s.wg != nil { s.wg.Done() } diff --git a/server/rpc_stream.go b/server/rpc_stream.go index 185f1ff9..a4e64af8 100644 --- a/server/rpc_stream.go +++ b/server/rpc_stream.go @@ -2,6 +2,8 @@ package server import ( "context" + "errors" + "io" "sync" "github.com/micro/go-micro/codec" @@ -59,6 +61,20 @@ func (r *rpcStream) Recv(msg interface{}) error { return err } + // check the error + if len(req.Error) > 0 { + // Check the client closed the stream + switch req.Error { + case lastStreamResponseError.Error(): + // discard body + r.codec.ReadBody(nil) + r.err = io.EOF + return io.EOF + default: + return errors.New(req.Error) + } + } + // we need to stay up to date with sequence numbers r.id = req.Id if err := r.codec.ReadBody(msg); err != nil { diff --git a/util/socket/socket.go b/util/socket/socket.go new file mode 100644 index 00000000..59bb538d --- /dev/null +++ b/util/socket/socket.go @@ -0,0 +1,137 @@ +// Package socket provides a pseudo socket +package socket + +import ( + "io" + + "github.com/micro/go-micro/transport" +) + +// Socket is our pseudo socket for transport.Socket +type Socket struct { + // closed + closed chan bool + // remote addr + remote string + // local addr + local string + // send chan + send chan *transport.Message + // recv chan + recv chan *transport.Message +} + +func (s *Socket) SetLocal(l string) { + s.local = l +} + +func (s *Socket) SetRemote(r string) { + s.remote = r +} + +// Accept passes a message to the socket which will be processed by the call to Recv +func (s *Socket) Accept(m *transport.Message) error { + select { + case <-s.closed: + return io.EOF + case s.recv <- m: + return nil + } + return nil +} + +// Process takes the next message off the send queue created by a call to Send +func (s *Socket) Process(m *transport.Message) error { + select { + case <-s.closed: + return io.EOF + case msg := <-s.send: + *m = *msg + } + return nil +} + +func (s *Socket) Remote() string { + return s.remote +} + +func (s *Socket) Local() string { + return s.local +} + +func (s *Socket) Send(m *transport.Message) error { + select { + case <-s.closed: + return io.EOF + default: + // no op + } + + // make copy + msg := &transport.Message{ + Header: make(map[string]string), + Body: make([]byte, len(m.Body)), + } + + // copy headers + for k, v := range m.Header { + msg.Header[k] = v + } + + // copy body + copy(msg.Body, m.Body) + + // send a message + select { + case s.send <- msg: + case <-s.closed: + return io.EOF + } + + return nil +} + +func (s *Socket) Recv(m *transport.Message) error { + select { + case <-s.closed: + return io.EOF + default: + // no op + } + + // receive a message + select { + case msg := <-s.recv: + // set message + *m = *msg + case <-s.closed: + return io.EOF + } + + // return nil + return nil +} + +// Close closes the socket +func (s *Socket) Close() error { + select { + case <-s.closed: + // no op + default: + close(s.closed) + } + return nil +} + +// New returns a new pseudo socket which can be used in the place of a transport socket. +// Messages are sent to the socket via Accept and receives from the socket via Process. +// SetLocal/SetRemote should be called before using the socket. +func New() *Socket { + return &Socket{ + closed: make(chan bool), + local: "local", + remote: "remote", + send: make(chan *transport.Message, 128), + recv: make(chan *transport.Message, 128), + } +}