diff --git a/codec/grpc/grpc.go b/codec/grpc/grpc.go index d347c31d..86630772 100644 --- a/codec/grpc/grpc.go +++ b/codec/grpc/grpc.go @@ -89,9 +89,22 @@ func (c *Codec) Write(m *codec.Message, b interface{}) error { m.Header[":authority"] = m.Target m.Header["content-type"] = c.ContentType case codec.Response: - m.Header["Trailer"] = "grpc-status, grpc-message" + m.Header["Trailer"] = "grpc-status" //, grpc-message" + m.Header["content-type"] = c.ContentType + m.Header[":status"] = "200" m.Header["grpc-status"] = "0" - m.Header["grpc-message"] = "" + // m.Header["grpc-message"] = "" + case codec.Error: + m.Header["Trailer"] = "grpc-status, grpc-message" + // micro end of stream + if m.Error == "EOS" { + m.Header["grpc-status"] = "0" + } else { + m.Header["grpc-message"] = m.Error + m.Header["grpc-status"] = "13" + } + + return nil } // marshal content diff --git a/server/rpc_codec.go b/server/rpc_codec.go index 19cdc564..2fea1b83 100644 --- a/server/rpc_codec.go +++ b/server/rpc_codec.go @@ -15,9 +15,10 @@ import ( ) type rpcCodec struct { - socket transport.Socket - codec codec.Codec - first bool + socket transport.Socket + codec codec.Codec + first bool + protocol string req *transport.Message buf *readWriteCloser @@ -157,12 +158,27 @@ func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCod rbuf: bytes.NewBuffer(nil), wbuf: bytes.NewBuffer(nil), } + r := &rpcCodec{ - buf: rwc, - codec: c(rwc), - req: req, - socket: socket, + buf: rwc, + codec: c(rwc), + req: req, + socket: socket, + protocol: "mucp", } + + // if grpc pre-load the buffer + // TODO: remove this terrible hack + switch r.codec.String() { + case "grpc": + // set as first + r.first = true + // write the body + rwc.rbuf.Write(req.Body) + // set the protocol + r.protocol = "grpc" + } + return r } @@ -173,27 +189,33 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error { Body: c.req.Body, } - var tm transport.Message + // first message could be pre-loaded + if !c.first { + var tm transport.Message - // read off the socket - if err := c.socket.Recv(&tm); err != nil { - return err - } - // reset the read buffer - c.buf.rbuf.Reset() + // read off the socket + if err := c.socket.Recv(&tm); err != nil { + return err + } + // reset the read buffer + c.buf.rbuf.Reset() - // write the body to the buffer - if _, err := c.buf.rbuf.Write(tm.Body); err != nil { - return err + // write the body to the buffer + if _, err := c.buf.rbuf.Write(tm.Body); err != nil { + return err + } + + // set the message header + m.Header = tm.Header + // set the message body + m.Body = tm.Body + + // set req + c.req = &tm } - // set the message header - m.Header = tm.Header - // set the message body - m.Body = tm.Body - - // set req - c.req = &tm + // disable first + c.first = false // set some internal things getHeaders(&m) @@ -293,5 +315,5 @@ func (c *rpcCodec) Close() error { } func (c *rpcCodec) String() string { - return "rpc" + return c.protocol } diff --git a/server/rpc_server.go b/server/rpc_server.go index c02b893a..43c4bae5 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -63,20 +63,33 @@ func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) // ServeConn serves a single connection func (s *rpcServer) ServeConn(sock transport.Socket) { + var wg sync.WaitGroup + var mtx sync.RWMutex + // streams are multiplexed on Micro-Stream or Micro-Id header + sockets := make(map[string]*socket.Socket) + defer func() { - // close socket + // wait till done + wg.Wait() + + // close underlying socket sock.Close() + // close the sockets + mtx.Lock() + for id, psock := range sockets { + psock.Close() + delete(sockets, id) + } + mtx.Unlock() + + // recover any panics if r := recover(); r != nil { log.Log("panic recovered: ", r) log.Log(string(debug.Stack())) } }() - // multiplex the streams on a single socket by Micro-Stream - var mtx sync.RWMutex - sockets := make(map[string]*socket.Socket) - for { var msg transport.Message if err := sock.Recv(&msg); err != nil { @@ -94,6 +107,9 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { id = msg.Header["Micro-Id"] } + // we're starting processing + wg.Add(1) + // add to wait group if "wait" is opt-in if s.wg != nil { s.wg.Add(1) @@ -119,6 +135,8 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { s.wg.Done() } + wg.Done() + // continue to the next message continue } @@ -136,28 +154,6 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { sockets[id] = psock mtx.Unlock() - // process the outbound messages from the socket - go func(id string, psock *socket.Socket) { - defer psock.Close() - - for { - // get the message from our internal handler/stream - m := new(transport.Message) - if err := psock.Process(m); err != nil { - // delete the socket - mtx.Lock() - delete(sockets, id) - mtx.Unlock() - return - } - - // send the message back over the socket - if err := sock.Send(m); err != nil { - return - } - } - }(id, psock) - // now walk the usual path // we use this Timeout header to set a server deadline @@ -205,17 +201,23 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { }, Body: []byte(err.Error()), }) + if s.wg != nil { s.wg.Done() } + + wg.Done() + return } } rcodec := newRpcCodec(&msg, psock, cf) + protocol := rcodec.String() // check stream id var stream bool + if v := getHeader("Micro-Stream", msg.Header); len(v) > 0 { stream = true } @@ -259,8 +261,44 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { r = rpcRouter{handler} } + // wait for processing to exit + wg.Add(1) + + // process the outbound messages from the socket + go func(id string, psock *socket.Socket) { + defer func() { + // TODO: don't hack this but if its grpc just break out of the stream + // We do this because the underlying connection is h2 and its a stream + switch protocol { + case "grpc": + sock.Close() + } + + wg.Done() + }() + + for { + // get the message from our internal handler/stream + m := new(transport.Message) + if err := psock.Process(m); err != nil { + // delete the socket + mtx.Lock() + delete(sockets, id) + mtx.Unlock() + return + } + + // send the message back over the socket + if err := sock.Send(m); err != nil { + return + } + } + }(id, psock) + // serve the request in a go routine as this may be a stream go func(id string, psock *socket.Socket) { + defer psock.Close() + // serve the actual request using the request router if err := r.ServeRequest(ctx, request, response); err != nil { // write an error response @@ -285,8 +323,9 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { s.wg.Done() } + // done with this socket + wg.Done() }(id, psock) - } } diff --git a/transport/http_transport.go b/transport/http_transport.go index b1e760a4..e6fd1b05 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -33,6 +33,8 @@ type httpTransportClient struct { once sync.Once sync.RWMutex + + // request must be stored for response processing r chan *http.Request bl []*http.Request buff *bufio.Reader @@ -48,10 +50,18 @@ type httpTransportSocket struct { r *http.Request rw *bufio.ReadWriter + mtx sync.RWMutex + + // the hijacked when using http 1 conn net.Conn // for the first request ch chan *http.Request + // h2 things + buf *bufio.Reader + // indicate if socket is closed + closed chan bool + // local/remote ip local string remote string @@ -161,14 +171,13 @@ func (h *httpTransportClient) Recv(m *Message) error { } func (h *httpTransportClient) Close() error { - err := h.conn.Close() h.once.Do(func() { h.Lock() h.buff.Reset(nil) h.Unlock() close(h.r) }) - return err + return h.conn.Close() } func (h *httpTransportSocket) Local() string { @@ -232,14 +241,23 @@ func (h *httpTransportSocket) Recv(m *Message) error { return nil } + // only process if the socket is open + select { + case <-h.closed: + return io.EOF + default: + // no op + } + // processing http2 request // read streaming body // set max buffer size - buf := make([]byte, 4*1024) + // TODO: adjustable buffer size + buf := make([]byte, 4*1024*1024) // read the request body - n, err := h.r.Body.Read(buf) + n, err := h.buf.Read(buf) // not an eof error if err != nil { return err @@ -290,7 +308,13 @@ func (h *httpTransportSocket) Send(m *Message) error { return rsp.Write(h.conn) } - // http2 request + // only process if the socket is open + select { + case <-h.closed: + return io.EOF + default: + // no op + } // set headers for k, v := range m.Header { @@ -299,6 +323,10 @@ func (h *httpTransportSocket) Send(m *Message) error { // write request _, err := h.w.Write(m.Body) + + // flush the trailers + h.w.(http.Flusher).Flush() + return err } @@ -321,13 +349,29 @@ func (h *httpTransportSocket) error(m *Message) error { return rsp.Write(h.conn) } + return nil } func (h *httpTransportSocket) Close() error { - if h.r.ProtoMajor == 1 { - return h.conn.Close() + h.mtx.Lock() + defer h.mtx.Unlock() + select { + case <-h.closed: + return nil + default: + // close the channel + close(h.closed) + + // close the buffer + h.r.Body.Close() + + // close the connection + if h.r.ProtoMajor == 1 { + return h.conn.Close() + } } + return nil } @@ -374,20 +418,29 @@ func (h *httpTransportListener) Accept(fn func(Socket)) error { con = conn } + // buffered reader + bufr := bufio.NewReader(r.Body) + // save the request ch := make(chan *http.Request, 1) ch <- r - fn(&httpTransportSocket{ + // create a new transport socket + sock := &httpTransportSocket{ ht: h.ht, w: w, r: r, rw: buf, + buf: bufr, ch: ch, conn: con, local: h.Addr(), remote: r.RemoteAddr, - }) + closed: make(chan bool), + } + + // execute the socket + fn(sock) }) // get optional handlers diff --git a/util/socket/socket.go b/util/socket/socket.go index 59bb538d..29ba5006 100644 --- a/util/socket/socket.go +++ b/util/socket/socket.go @@ -32,10 +32,10 @@ func (s *Socket) SetRemote(r string) { // Accept passes a message to the socket which will be processed by the call to Recv func (s *Socket) Accept(m *transport.Message) error { select { - case <-s.closed: - return io.EOF case s.recv <- m: return nil + case <-s.closed: + return io.EOF } return nil } @@ -43,10 +43,17 @@ func (s *Socket) Accept(m *transport.Message) error { // Process takes the next message off the send queue created by a call to Send func (s *Socket) Process(m *transport.Message) error { select { - case <-s.closed: - return io.EOF case msg := <-s.send: *m = *msg + case <-s.closed: + // see if we need to drain + select { + case msg := <-s.send: + *m = *msg + return nil + default: + return io.EOF + } } return nil } @@ -60,13 +67,6 @@ func (s *Socket) Local() string { } func (s *Socket) Send(m *transport.Message) error { - select { - case <-s.closed: - return io.EOF - default: - // no op - } - // make copy msg := &transport.Message{ Header: make(map[string]string), @@ -92,13 +92,6 @@ func (s *Socket) Send(m *transport.Message) error { } func (s *Socket) Recv(m *transport.Message) error { - select { - case <-s.closed: - return io.EOF - default: - // no op - } - // receive a message select { case msg := <-s.recv: