Merge pull request #24 from micro/server_wrapper

Experimental server side wrappers for handlers
This commit is contained in:
Asim 2015-12-02 21:20:01 +00:00
commit ab650630ce
14 changed files with 457 additions and 176 deletions

View File

@ -56,14 +56,6 @@ type client struct {
shutdown bool shutdown bool
} }
// A clientCodec implements writing of RPC requests and
// reading of RPC responses for the client side of an RPC session.
// The client calls WriteRequest to write a request to the connection
// and calls ReadResponseHeader and ReadResponseBody in pairs
// to read responses. The client calls Close when finished with the
// connection. ReadResponseBody may be called with a nil
// argument to force the body of the response to be read and then
// discarded.
type clientCodec interface { type clientCodec interface {
WriteRequest(*request, interface{}) error WriteRequest(*request, interface{}) error
ReadResponseHeader(*response) error ReadResponseHeader(*response) error
@ -224,8 +216,6 @@ func (call *call) done() {
} }
} }
// NewclientWithCodec is like Newclient but uses the specified
// codec to encode requests and decode responses.
func newClientWithCodec(codec clientCodec) *client { func newClientWithCodec(codec clientCodec) *client {
client := &client{ client := &client{
codec: codec, codec: codec,

View File

@ -16,8 +16,11 @@ type MessageType int
// Takes in a connection/buffer and returns a new Codec // Takes in a connection/buffer and returns a new Codec
type NewCodec func(io.ReadWriteCloser) Codec type NewCodec func(io.ReadWriteCloser) Codec
// Codec encodes/decodes various types of // Codec encodes/decodes various types of messages used within go-micro.
// messages used within go-micro // ReadHeader and ReadBody are called in pairs to read requests/responses
// from the connection. Close is called when finished with the
// connection. ReadBody may be called with a nil argument to force the
// body to be read and discarded.
type Codec interface { type Codec interface {
ReadHeader(*Message, MessageType) error ReadHeader(*Message, MessageType) error
ReadBody(interface{}) error ReadBody(interface{}) error

View File

@ -2,7 +2,6 @@ package main
import ( import (
"fmt" "fmt"
"time"
"github.com/micro/go-micro/client" "github.com/micro/go-micro/client"
"github.com/micro/go-micro/cmd" "github.com/micro/go-micro/cmd"
@ -11,41 +10,6 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
// wrapper example code
// log wrapper logs every time a request is made
type logWrapper struct {
client.Client
}
func (l *logWrapper) Call(ctx context.Context, req client.Request, rsp interface{}) error {
md, _ := c.GetMetadata(ctx)
fmt.Printf("[Log Wrapper] ctx: %v service: %s method: %s\n", md, req.Service(), req.Method())
return l.Client.Call(ctx, req, rsp)
}
// trace wrapper attaches a unique trace ID - timestamp
type traceWrapper struct {
client.Client
}
func (t *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}) error {
ctx = c.WithMetadata(ctx, map[string]string{
"X-Trace-Id": fmt.Sprintf("%d", time.Now().Unix()),
})
return t.Client.Call(ctx, req, rsp)
}
// Implements client.Wrapper as logWrapper
func logWrap(c client.Client) client.Client {
return &logWrapper{c}
}
// Implements client.Wrapper as traceWrapper
func traceWrap(c client.Client) client.Client {
return &traceWrapper{c}
}
// publishes a message // publishes a message
func pub() { func pub() {
msg := client.NewPublication("topic.go.micro.srv.example", &example.Message{ msg := client.NewPublication("topic.go.micro.srv.example", &example.Message{
@ -120,7 +84,6 @@ func stream() {
func main() { func main() {
cmd.Init() cmd.Init()
fmt.Println("\n--- Call example ---\n") fmt.Println("\n--- Call example ---\n")
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
call(i) call(i)
@ -131,19 +94,4 @@ func main() {
fmt.Println("\n--- Publisher example ---\n") fmt.Println("\n--- Publisher example ---\n")
pub() pub()
fmt.Println("\n--- Wrapper example ---\n")
// Wrap the default client
client.DefaultClient = logWrap(client.DefaultClient)
call(0)
// Wrap using client.Wrap option
client.DefaultClient = client.NewClient(
client.Wrap(traceWrap),
client.Wrap(logWrap),
)
call(1)
} }

View File

@ -0,0 +1,40 @@
package main
import (
"fmt"
"github.com/micro/go-micro/client"
"github.com/micro/go-micro/cmd"
c "github.com/micro/go-micro/context"
example "github.com/micro/go-micro/examples/server/proto/example"
"golang.org/x/net/context"
)
// publishes a message
func pub(i int) {
msg := client.NewPublication("topic.go.micro.srv.example", &example.Message{
Say: fmt.Sprintf("This is a publication %d", i),
})
// create context with metadata
ctx := c.WithMetadata(context.Background(), map[string]string{
"X-User-Id": "john",
"X-From-Id": "script",
})
// publish message
if err := client.Publish(ctx, msg); err != nil {
fmt.Println("pub err: ", err)
return
}
fmt.Printf("Published %d: %v\n", i, msg)
}
func main() {
cmd.Init()
fmt.Println("\n--- Publisher example ---\n")
for i := 0; i < 10; i++ {
pub(i)
}
}

View File

@ -0,0 +1,91 @@
package main
import (
"fmt"
"time"
"github.com/micro/go-micro/client"
"github.com/micro/go-micro/cmd"
c "github.com/micro/go-micro/context"
example "github.com/micro/go-micro/examples/server/proto/example"
"golang.org/x/net/context"
)
// wrapper example code
// log wrapper logs every time a request is made
type logWrapper struct {
client.Client
}
func (l *logWrapper) Call(ctx context.Context, req client.Request, rsp interface{}) error {
md, _ := c.GetMetadata(ctx)
fmt.Printf("[Log Wrapper] ctx: %v service: %s method: %s\n", md, req.Service(), req.Method())
return l.Client.Call(ctx, req, rsp)
}
// trace wrapper attaches a unique trace ID - timestamp
type traceWrapper struct {
client.Client
}
func (t *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}) error {
ctx = c.WithMetadata(ctx, map[string]string{
"X-Trace-Id": fmt.Sprintf("%d", time.Now().Unix()),
})
return t.Client.Call(ctx, req, rsp)
}
// Implements client.Wrapper as logWrapper
func logWrap(c client.Client) client.Client {
return &logWrapper{c}
}
// Implements client.Wrapper as traceWrapper
func traceWrap(c client.Client) client.Client {
return &traceWrapper{c}
}
func call(i int) {
// Create new request to service go.micro.srv.example, method Example.Call
req := client.NewRequest("go.micro.srv.example", "Example.Call", &example.Request{
Name: "John",
})
// create context with metadata
ctx := c.WithMetadata(context.Background(), map[string]string{
"X-User-Id": "john",
"X-From-Id": "script",
})
rsp := &example.Response{}
// Call service
if err := client.Call(ctx, req, rsp); err != nil {
fmt.Println("call err: ", err, rsp)
return
}
fmt.Println("Call:", i, "rsp:", rsp.Msg)
}
func main() {
cmd.Init()
fmt.Println("\n--- Log Wrapper example ---\n")
// Wrap the default client
client.DefaultClient = logWrap(client.DefaultClient)
call(0)
fmt.Println("\n--- Log+Trace Wrapper example ---\n")
// Wrap using client.Wrap option
client.DefaultClient = client.NewClient(
client.Wrap(traceWrap),
client.Wrap(logWrap),
)
call(1)
}

View File

@ -6,15 +6,35 @@ import (
"github.com/micro/go-micro/examples/server/handler" "github.com/micro/go-micro/examples/server/handler"
"github.com/micro/go-micro/examples/server/subscriber" "github.com/micro/go-micro/examples/server/subscriber"
"github.com/micro/go-micro/server" "github.com/micro/go-micro/server"
"golang.org/x/net/context"
) )
func logWrapper(fn server.HandlerFunc) server.HandlerFunc {
return func(ctx context.Context, req server.Request, rsp interface{}) error {
log.Infof("[Log Wrapper] Before serving request method: %v", req.Method())
err := fn(ctx, req, rsp)
log.Infof("[Log Wrapper] After serving request")
return err
}
}
func logSubWrapper(fn server.SubscriberFunc) server.SubscriberFunc {
return func(ctx context.Context, req server.Publication) error {
log.Infof("[Log Sub Wrapper] Before serving publication topic: %v", req.Topic())
err := fn(ctx, req)
log.Infof("[Log Sub Wrapper] After serving publication")
return err
}
}
func main() { func main() {
// optionally setup command line usage // optionally setup command line usage
cmd.Init() cmd.Init()
// server.DefaultServer = server.NewServer( server.DefaultServer = server.NewServer(
// server.Codec("application/bson", bson.Codec), server.WrapHandler(logWrapper),
// ) server.WrapSubscriber(logSubWrapper),
)
// Initialise Server // Initialise Server
server.Init( server.Init(
@ -29,19 +49,23 @@ func main() {
) )
// Register Subscribers // Register Subscribers
server.Subscribe( if err := server.Subscribe(
server.NewSubscriber( server.NewSubscriber(
"topic.go.micro.srv.example", "topic.go.micro.srv.example",
new(subscriber.Example), new(subscriber.Example),
), ),
) ); err != nil {
log.Fatal(err)
}
server.Subscribe( if err := server.Subscribe(
server.NewSubscriber( server.NewSubscriber(
"topic.go.micro.srv.example", "topic.go.micro.srv.example",
subscriber.Handler, subscriber.Handler,
), ),
) ); err != nil {
log.Fatal(err)
}
// Run server // Run server
if err := server.Run(); err != nil { if err := server.Run(); err != nil {

View File

@ -13,6 +13,7 @@ func (e *Example) Handle(ctx context.Context, msg *example.Message) error {
return nil return nil
} }
func Handler(msg *example.Message) { func Handler(ctx context.Context, msg *example.Message) error {
log.Info("Function Received message: ", msg.Say) log.Info("Function Received message: ", msg.Say)
return nil
} }

View File

@ -18,6 +18,8 @@ type options struct {
advertise string advertise string
id string id string
version string version string
hdlrWrappers []HandlerWrapper
subWrappers []SubscriberWrapper
} }
func newOptions(opt ...Option) options { func newOptions(opt ...Option) options {
@ -153,3 +155,17 @@ func Metadata(md map[string]string) Option {
o.metadata = md o.metadata = md
} }
} }
// Adds a handler Wrapper to a list of options passed into the server
func WrapHandler(w HandlerWrapper) Option {
return func(o *options) {
o.hdlrWrappers = append(o.hdlrWrappers, w)
}
}
// Adds a subscriber Wrapper to a list of options passed into the server
func WrapSubscriber(w SubscriberWrapper) Option {
return func(o *options) {
o.subWrappers = append(o.subWrappers, w)
}
}

47
server/rpc_request.go Normal file
View File

@ -0,0 +1,47 @@
package server
type rpcRequest struct {
service string
method string
contentType string
request interface{}
stream bool
}
type rpcPublication struct {
topic string
contentType string
message interface{}
}
func (r *rpcRequest) ContentType() string {
return r.contentType
}
func (r *rpcRequest) Service() string {
return r.service
}
func (r *rpcRequest) Method() string {
return r.method
}
func (r *rpcRequest) Request() interface{} {
return r.request
}
func (r *rpcRequest) Stream() bool {
return r.stream
}
func (r *rpcPublication) ContentType() string {
return r.contentType
}
func (r *rpcPublication) Topic() string {
return r.topic
}
func (r *rpcPublication) Message() interface{} {
return r.message
}

View File

@ -28,9 +28,14 @@ type rpcServer struct {
} }
func newRpcServer(opts ...Option) Server { func newRpcServer(opts ...Option) Server {
options := newOptions(opts...)
return &rpcServer{ return &rpcServer{
opts: newOptions(opts...), opts: options,
rpc: newServer(), rpc: &server{
name: options.name,
serviceMap: make(map[string]*service),
hdlrWrappers: options.hdlrWrappers,
},
handlers: make(map[string]Handler), handlers: make(map[string]Handler),
subscribers: make(map[*subscriber][]broker.Subscriber), subscribers: make(map[*subscriber][]broker.Subscriber),
exit: make(chan chan error), exit: make(chan chan error),
@ -43,7 +48,8 @@ func (s *rpcServer) accept(sock transport.Socket) {
return return
} }
cf, err := s.newCodec(msg.Header["Content-Type"]) ct := msg.Header["Content-Type"]
cf, err := s.newCodec(ct)
// TODO: needs better error handling // TODO: needs better error handling
if err != nil { if err != nil {
sock.Send(&transport.Message{ sock.Send(&transport.Message{
@ -66,8 +72,9 @@ func (s *rpcServer) accept(sock transport.Socket) {
delete(hdr, "Content-Type") delete(hdr, "Content-Type")
ctx := c.WithMetadata(context.Background(), hdr) ctx := c.WithMetadata(context.Background(), hdr)
// TODO: needs better error handling // TODO: needs better error handling
if err := s.rpc.ServeRequestWithContext(ctx, codec); err != nil { if err := s.rpc.serveRequest(ctx, codec, ct); err != nil {
log.Errorf("Unexpected error serving request, closing socket: %v", err) log.Errorf("Unexpected error serving request, closing socket: %v", err)
sock.Close() sock.Close()
} }
@ -106,7 +113,7 @@ func (s *rpcServer) NewHandler(h interface{}) Handler {
} }
func (s *rpcServer) Handle(h Handler) error { func (s *rpcServer) Handle(h Handler) error {
if err := s.rpc.Register(h.Handler()); err != nil { if err := s.rpc.register(h.Handler()); err != nil {
return err return err
} }
s.Lock() s.Lock()
@ -128,6 +135,10 @@ func (s *rpcServer) Subscribe(sb Subscriber) error {
return fmt.Errorf("invalid subscriber: no handler functions") return fmt.Errorf("invalid subscriber: no handler functions")
} }
if err := validateSubscriber(sb); err != nil {
return err
}
s.Lock() s.Lock()
_, ok = s.subscribers[sub] _, ok = s.subscribers[sub]
if ok { if ok {
@ -200,7 +211,7 @@ func (s *rpcServer) Register() error {
defer s.Unlock() defer s.Unlock()
for sb, _ := range s.subscribers { for sb, _ := range s.subscribers {
handler := s.createSubHandler(sb) handler := s.createSubHandler(sb, s.opts)
sub, err := config.broker.Subscribe(sb.Topic(), handler) sub, err := config.broker.Subscribe(sb.Topic(), handler)
if err != nil { if err != nil {
return err return err
@ -271,7 +282,7 @@ func (s *rpcServer) Start() error {
registerHealthChecker(s) registerHealthChecker(s)
config := s.Config() config := s.Config()
ts, err := config.transport.Listen(s.opts.address) ts, err := config.transport.Listen(config.address)
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,11 +18,8 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
const (
lastStreamResponseError = "EOS"
)
var ( var (
lastStreamResponseError = errors.New("EOS")
// A value sent as a placeholder for the server's response value when the server // A value sent as a placeholder for the server's response value when the server
// receives an invalid request. It is never decoded by the client since the Response // receives an invalid request. It is never decoded by the client since the Response
// contains an error when it is used. // contains an error when it is used.
@ -43,10 +40,6 @@ type methodType struct {
numCalls uint numCalls uint
} }
func (m *methodType) TakesContext() bool {
return m.ContextType != nil
}
func (m *methodType) NumCalls() (n uint) { func (m *methodType) NumCalls() (n uint) {
m.Lock() m.Lock()
n = m.numCalls n = m.numCalls
@ -76,16 +69,14 @@ type response struct {
// server represents an RPC Server. // server represents an RPC Server.
type server struct { type server struct {
name string
mu sync.Mutex // protects the serviceMap mu sync.Mutex // protects the serviceMap
serviceMap map[string]*service serviceMap map[string]*service
reqLock sync.Mutex // protects freeReq reqLock sync.Mutex // protects freeReq
freeReq *request freeReq *request
respLock sync.Mutex // protects freeResp respLock sync.Mutex // protects freeResp
freeResp *response freeResp *response
} hdlrWrappers []HandlerWrapper
func newServer() *server {
return &server{serviceMap: make(map[string]*service)}
} }
// Is this an exported - upper case - name? // Is this an exported - upper case - name?
@ -104,10 +95,6 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
return isExported(t.Name()) || t.PkgPath() == "" return isExported(t.Name()) || t.PkgPath() == ""
} }
func (server *server) Register(rcvr interface{}) error {
return server.register(rcvr, "", false)
}
// prepareMethod returns a methodType for the provided method or nil // prepareMethod returns a methodType for the provided method or nil
// in case if the method was unsuitable. // in case if the method was unsuitable.
func prepareMethod(method reflect.Method) *methodType { func prepareMethod(method reflect.Method) *methodType {
@ -122,11 +109,6 @@ func prepareMethod(method reflect.Method) *methodType {
} }
switch mtype.NumIn() { switch mtype.NumIn() {
case 3:
// normal method
argType = mtype.In(1)
replyType = mtype.In(2)
contextType = nil
case 4: case 4:
// method that takes a context // method that takes a context
argType = mtype.In(2) argType = mtype.In(2)
@ -188,7 +170,7 @@ func prepareMethod(method reflect.Method) *methodType {
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
} }
func (server *server) register(rcvr interface{}, name string, useName bool) error { func (server *server) register(rcvr interface{}) error {
server.mu.Lock() server.mu.Lock()
defer server.mu.Unlock() defer server.mu.Unlock()
if server.serviceMap == nil { if server.serviceMap == nil {
@ -198,13 +180,10 @@ func (server *server) register(rcvr interface{}, name string, useName bool) erro
s.typ = reflect.TypeOf(rcvr) s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr) s.rcvr = reflect.ValueOf(rcvr)
sname := reflect.Indirect(s.rcvr).Type().Name() sname := reflect.Indirect(s.rcvr).Type().Name()
if useName {
sname = name
}
if sname == "" { if sname == "" {
log.Fatal("rpc: no service name for type", s.typ.String()) log.Fatal("rpc: no service name for type", s.typ.String())
} }
if !isExported(sname) && !useName { if !isExported(sname) {
s := "rpc Register: type " + sname + " is not exported" s := "rpc Register: type " + sname + " is not exported"
log.Print(s) log.Print(s)
return errors.New(s) return errors.New(s)
@ -251,28 +230,42 @@ func (server *server) sendResponse(sending *sync.Mutex, req *request, reply inte
return err return err
} }
func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec) { func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec, ct string) {
mtype.Lock() mtype.Lock()
mtype.numCalls++ mtype.numCalls++
mtype.Unlock() mtype.Unlock()
function := mtype.method.Func function := mtype.method.Func
var returnValues []reflect.Value var returnValues []reflect.Value
if !mtype.stream { r := &rpcRequest{
service: s.name,
// Invoke the method, providing a new value for the reply. contentType: ct,
if mtype.TakesContext() { method: req.ServiceMethod,
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, replyv}) request: argv.Interface(),
} else {
returnValues = function.Call([]reflect.Value{s.rcvr, argv, replyv})
} }
if !mtype.stream {
fn := func(ctx context.Context, req Request, rsp interface{}) error {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rsp)})
// The return value for the method is an error. // The return value for the method is an error.
errInter := returnValues[0].Interface() if err := returnValues[0].Interface(); err != nil {
errmsg := "" return err.(error)
if errInter != nil {
errmsg = errInter.(error).Error()
} }
return nil
}
for i := len(server.hdlrWrappers); i > 0; i-- {
fn = server.hdlrWrappers[i-1](fn)
}
errmsg := ""
err := fn(ctx, r, replyv.Interface())
if err != nil {
errmsg = err.Error()
}
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true) server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true)
server.freeRequest(req) server.freeRequest(req)
return return
@ -314,22 +307,31 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex,
} }
// Invoke the method, providing a new value for the reply. // Invoke the method, providing a new value for the reply.
if mtype.TakesContext() { fn := func(ctx context.Context, req Request, rspFn interface{}) error {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, reflect.ValueOf(sendReply)}) returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rspFn)})
} else { if err := returnValues[0].Interface(); err != nil {
returnValues = function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)})
}
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
// the function returned an error, we use that // the function returned an error, we use that
errmsg = errInter.(error).Error() return err.(error)
} else if lastError != nil { } else if lastError != nil {
// we had an error inside sendReply, we use that // we had an error inside sendReply, we use that
errmsg = lastError.Error() return lastError
} else { } else {
// no error, we send the special EOS error // no error, we send the special EOS error
errmsg = lastStreamResponseError return lastStreamResponseError
}
return nil
}
for i := len(server.hdlrWrappers); i > 0; i-- {
fn = server.hdlrWrappers[i-1](fn)
}
// client.Stream request
r.stream = true
errmsg := ""
if err := fn(ctx, r, reflect.ValueOf(sendReply).Interface()); err != nil {
errmsg = err.Error()
} }
// this is the last packet, we don't do anything with // this is the last packet, we don't do anything with
@ -346,7 +348,7 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
return reflect.Zero(m.ContextType) return reflect.Zero(m.ContextType)
} }
func (server *server) ServeRequestWithContext(ctx context.Context, codec serverCodec) error { func (server *server) serveRequest(ctx context.Context, codec serverCodec, ct string) error {
sending := new(sync.Mutex) sending := new(sync.Mutex)
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil { if err != nil {
@ -360,7 +362,7 @@ func (server *server) ServeRequestWithContext(ctx context.Context, codec serverC
} }
return err return err
} }
service.call(ctx, server, sending, mtype, req, argv, replyv, codec) service.call(ctx, server, sending, mtype, req, argv, replyv, codec, ct)
return nil return nil
} }
@ -474,13 +476,6 @@ func (server *server) readRequestHeader(codec serverCodec) (service *service, mt
return return
} }
// A serverCodec implements reading of RPC requests and writing of
// RPC responses for the server side of an RPC session.
// The server calls ReadRequestHeader and ReadRequestBody in pairs
// to read requests from the connection, and it calls WriteResponse to
// write a response back. The server calls Close when finished with the
// connection. ReadRequestBody may be called with a nil
// argument to force the body of the request to be read and discarded.
type serverCodec interface { type serverCodec interface {
ReadRequestHeader(*request) error ReadRequestHeader(*request) error
ReadRequestBody(interface{}) error ReadRequestBody(interface{}) error

View File

@ -31,6 +31,21 @@ type Server interface {
Stop() error Stop() error
} }
type Publication interface {
Topic() string
Message() interface{}
ContentType() string
}
type Request interface {
Service() string
Method() string
ContentType() string
Request() interface{}
// indicates whether the response should be streaming
Stream() bool
}
type Option func(*options) type Option func(*options)
var ( var (

21
server/server_wrapper.go Normal file
View File

@ -0,0 +1,21 @@
package server
import (
"golang.org/x/net/context"
)
// HandlerFunc represents a single method of a handler. It's used primarily
// for the wrappers. What's handed to the actual method is the concrete
// request and response types.
type HandlerFunc func(ctx context.Context, req Request, rsp interface{}) error
// SubscriberFunc represents a single method of a subscriber. It's used primarily
// for the wrappers. What's handed to the actual method is the concrete
// publication message.
type SubscriberFunc func(ctx context.Context, msg Publication) error
// HandlerWrapper wraps the HandlerFunc and returns the equivalent
type HandlerWrapper func(HandlerFunc) HandlerFunc
// SubscriberWrapper wraps the SubscriberFunc and returns the equivalent
type SubscriberWrapper func(SubscriberFunc) SubscriberFunc

View File

@ -2,6 +2,7 @@ package server
import ( import (
"bytes" "bytes"
"fmt"
"reflect" "reflect"
"github.com/micro/go-micro/broker" "github.com/micro/go-micro/broker"
@ -11,6 +12,10 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
const (
subSig = "func(context.Context, interface{}) error"
)
type handler struct { type handler struct {
method reflect.Value method reflect.Value
reqType reflect.Type reqType reflect.Type
@ -94,16 +99,66 @@ func newSubscriber(topic string, sub interface{}) Subscriber {
} }
} }
func (s *rpcServer) createSubHandler(sb *subscriber) broker.Handler { func validateSubscriber(sub Subscriber) error {
return func(msg *broker.Message) { typ := reflect.TypeOf(sub.Subscriber())
cf, err := s.newCodec(msg.Header["Content-Type"]) var argType reflect.Type
if err != nil {
return if typ.Kind() == reflect.Func {
name := "Func"
switch typ.NumIn() {
case 2:
argType = typ.In(1)
default:
return fmt.Errorf("subscriber %v takes wrong number of args: %v required signature %s", name, typ.NumIn(), subSig)
}
if !isExportedOrBuiltinType(argType) {
return fmt.Errorf("subscriber %v argument type not exported: %v", name, argType)
}
if typ.NumOut() != 1 {
return fmt.Errorf(
"subscriber %v.%v has wrong number of outs: %v require signature %s",
name, typ.NumOut(), subSig)
}
if returnType := typ.Out(0); returnType != typeOfError {
return fmt.Errorf("subscriber %v returns %v not error", name, returnType.String())
}
} else {
hdlr := reflect.ValueOf(sub.Subscriber())
name := reflect.Indirect(hdlr).Type().Name()
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
switch method.Type.NumIn() {
case 3:
argType = method.Type.In(2)
default:
return fmt.Errorf("subscriber %v.%v takes wrong number of args: %v required signature %s",
name, method.Name, method.Type.NumIn(), subSig)
} }
b := &buffer{bytes.NewBuffer(msg.Body)} if !isExportedOrBuiltinType(argType) {
co := cf(b) return fmt.Errorf("%v argument type not exported: %v", name, argType)
if err := co.ReadHeader(&codec.Message{}, codec.Publication); err != nil { }
if method.Type.NumOut() != 1 {
return fmt.Errorf(
"subscriber %v.%v has wrong number of outs: %v require signature %s",
name, method.Name, method.Type.NumOut(), subSig)
}
if returnType := method.Type.Out(0); returnType != typeOfError {
return fmt.Errorf("subscriber %v.%v returns %v not error", name, method.Name, returnType.String())
}
}
}
return nil
}
func (s *rpcServer) createSubHandler(sb *subscriber, opts options) broker.Handler {
return func(msg *broker.Message) {
ct := msg.Header["Content-Type"]
cf, err := s.newCodec(ct)
if err != nil {
return return
} }
@ -113,9 +168,10 @@ func (s *rpcServer) createSubHandler(sb *subscriber) broker.Handler {
} }
delete(hdr, "Content-Type") delete(hdr, "Content-Type")
ctx := c.WithMetadata(context.Background(), hdr) ctx := c.WithMetadata(context.Background(), hdr)
rctx := reflect.ValueOf(ctx)
for _, handler := range sb.handlers { for i := 0; i < len(sb.handlers); i++ {
handler := sb.handlers[i]
var isVal bool var isVal bool
var req reflect.Value var req reflect.Value
@ -125,26 +181,49 @@ func (s *rpcServer) createSubHandler(sb *subscriber) broker.Handler {
req = reflect.New(handler.reqType) req = reflect.New(handler.reqType)
isVal = true isVal = true
} }
if isVal {
req = req.Elem()
}
b := &buffer{bytes.NewBuffer(msg.Body)}
co := cf(b)
defer co.Close()
if err := co.ReadHeader(&codec.Message{}, codec.Publication); err != nil {
continue
}
if err := co.ReadBody(req.Interface()); err != nil { if err := co.ReadBody(req.Interface()); err != nil {
continue continue
} }
if isVal { fn := func(ctx context.Context, msg Publication) error {
req = req.Elem()
}
var vals []reflect.Value var vals []reflect.Value
if sb.typ.Kind() != reflect.Func { if sb.typ.Kind() != reflect.Func {
vals = append(vals, sb.rcvr) vals = append(vals, sb.rcvr)
} }
if handler.ctxType != nil { if handler.ctxType != nil {
vals = append(vals, rctx) vals = append(vals, reflect.ValueOf(ctx))
} }
vals = append(vals, req) vals = append(vals, reflect.ValueOf(msg.Message()))
go handler.method.Call(vals)
returnValues := handler.method.Call(vals)
if err := returnValues[0].Interface(); err != nil {
return err.(error)
}
return nil
}
for i := len(opts.subWrappers); i > 0; i-- {
fn = opts.subWrappers[i-1](fn)
}
go fn(ctx, &rpcPublication{
topic: sb.topic,
contentType: ct,
message: req.Interface(),
})
} }
} }
} }