package server // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // // Meh, we need to get rid of this shit import ( "errors" "io" "reflect" "strings" "sync" "unicode" "unicode/utf8" "github.com/micro/go-log" "golang.org/x/net/context" ) var ( lastStreamResponseError = errors.New("EOS") // A value sent as a placeholder for the server's response value when the server // receives an invalid request. It is never decoded by the client since the Response // contains an error when it is used. invalidRequest = struct{}{} // Precompute the reflect type for error. Can't use error directly // because Typeof takes an empty interface value. This is annoying. typeOfError = reflect.TypeOf((*error)(nil)).Elem() ) type methodType struct { sync.Mutex // protects counters method reflect.Method ArgType reflect.Type ReplyType reflect.Type ContextType reflect.Type stream bool } type service struct { name string // name of service rcvr reflect.Value // receiver of methods for the service typ reflect.Type // type of the receiver method map[string]*methodType // registered methods } type request struct { ServiceMethod string // format: "Service.Method" Seq uint64 // sequence number chosen by client next *request // for free list in Server } type response struct { ServiceMethod string // echoes that of the Request Seq uint64 // echoes that of the request Error string // error, if any. next *response // for free list in Server } // server represents an RPC Server. type server struct { name string mu sync.Mutex // protects the serviceMap serviceMap map[string]*service reqLock sync.Mutex // protects freeReq freeReq *request respLock sync.Mutex // protects freeResp freeResp *response hdlrWrappers []HandlerWrapper } // Is this an exported - upper case - name? func isExported(name string) bool { rune, _ := utf8.DecodeRuneInString(name) return unicode.IsUpper(rune) } // Is this type exported or a builtin? func isExportedOrBuiltinType(t reflect.Type) bool { for t.Kind() == reflect.Ptr { t = t.Elem() } // PkgPath will be non-empty even for an exported type, // so we need to check the type name as well. return isExported(t.Name()) || t.PkgPath() == "" } // prepareMethod returns a methodType for the provided method or nil // in case if the method was unsuitable. func prepareMethod(method reflect.Method) *methodType { mtype := method.Type mname := method.Name var replyType, argType, contextType reflect.Type var stream bool // Method must be exported. if method.PkgPath != "" { return nil } switch mtype.NumIn() { case 3: // assuming streaming argType = mtype.In(2) contextType = mtype.In(1) stream = true case 4: // method that takes a context argType = mtype.In(2) replyType = mtype.In(3) contextType = mtype.In(1) default: log.Log("method", mname, "of", mtype, "has wrong number of ins:", mtype.NumIn()) return nil } if stream { // check stream type streamType := reflect.TypeOf((*Streamer)(nil)).Elem() if !argType.Implements(streamType) { log.Log(mname, "argument does not implement Streamer interface:", argType) return nil } } else { // if not stream check the replyType // First arg need not be a pointer. if !isExportedOrBuiltinType(argType) { log.Log(mname, "argument type not exported:", argType) return nil } if replyType.Kind() != reflect.Ptr { log.Log("method", mname, "reply type not a pointer:", replyType) return nil } // Reply type must be exported. if !isExportedOrBuiltinType(replyType) { log.Log("method", mname, "reply type not exported:", replyType) return nil } } // Method needs one out. if mtype.NumOut() != 1 { log.Log("method", mname, "has wrong number of outs:", mtype.NumOut()) return nil } // The return type of the method must be error. if returnType := mtype.Out(0); returnType != typeOfError { log.Log("method", mname, "returns", returnType.String(), "not error") return nil } return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} } func (server *server) register(rcvr interface{}) error { server.mu.Lock() defer server.mu.Unlock() if server.serviceMap == nil { server.serviceMap = make(map[string]*service) } s := new(service) s.typ = reflect.TypeOf(rcvr) s.rcvr = reflect.ValueOf(rcvr) sname := reflect.Indirect(s.rcvr).Type().Name() if sname == "" { log.Log("rpc: no service name for type", s.typ.String()) return errors.New("rpc: no service name for type" + s.typ.String()) } if !isExported(sname) { s := "rpc Register: type " + sname + " is not exported" log.Log(s) return errors.New(s) } if _, present := server.serviceMap[sname]; present { return errors.New("rpc: service already defined: " + sname) } s.name = sname s.method = make(map[string]*methodType) // Install the methods for m := 0; m < s.typ.NumMethod(); m++ { method := s.typ.Method(m) if mt := prepareMethod(method); mt != nil { s.method[method.Name] = mt } } if len(s.method) == 0 { s := "rpc Register: type " + sname + " has no exported methods of suitable type" log.Log(s) return errors.New(s) } server.serviceMap[s.name] = s return nil } func (server *server) sendResponse(sending *sync.Mutex, req *request, reply interface{}, codec serverCodec, errmsg string, last bool) (err error) { resp := server.getResponse() // Encode the response header resp.ServiceMethod = req.ServiceMethod if errmsg != "" { resp.Error = errmsg reply = invalidRequest } resp.Seq = req.Seq sending.Lock() err = codec.WriteResponse(resp, reply, last) sending.Unlock() server.freeResponse(resp) return err } func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec, ct string) { function := mtype.method.Func var returnValues []reflect.Value r := &rpcRequest{ service: server.name, contentType: ct, method: req.ServiceMethod, } if !mtype.stream { r.request = argv.Interface() fn := func(ctx context.Context, req Request, rsp interface{}) error { returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rsp)}) // The return value for the method is an error. if err := returnValues[0].Interface(); err != nil { return err.(error) } return nil } for i := len(server.hdlrWrappers); i > 0; i-- { fn = server.hdlrWrappers[i-1](fn) } errmsg := "" err := fn(ctx, r, replyv.Interface()) if err != nil { errmsg = err.Error() } server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true) server.freeRequest(req) return } // declare a local error to see if we errored out already // keep track of the type, to make sure we return // the same one consistently var lastError error stream := &rpcStream{ context: ctx, codec: codec, request: r, seq: req.Seq, } // Invoke the method, providing a new value for the reply. fn := func(ctx context.Context, req Request, stream interface{}) error { returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) if err := returnValues[0].Interface(); err != nil { // the function returned an error, we use that return err.(error) } else if lastError != nil { // we had an error inside sendReply, we use that return lastError } else { // no error, we send the special EOS error return lastStreamResponseError } } for i := len(server.hdlrWrappers); i > 0; i-- { fn = server.hdlrWrappers[i-1](fn) } // client.Stream request r.stream = true errmsg := "" if err := fn(ctx, r, stream); err != nil { errmsg = err.Error() } // this is the last packet, we don't do anything with // the error here (well sendStreamResponse will log it // already) server.sendResponse(sending, req, nil, codec, errmsg, true) server.freeRequest(req) } func (m *methodType) prepareContext(ctx context.Context) reflect.Value { if contextv := reflect.ValueOf(ctx); contextv.IsValid() { return contextv } return reflect.Zero(m.ContextType) } func (server *server) serveRequest(ctx context.Context, codec serverCodec, ct string) error { sending := new(sync.Mutex) service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) if err != nil { if !keepReading { return err } // send a response if we actually managed to read a header. if req != nil { server.sendResponse(sending, req, invalidRequest, codec, err.Error(), true) server.freeRequest(req) } return err } service.call(ctx, server, sending, mtype, req, argv, replyv, codec, ct) return nil } func (server *server) getRequest() *request { server.reqLock.Lock() req := server.freeReq if req == nil { req = new(request) } else { server.freeReq = req.next *req = request{} } server.reqLock.Unlock() return req } func (server *server) freeRequest(req *request) { server.reqLock.Lock() req.next = server.freeReq server.freeReq = req server.reqLock.Unlock() } func (server *server) getResponse() *response { server.respLock.Lock() resp := server.freeResp if resp == nil { resp = new(response) } else { server.freeResp = resp.next *resp = response{} } server.respLock.Unlock() return resp } func (server *server) freeResponse(resp *response) { server.respLock.Lock() resp.next = server.freeResp server.freeResp = resp server.respLock.Unlock() } func (server *server) readRequest(codec serverCodec) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { service, mtype, req, keepReading, err = server.readRequestHeader(codec) if err != nil { if !keepReading { return } // discard body codec.ReadRequestBody(nil) return } // is it a streaming request? then we don't read the body if mtype.stream { codec.ReadRequestBody(nil) return } // Decode the argument value. argIsValue := false // if true, need to indirect before calling. if mtype.ArgType.Kind() == reflect.Ptr { argv = reflect.New(mtype.ArgType.Elem()) } else { argv = reflect.New(mtype.ArgType) argIsValue = true } // argv guaranteed to be a pointer now. if err = codec.ReadRequestBody(argv.Interface()); err != nil { return } if argIsValue { argv = argv.Elem() } if !mtype.stream { replyv = reflect.New(mtype.ReplyType.Elem()) } return } func (server *server) readRequestHeader(codec serverCodec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { // Grab the request header. req = server.getRequest() err = codec.ReadRequestHeader(req, true) if err != nil { req = nil if err == io.EOF || err == io.ErrUnexpectedEOF { return } err = errors.New("rpc: server cannot decode request: " + err.Error()) return } // We read the header successfully. If we see an error now, // we can still recover and move on to the next request. keepReading = true serviceMethod := strings.Split(req.ServiceMethod, ".") if len(serviceMethod) != 2 { err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) return } // Look up the request. server.mu.Lock() service = server.serviceMap[serviceMethod[0]] server.mu.Unlock() if service == nil { err = errors.New("rpc: can't find service " + req.ServiceMethod) return } mtype = service.method[serviceMethod[1]] if mtype == nil { err = errors.New("rpc: can't find method " + req.ServiceMethod) } return } type serverCodec interface { ReadRequestHeader(*request, bool) error ReadRequestBody(interface{}) error WriteResponse(*response, interface{}, bool) error Close() error }