micro/server/rpc_router.go

468 lines
12 KiB
Go
Raw Normal View History

2015-12-03 03:59:32 +03:00
package server
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// Meh, we need to get rid of this shit
import (
2018-03-03 14:53:52 +03:00
"context"
"errors"
"io"
"reflect"
"strings"
"sync"
"unicode"
"unicode/utf8"
2017-05-11 22:43:42 +03:00
"github.com/micro/go-log"
2019-01-08 18:38:25 +03:00
"github.com/micro/go-micro/codec"
)
var (
lastStreamResponseError = errors.New("EOS")
// A value sent as a placeholder for the server's response value when the server
// receives an invalid request. It is never decoded by the client since the Response
// contains an error when it is used.
invalidRequest = struct{}{}
// Precompute the reflect type for error. Can't use error directly
// because Typeof takes an empty interface value. This is annoying.
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
)
type methodType struct {
sync.Mutex // protects counters
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
ContextType reflect.Type
stream bool
}
type service struct {
name string // name of service
rcvr reflect.Value // receiver of methods for the service
typ reflect.Type // type of the receiver
method map[string]*methodType // registered methods
}
type request struct {
2019-01-08 18:38:25 +03:00
msg *codec.Message
next *request // for free list in Server
}
type response struct {
2019-01-08 18:38:25 +03:00
msg *codec.Message
next *response // for free list in Server
}
2019-01-07 17:44:40 +03:00
// router represents an RPC router.
type router struct {
name string
mu sync.Mutex // protects the serviceMap
serviceMap map[string]*service
reqLock sync.Mutex // protects freeReq
freeReq *request
respLock sync.Mutex // protects freeResp
freeResp *response
hdlrWrappers []HandlerWrapper
}
2019-01-07 17:44:40 +03:00
func newRpcRouter(opts Options) *router {
return &router{
name: opts.Name,
hdlrWrappers: opts.HdlrWrappers,
serviceMap: make(map[string]*service),
}
}
// Is this an exported - upper case - name?
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
// Is this type exported or a builtin?
func isExportedOrBuiltinType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
// PkgPath will be non-empty even for an exported type,
// so we need to check the type name as well.
return isExported(t.Name()) || t.PkgPath() == ""
}
// prepareMethod returns a methodType for the provided method or nil
// in case if the method was unsuitable.
func prepareMethod(method reflect.Method) *methodType {
mtype := method.Type
mname := method.Name
var replyType, argType, contextType reflect.Type
2015-12-18 04:01:59 +03:00
var stream bool
// Method must be exported.
if method.PkgPath != "" {
return nil
}
switch mtype.NumIn() {
2015-12-18 04:01:59 +03:00
case 3:
// assuming streaming
argType = mtype.In(2)
contextType = mtype.In(1)
stream = true
case 4:
// method that takes a context
argType = mtype.In(2)
replyType = mtype.In(3)
contextType = mtype.In(1)
default:
2017-05-11 22:43:42 +03:00
log.Log("method", mname, "of", mtype, "has wrong number of ins:", mtype.NumIn())
return nil
}
2015-12-18 04:01:59 +03:00
if stream {
// check stream type
2018-04-14 20:15:09 +03:00
streamType := reflect.TypeOf((*Stream)(nil)).Elem()
2015-12-18 04:01:59 +03:00
if !argType.Implements(streamType) {
2018-04-14 20:15:09 +03:00
log.Log(mname, "argument does not implement Stream interface:", argType)
return nil
}
2015-12-18 04:01:59 +03:00
} else {
// if not stream check the replyType
// First arg need not be a pointer.
if !isExportedOrBuiltinType(argType) {
2017-05-11 22:43:42 +03:00
log.Log(mname, "argument type not exported:", argType)
return nil
}
2015-12-18 04:01:59 +03:00
if replyType.Kind() != reflect.Ptr {
2017-05-11 22:43:42 +03:00
log.Log("method", mname, "reply type not a pointer:", replyType)
return nil
}
2015-12-18 04:01:59 +03:00
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
2017-05-11 22:43:42 +03:00
log.Log("method", mname, "reply type not exported:", replyType)
return nil
}
}
// Method needs one out.
if mtype.NumOut() != 1 {
2017-05-11 22:43:42 +03:00
log.Log("method", mname, "has wrong number of outs:", mtype.NumOut())
return nil
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
2017-05-11 22:43:42 +03:00
log.Log("method", mname, "returns", returnType.String(), "not error")
return nil
}
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
}
2019-01-09 22:28:13 +03:00
func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, errmsg string, last bool) (err error) {
2019-01-08 18:38:25 +03:00
msg := new(codec.Message)
msg.Type = codec.Response
2019-01-07 17:44:40 +03:00
resp := router.getResponse()
2019-01-08 18:38:25 +03:00
resp.msg = msg
// Encode the response header
2019-01-08 18:38:25 +03:00
resp.msg.Method = req.msg.Method
if errmsg != "" {
2019-01-08 18:38:25 +03:00
resp.msg.Error = errmsg
reply = invalidRequest
}
2019-01-08 18:38:25 +03:00
resp.msg.Id = req.msg.Id
sending.Lock()
2019-01-08 18:38:25 +03:00
err = cc.Write(resp.msg, reply)
sending.Unlock()
2019-01-07 17:44:40 +03:00
router.freeResponse(resp)
return err
}
2019-01-09 22:28:13 +03:00
func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) {
function := mtype.method.Func
var returnValues []reflect.Value
2015-12-02 23:56:50 +03:00
r := &rpcRequest{
2019-01-09 19:20:57 +03:00
service: req.msg.Target,
2019-01-08 23:32:47 +03:00
contentType: req.msg.Header["Content-Type"],
2019-01-08 18:38:25 +03:00
method: req.msg.Method,
2019-01-09 19:20:57 +03:00
body: req.msg.Body,
2015-12-02 23:56:50 +03:00
}
if !mtype.stream {
2015-12-02 23:56:50 +03:00
fn := func(ctx context.Context, req Request, rsp interface{}) error {
2019-01-09 19:20:57 +03:00
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
// The return value for the method is an error.
if err := returnValues[0].Interface(); err != nil {
return err.(error)
}
return nil
}
2019-01-07 17:44:40 +03:00
for i := len(router.hdlrWrappers); i > 0; i-- {
fn = router.hdlrWrappers[i-1](fn)
}
errmsg := ""
2015-12-02 23:56:50 +03:00
err := fn(ctx, r, replyv.Interface())
if err != nil {
errmsg = err.Error()
}
2019-01-09 19:20:57 +03:00
err = router.sendResponse(sending, req, replyv.Interface(), cc, errmsg, true)
if err != nil {
log.Log("rpc call: unable to send response: ", err)
}
2019-01-07 17:44:40 +03:00
router.freeRequest(req)
return
}
// declare a local error to see if we errored out already
// keep track of the type, to make sure we return
// the same one consistently
var lastError error
2015-12-18 04:01:59 +03:00
stream := &rpcStream{
context: ctx,
2019-01-09 22:28:13 +03:00
codec: cc.(codec.Codec),
2015-12-18 04:01:59 +03:00
request: r,
2019-01-08 18:38:25 +03:00
id: req.msg.Id,
}
// Invoke the method, providing a new value for the reply.
2015-12-18 04:01:59 +03:00
fn := func(ctx context.Context, req Request, stream interface{}) error {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
if err := returnValues[0].Interface(); err != nil {
// the function returned an error, we use that
return err.(error)
} else if lastError != nil {
// we had an error inside sendReply, we use that
return lastError
} else {
// no error, we send the special EOS error
return lastStreamResponseError
}
}
2019-01-07 17:44:40 +03:00
for i := len(router.hdlrWrappers); i > 0; i-- {
fn = router.hdlrWrappers[i-1](fn)
}
2015-12-02 23:56:50 +03:00
// client.Stream request
r.stream = true
errmsg := ""
2015-12-18 04:01:59 +03:00
if err := fn(ctx, r, stream); err != nil {
errmsg = err.Error()
}
// this is the last packet, we don't do anything with
// the error here (well sendStreamResponse will log it
// already)
2019-01-09 19:20:57 +03:00
router.sendResponse(sending, req, nil, cc, errmsg, true)
2019-01-07 17:44:40 +03:00
router.freeRequest(req)
}
func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
return contextv
}
return reflect.Zero(m.ContextType)
}
2019-01-07 17:44:40 +03:00
func (router *router) getRequest() *request {
router.reqLock.Lock()
req := router.freeReq
if req == nil {
req = new(request)
} else {
2019-01-07 17:44:40 +03:00
router.freeReq = req.next
*req = request{}
}
2019-01-07 17:44:40 +03:00
router.reqLock.Unlock()
return req
}
2019-01-07 17:44:40 +03:00
func (router *router) freeRequest(req *request) {
router.reqLock.Lock()
req.next = router.freeReq
router.freeReq = req
router.reqLock.Unlock()
}
2019-01-07 17:44:40 +03:00
func (router *router) getResponse() *response {
router.respLock.Lock()
resp := router.freeResp
if resp == nil {
resp = new(response)
} else {
2019-01-07 17:44:40 +03:00
router.freeResp = resp.next
*resp = response{}
}
2019-01-07 17:44:40 +03:00
router.respLock.Unlock()
return resp
}
2019-01-07 17:44:40 +03:00
func (router *router) freeResponse(resp *response) {
router.respLock.Lock()
resp.next = router.freeResp
router.freeResp = resp
router.respLock.Unlock()
}
2019-01-09 19:20:57 +03:00
func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
cc := r.Codec()
2019-01-08 18:38:25 +03:00
service, mtype, req, keepReading, err = router.readHeader(cc)
if err != nil {
if !keepReading {
return
}
// discard body
2019-01-08 18:38:25 +03:00
cc.ReadBody(nil)
return
}
2015-12-18 04:01:59 +03:00
// is it a streaming request? then we don't read the body
if mtype.stream {
2019-01-08 18:38:25 +03:00
cc.ReadBody(nil)
2015-12-18 04:01:59 +03:00
return
}
// Decode the argument value.
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
// argv guaranteed to be a pointer now.
2019-01-08 18:38:25 +03:00
if err = cc.ReadBody(argv.Interface()); err != nil {
return
}
if argIsValue {
argv = argv.Elem()
}
if !mtype.stream {
replyv = reflect.New(mtype.ReplyType.Elem())
}
return
}
2019-01-09 22:28:13 +03:00
func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) {
// Grab the request header.
2019-01-08 18:38:25 +03:00
msg := new(codec.Message)
msg.Type = codec.Request
2019-01-07 17:44:40 +03:00
req = router.getRequest()
2019-01-08 18:38:25 +03:00
req.msg = msg
err = cc.ReadHeader(msg, msg.Type)
if err != nil {
req = nil
if err == io.EOF || err == io.ErrUnexpectedEOF {
return
}
2019-01-07 17:44:40 +03:00
err = errors.New("rpc: router cannot decode request: " + err.Error())
return
}
// We read the header successfully. If we see an error now,
// we can still recover and move on to the next request.
keepReading = true
2019-01-08 18:38:25 +03:00
serviceMethod := strings.Split(req.msg.Method, ".")
if len(serviceMethod) != 2 {
2019-01-08 18:38:25 +03:00
err = errors.New("rpc: service/method request ill-formed: " + req.msg.Method)
return
}
// Look up the request.
2019-01-07 17:44:40 +03:00
router.mu.Lock()
service = router.serviceMap[serviceMethod[0]]
router.mu.Unlock()
if service == nil {
2019-01-08 18:38:25 +03:00
err = errors.New("rpc: can't find service " + req.msg.Method)
return
}
mtype = service.method[serviceMethod[1]]
if mtype == nil {
2019-01-08 18:38:25 +03:00
err = errors.New("rpc: can't find method " + req.msg.Method)
}
return
}
2019-01-09 12:06:30 +03:00
func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
return newRpcHandler(h, opts...)
}
func (router *router) Handle(h Handler) error {
router.mu.Lock()
defer router.mu.Unlock()
if router.serviceMap == nil {
router.serviceMap = make(map[string]*service)
}
if len(h.Name()) == 0 {
return errors.New("rpc.Handle: handler has no name")
}
if !isExported(h.Name()) {
return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
}
rcvr := h.Handler()
s := new(service)
s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
// check name
if _, present := router.serviceMap[h.Name()]; present {
return errors.New("rpc.Handle: service already defined: " + h.Name())
}
s.name = h.Name()
s.method = make(map[string]*methodType)
// Install the methods
for m := 0; m < s.typ.NumMethod(); m++ {
method := s.typ.Method(m)
if mt := prepareMethod(method); mt != nil {
s.method[method.Name] = mt
}
}
// Check there are methods
if len(s.method) == 0 {
return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type")
}
// save handler
router.serviceMap[s.name] = s
return nil
}
2019-01-09 22:11:47 +03:00
func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) error {
2019-01-09 12:06:30 +03:00
sending := new(sync.Mutex)
2019-01-09 19:20:57 +03:00
service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r)
2019-01-09 12:06:30 +03:00
if err != nil {
if !keepReading {
return err
}
// send a response if we actually managed to read a header.
if req != nil {
2019-01-09 22:28:13 +03:00
router.sendResponse(sending, req, invalidRequest, rsp.Codec(), err.Error(), true)
2019-01-09 12:06:30 +03:00
router.freeRequest(req)
}
return err
}
2019-01-09 22:28:13 +03:00
service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
2019-01-09 12:06:30 +03:00
return nil
}