diff --git a/client/client_wrapper.go b/client/client_wrapper.go index ecbad74f..382747a3 100644 --- a/client/client_wrapper.go +++ b/client/client_wrapper.go @@ -36,3 +36,6 @@ Example usage: // Wrapper wraps a client and returns a client type Wrapper func(Client) Client + +// StreamWrapper wraps a Stream and returns the equivalent +type StreamWrapper func(Streamer) Streamer diff --git a/client/rpc_stream.go b/client/rpc_stream.go index e118057a..0932539a 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -9,6 +9,7 @@ import ( "golang.org/x/net/context" ) +// Implements the streamer interface type rpcStream struct { sync.RWMutex seq uint64 diff --git a/examples/client/main.go b/examples/client/main.go index ff93690b..810ee280 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -56,9 +56,7 @@ func call(i int) { func stream() { // Create new request to service go.micro.srv.example, method Example.Call - req := client.NewRequest("go.micro.srv.example", "Example.Stream", &example.StreamingRequest{ - Count: int64(10), - }) + req := client.NewRequest("go.micro.srv.example", "Example.Stream", &example.StreamingRequest{}) stream, err := client.Stream(context.Background(), req) if err != nil { @@ -66,6 +64,15 @@ func stream() { return } + fmt.Println("sending request") + if err := stream.Send(&example.StreamingRequest{ + Count: int64(10), + }); err != nil { + fmt.Println("err", err) + return + } + fmt.Println("sent request") + for stream.Error() == nil { rsp := &example.StreamingResponse{} err := stream.Recv(rsp) @@ -88,14 +95,14 @@ func stream() { func main() { cmd.Init() - fmt.Println("\n--- Call example ---\n") - for i := 0; i < 10; i++ { - call(i) - } + // fmt.Println("\n--- Call example ---\n") + // for i := 0; i < 10; i++ { + // call(i) + // } fmt.Println("\n--- Streamer example ---\n") stream() - fmt.Println("\n--- Publisher example ---\n") - pub() + // fmt.Println("\n--- Publisher example ---\n") + // pub() } diff --git a/examples/server/handler/example.go b/examples/server/handler/example.go index 6e84cf72..05ebbff5 100644 --- a/examples/server/handler/example.go +++ b/examples/server/handler/example.go @@ -18,16 +18,24 @@ func (e *Example) Call(ctx context.Context, req *example.Request, rsp *example.R return nil } -func (e *Example) Stream(ctx context.Context, req *example.StreamingRequest, response func(interface{}) error) error { +func (e *Example) Stream(ctx context.Context, stream server.Streamer) error { + log.Info("Executing streaming handler") + req := &example.StreamingRequest{} + + // We just want to receive 1 request and then process here + if err := stream.Recv(req); err != nil { + log.Errorf("Error receiving streaming request: %v", err) + return err + } + log.Infof("Received Example.Stream request with count: %d", req.Count) + for i := 0; i < int(req.Count); i++ { log.Infof("Responding: %d", i) - r := &example.StreamingResponse{ + if err := stream.Send(&example.StreamingResponse{ Count: int64(i), - } - - if err := response(r); err != nil { + }); err != nil { return err } } diff --git a/server/rpc_stream.go b/server/rpc_stream.go index 94b111ae..060585a6 100644 --- a/server/rpc_stream.go +++ b/server/rpc_stream.go @@ -1,14 +1,13 @@ package server import ( - "errors" - "io" "log" "sync" "golang.org/x/net/context" ) +// Implements the Streamer interface type rpcStream struct { sync.RWMutex seq uint64 @@ -39,7 +38,7 @@ func (r *rpcStream) Send(msg interface{}) error { Seq: seq, } - err := codec.WriteResponse(&resp, msg, false) + err := r.codec.WriteResponse(&resp, msg, false) if err != nil { log.Println("rpc: writing response:", err) } @@ -52,13 +51,13 @@ func (r *rpcStream) Recv(msg interface{}) error { req := request{} - if err := codec.ReadRequestHeader(&req); err != nil { + if err := r.codec.ReadRequestHeader(&req); err != nil { // discard body - codec.ReadRequestBody(nil) + r.codec.ReadRequestBody(nil) return err } - if err = codec.ReadRequestBody(msg); err != nil { + if err := r.codec.ReadRequestBody(msg); err != nil { return err } diff --git a/server/rpcplus_server.go b/server/rpcplus_server.go index 62fca91c..ff808444 100644 --- a/server/rpcplus_server.go +++ b/server/rpcplus_server.go @@ -102,14 +102,19 @@ func prepareMethod(method reflect.Method) *methodType { mtype := method.Type mname := method.Name var replyType, argType, contextType reflect.Type + var stream bool - stream := false // Method must be exported. if method.PkgPath != "" { return nil } switch mtype.NumIn() { + case 3: + // assuming streaming + argType = mtype.In(2) + contextType = mtype.In(1) + stream = true case 4: // method that takes a context argType = mtype.In(2) @@ -120,44 +125,34 @@ func prepareMethod(method reflect.Method) *methodType { return nil } - // First arg need not be a pointer. - if !isExportedOrBuiltinType(argType) { - log.Println(mname, "argument type not exported:", argType) - return nil - } + if stream { + // check stream type + streamType := reflect.TypeOf((*Streamer)(nil)).Elem() + if !argType.Implements(streamType) { + log.Println(mname, "argument does not implement Streamer interface:", argType) + return nil + } + } else { + // if not stream check the replyType - // the second argument will tell us if it's a streaming call - // or a regular call - if replyType.Kind() == reflect.Func { - // this is a streaming call - stream = true - if replyType.NumIn() != 1 { - log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn()) - return nil - } - if replyType.In(0).Kind() != reflect.Interface { - log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0)) - return nil - } - if replyType.NumOut() != 1 { - log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut()) - return nil - } - if returnType := replyType.Out(0); returnType != typeOfError { - log.Println("method", mname, "sendReply returns", returnType.String(), "not error") + // First arg need not be a pointer. + if !isExportedOrBuiltinType(argType) { + log.Println(mname, "argument type not exported:", argType) return nil } - } else if replyType.Kind() != reflect.Ptr { - log.Println("method", mname, "reply type not a pointer:", replyType) - return nil + if replyType.Kind() != reflect.Ptr { + log.Println("method", mname, "reply type not a pointer:", replyType) + return nil + } + + // Reply type must be exported. + if !isExportedOrBuiltinType(replyType) { + log.Println("method", mname, "reply type not exported:", replyType) + return nil + } } - // Reply type must be exported. - if !isExportedOrBuiltinType(replyType) { - log.Println("method", mname, "reply type not exported:", replyType) - return nil - } // Method needs one out. if mtype.NumOut() != 1 { log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) @@ -242,10 +237,11 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, service: s.name, contentType: ct, method: req.ServiceMethod, - request: argv.Interface(), } if !mtype.stream { + r.request = argv.Interface() + fn := func(ctx context.Context, req Request, rsp interface{}) error { returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rsp)}) @@ -276,40 +272,16 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, // keep track of the type, to make sure we return // the same one consistently var lastError error - var firstType reflect.Type - sendReply := func(oneReply interface{}) error { - - // we already triggered an error, we're done - if lastError != nil { - return lastError - } - - // check the oneReply has the right type using reflection - typ := reflect.TypeOf(oneReply) - if firstType == nil { - firstType = typ - } else { - if firstType != typ { - log.Println("passing wrong type to sendReply", - firstType, "!=", typ) - lastError = errors.New("rpc: passing wrong type to sendReply") - return lastError - } - } - - lastError = server.sendResponse(sending, req, oneReply, codec, "", false) - if lastError != nil { - return lastError - } - - // we manage to send, we're good - return nil + stream := &rpcStream{ + context: ctx, + codec: codec, + request: r, } // Invoke the method, providing a new value for the reply. - fn := func(ctx context.Context, req Request, rspFn interface{}) error { - returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rspFn)}) + fn := func(ctx context.Context, req Request, stream interface{}) error { + returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) if err := returnValues[0].Interface(); err != nil { // the function returned an error, we use that return err.(error) @@ -331,7 +303,7 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, r.stream = true errmsg := "" - if err := fn(ctx, r, reflect.ValueOf(sendReply).Interface()); err != nil { + if err := fn(ctx, r, stream); err != nil { errmsg = err.Error() } @@ -418,6 +390,12 @@ func (server *server) readRequest(codec serverCodec) (service *service, mtype *m return } + // is it a streaming request? then we don't read the body + if mtype.stream { + codec.ReadRequestBody(nil) + return + } + // Decode the argument value. argIsValue := false // if true, need to indirect before calling. if mtype.ArgType.Kind() == reflect.Ptr { diff --git a/server/server_wrapper.go b/server/server_wrapper.go index 45b6d845..45d2c46c 100644 --- a/server/server_wrapper.go +++ b/server/server_wrapper.go @@ -20,4 +20,8 @@ type HandlerWrapper func(HandlerFunc) HandlerFunc // SubscriberWrapper wraps the SubscriberFunc and returns the equivalent type SubscriberWrapper func(SubscriberFunc) SubscriberFunc -type StreamWrapper func(Streamer) Streamer +// StreamerWrapper wraps a Streamer interface and returns the equivalent. +// Because streams exist for the lifetime of a method invocation this +// is a convenient way to wrap a Stream as its in use for trace, monitoring, +// metrics, etc. +type StreamerWrapper func(Streamer) Streamer