Pull rpcplus in local for the time being. We need to modify it for server side middleware
This commit is contained in:
		
							
								
								
									
										490
									
								
								server/rpcplus_server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										490
									
								
								server/rpcplus_server.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,490 @@ | ||||
| // Copyright 2009 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
| // | ||||
| // Meh, we need to get rid of this shit | ||||
| package server | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"unicode" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	lastStreamResponseError = "EOS" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// A value sent as a placeholder for the server's response value when the server | ||||
| 	// receives an invalid request. It is never decoded by the client since the Response | ||||
| 	// contains an error when it is used. | ||||
| 	invalidRequest = struct{}{} | ||||
|  | ||||
| 	// Precompute the reflect type for error. Can't use error directly | ||||
| 	// because Typeof takes an empty interface value. This is annoying. | ||||
| 	typeOfError = reflect.TypeOf((*error)(nil)).Elem() | ||||
| ) | ||||
|  | ||||
| type methodType struct { | ||||
| 	sync.Mutex  // protects counters | ||||
| 	method      reflect.Method | ||||
| 	ArgType     reflect.Type | ||||
| 	ReplyType   reflect.Type | ||||
| 	ContextType reflect.Type | ||||
| 	stream      bool | ||||
| 	numCalls    uint | ||||
| } | ||||
|  | ||||
| func (m *methodType) TakesContext() bool { | ||||
| 	return m.ContextType != nil | ||||
| } | ||||
|  | ||||
| func (m *methodType) NumCalls() (n uint) { | ||||
| 	m.Lock() | ||||
| 	n = m.numCalls | ||||
| 	m.Unlock() | ||||
| 	return n | ||||
| } | ||||
|  | ||||
| type service struct { | ||||
| 	name   string                 // name of service | ||||
| 	rcvr   reflect.Value          // receiver of methods for the service | ||||
| 	typ    reflect.Type           // type of the receiver | ||||
| 	method map[string]*methodType // registered methods | ||||
| } | ||||
|  | ||||
| type request struct { | ||||
| 	ServiceMethod string   // format: "Service.Method" | ||||
| 	Seq           uint64   // sequence number chosen by client | ||||
| 	next          *request // for free list in Server | ||||
| } | ||||
|  | ||||
| type response struct { | ||||
| 	ServiceMethod string    // echoes that of the Request | ||||
| 	Seq           uint64    // echoes that of the request | ||||
| 	Error         string    // error, if any. | ||||
| 	next          *response // for free list in Server | ||||
| } | ||||
|  | ||||
| // server represents an RPC Server. | ||||
| type server struct { | ||||
| 	mu         sync.Mutex // protects the serviceMap | ||||
| 	serviceMap map[string]*service | ||||
| 	reqLock    sync.Mutex // protects freeReq | ||||
| 	freeReq    *request | ||||
| 	respLock   sync.Mutex // protects freeResp | ||||
| 	freeResp   *response | ||||
| } | ||||
|  | ||||
| func newServer() *server { | ||||
| 	return &server{serviceMap: make(map[string]*service)} | ||||
| } | ||||
|  | ||||
| // Is this an exported - upper case - name? | ||||
| func isExported(name string) bool { | ||||
| 	rune, _ := utf8.DecodeRuneInString(name) | ||||
| 	return unicode.IsUpper(rune) | ||||
| } | ||||
|  | ||||
| // Is this type exported or a builtin? | ||||
| func isExportedOrBuiltinType(t reflect.Type) bool { | ||||
| 	for t.Kind() == reflect.Ptr { | ||||
| 		t = t.Elem() | ||||
| 	} | ||||
| 	// PkgPath will be non-empty even for an exported type, | ||||
| 	// so we need to check the type name as well. | ||||
| 	return isExported(t.Name()) || t.PkgPath() == "" | ||||
| } | ||||
|  | ||||
| func (server *server) Register(rcvr interface{}) error { | ||||
| 	return server.register(rcvr, "", false) | ||||
| } | ||||
|  | ||||
| // prepareMethod returns a methodType for the provided method or nil | ||||
| // in case if the method was unsuitable. | ||||
| func prepareMethod(method reflect.Method) *methodType { | ||||
| 	mtype := method.Type | ||||
| 	mname := method.Name | ||||
| 	var replyType, argType, contextType reflect.Type | ||||
|  | ||||
| 	stream := false | ||||
| 	// Method must be exported. | ||||
| 	if method.PkgPath != "" { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	switch mtype.NumIn() { | ||||
| 	case 3: | ||||
| 		// normal method | ||||
| 		argType = mtype.In(1) | ||||
| 		replyType = mtype.In(2) | ||||
| 		contextType = nil | ||||
| 	case 4: | ||||
| 		// method that takes a context | ||||
| 		argType = mtype.In(2) | ||||
| 		replyType = mtype.In(3) | ||||
| 		contextType = mtype.In(1) | ||||
| 	default: | ||||
| 		log.Println("method", mname, "of", mtype, "has wrong number of ins:", mtype.NumIn()) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// First arg need not be a pointer. | ||||
| 	if !isExportedOrBuiltinType(argType) { | ||||
| 		log.Println(mname, "argument type not exported:", argType) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// the second argument will tell us if it's a streaming call | ||||
| 	// or a regular call | ||||
| 	if replyType.Kind() == reflect.Func { | ||||
| 		// this is a streaming call | ||||
| 		stream = true | ||||
| 		if replyType.NumIn() != 1 { | ||||
| 			log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn()) | ||||
| 			return nil | ||||
| 		} | ||||
| 		if replyType.In(0).Kind() != reflect.Interface { | ||||
| 			log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0)) | ||||
| 			return nil | ||||
| 		} | ||||
| 		if replyType.NumOut() != 1 { | ||||
| 			log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut()) | ||||
| 			return nil | ||||
| 		} | ||||
| 		if returnType := replyType.Out(0); returnType != typeOfError { | ||||
| 			log.Println("method", mname, "sendReply returns", returnType.String(), "not error") | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 	} else if replyType.Kind() != reflect.Ptr { | ||||
| 		log.Println("method", mname, "reply type not a pointer:", replyType) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Reply type must be exported. | ||||
| 	if !isExportedOrBuiltinType(replyType) { | ||||
| 		log.Println("method", mname, "reply type not exported:", replyType) | ||||
| 		return nil | ||||
| 	} | ||||
| 	// Method needs one out. | ||||
| 	if mtype.NumOut() != 1 { | ||||
| 		log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) | ||||
| 		return nil | ||||
| 	} | ||||
| 	// The return type of the method must be error. | ||||
| 	if returnType := mtype.Out(0); returnType != typeOfError { | ||||
| 		log.Println("method", mname, "returns", returnType.String(), "not error") | ||||
| 		return nil | ||||
| 	} | ||||
| 	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} | ||||
| } | ||||
|  | ||||
| func (server *server) register(rcvr interface{}, name string, useName bool) error { | ||||
| 	server.mu.Lock() | ||||
| 	defer server.mu.Unlock() | ||||
| 	if server.serviceMap == nil { | ||||
| 		server.serviceMap = make(map[string]*service) | ||||
| 	} | ||||
| 	s := new(service) | ||||
| 	s.typ = reflect.TypeOf(rcvr) | ||||
| 	s.rcvr = reflect.ValueOf(rcvr) | ||||
| 	sname := reflect.Indirect(s.rcvr).Type().Name() | ||||
| 	if useName { | ||||
| 		sname = name | ||||
| 	} | ||||
| 	if sname == "" { | ||||
| 		log.Fatal("rpc: no service name for type", s.typ.String()) | ||||
| 	} | ||||
| 	if !isExported(sname) && !useName { | ||||
| 		s := "rpc Register: type " + sname + " is not exported" | ||||
| 		log.Print(s) | ||||
| 		return errors.New(s) | ||||
| 	} | ||||
| 	if _, present := server.serviceMap[sname]; present { | ||||
| 		return errors.New("rpc: service already defined: " + sname) | ||||
| 	} | ||||
| 	s.name = sname | ||||
| 	s.method = make(map[string]*methodType) | ||||
|  | ||||
| 	// Install the methods | ||||
| 	for m := 0; m < s.typ.NumMethod(); m++ { | ||||
| 		method := s.typ.Method(m) | ||||
| 		if mt := prepareMethod(method); mt != nil { | ||||
| 			s.method[method.Name] = mt | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(s.method) == 0 { | ||||
| 		s := "rpc Register: type " + sname + " has no exported methods of suitable type" | ||||
| 		log.Print(s) | ||||
| 		return errors.New(s) | ||||
| 	} | ||||
| 	server.serviceMap[s.name] = s | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (server *server) sendResponse(sending *sync.Mutex, req *request, reply interface{}, codec serverCodec, errmsg string, last bool) (err error) { | ||||
| 	resp := server.getResponse() | ||||
| 	// Encode the response header | ||||
| 	resp.ServiceMethod = req.ServiceMethod | ||||
| 	if errmsg != "" { | ||||
| 		resp.Error = errmsg | ||||
| 		reply = invalidRequest | ||||
| 	} | ||||
| 	resp.Seq = req.Seq | ||||
| 	sending.Lock() | ||||
| 	err = codec.WriteResponse(resp, reply, last) | ||||
| 	if err != nil { | ||||
| 		log.Println("rpc: writing response:", err) | ||||
| 	} | ||||
| 	sending.Unlock() | ||||
| 	server.freeResponse(resp) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, codec serverCodec) { | ||||
| 	mtype.Lock() | ||||
| 	mtype.numCalls++ | ||||
| 	mtype.Unlock() | ||||
| 	function := mtype.method.Func | ||||
| 	var returnValues []reflect.Value | ||||
|  | ||||
| 	if !mtype.stream { | ||||
|  | ||||
| 		// Invoke the method, providing a new value for the reply. | ||||
| 		if mtype.TakesContext() { | ||||
| 			returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, replyv}) | ||||
| 		} else { | ||||
| 			returnValues = function.Call([]reflect.Value{s.rcvr, argv, replyv}) | ||||
| 		} | ||||
|  | ||||
| 		// The return value for the method is an error. | ||||
| 		errInter := returnValues[0].Interface() | ||||
| 		errmsg := "" | ||||
| 		if errInter != nil { | ||||
| 			errmsg = errInter.(error).Error() | ||||
| 		} | ||||
| 		server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true) | ||||
| 		server.freeRequest(req) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// declare a local error to see if we errored out already | ||||
| 	// keep track of the type, to make sure we return | ||||
| 	// the same one consistently | ||||
| 	var lastError error | ||||
| 	var firstType reflect.Type | ||||
|  | ||||
| 	sendReply := func(oneReply interface{}) error { | ||||
|  | ||||
| 		// we already triggered an error, we're done | ||||
| 		if lastError != nil { | ||||
| 			return lastError | ||||
| 		} | ||||
|  | ||||
| 		// check the oneReply has the right type using reflection | ||||
| 		typ := reflect.TypeOf(oneReply) | ||||
| 		if firstType == nil { | ||||
| 			firstType = typ | ||||
| 		} else { | ||||
| 			if firstType != typ { | ||||
| 				log.Println("passing wrong type to sendReply", | ||||
| 					firstType, "!=", typ) | ||||
| 				lastError = errors.New("rpc: passing wrong type to sendReply") | ||||
| 				return lastError | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		lastError = server.sendResponse(sending, req, oneReply, codec, "", false) | ||||
| 		if lastError != nil { | ||||
| 			return lastError | ||||
| 		} | ||||
|  | ||||
| 		// we manage to send, we're good | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Invoke the method, providing a new value for the reply. | ||||
| 	if mtype.TakesContext() { | ||||
| 		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, reflect.ValueOf(sendReply)}) | ||||
| 	} else { | ||||
| 		returnValues = function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)}) | ||||
| 	} | ||||
| 	errInter := returnValues[0].Interface() | ||||
| 	errmsg := "" | ||||
| 	if errInter != nil { | ||||
| 		// the function returned an error, we use that | ||||
| 		errmsg = errInter.(error).Error() | ||||
| 	} else if lastError != nil { | ||||
| 		// we had an error inside sendReply, we use that | ||||
| 		errmsg = lastError.Error() | ||||
| 	} else { | ||||
| 		// no error, we send the special EOS error | ||||
| 		errmsg = lastStreamResponseError | ||||
| 	} | ||||
|  | ||||
| 	// this is the last packet, we don't do anything with | ||||
| 	// the error here (well sendStreamResponse will log it | ||||
| 	// already) | ||||
| 	server.sendResponse(sending, req, nil, codec, errmsg, true) | ||||
| 	server.freeRequest(req) | ||||
| } | ||||
|  | ||||
| func (m *methodType) prepareContext(ctx context.Context) reflect.Value { | ||||
| 	if contextv := reflect.ValueOf(ctx); contextv.IsValid() { | ||||
| 		return contextv | ||||
| 	} | ||||
| 	return reflect.Zero(m.ContextType) | ||||
| } | ||||
|  | ||||
| func (server *server) ServeRequestWithContext(ctx context.Context, codec serverCodec) error { | ||||
| 	sending := new(sync.Mutex) | ||||
| 	service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) | ||||
| 	if err != nil { | ||||
| 		if !keepReading { | ||||
| 			return err | ||||
| 		} | ||||
| 		// send a response if we actually managed to read a header. | ||||
| 		if req != nil { | ||||
| 			server.sendResponse(sending, req, invalidRequest, codec, err.Error(), true) | ||||
| 			server.freeRequest(req) | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 	service.call(ctx, server, sending, mtype, req, argv, replyv, codec) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (server *server) getRequest() *request { | ||||
| 	server.reqLock.Lock() | ||||
| 	req := server.freeReq | ||||
| 	if req == nil { | ||||
| 		req = new(request) | ||||
| 	} else { | ||||
| 		server.freeReq = req.next | ||||
| 		*req = request{} | ||||
| 	} | ||||
| 	server.reqLock.Unlock() | ||||
| 	return req | ||||
| } | ||||
|  | ||||
| func (server *server) freeRequest(req *request) { | ||||
| 	server.reqLock.Lock() | ||||
| 	req.next = server.freeReq | ||||
| 	server.freeReq = req | ||||
| 	server.reqLock.Unlock() | ||||
| } | ||||
|  | ||||
| func (server *server) getResponse() *response { | ||||
| 	server.respLock.Lock() | ||||
| 	resp := server.freeResp | ||||
| 	if resp == nil { | ||||
| 		resp = new(response) | ||||
| 	} else { | ||||
| 		server.freeResp = resp.next | ||||
| 		*resp = response{} | ||||
| 	} | ||||
| 	server.respLock.Unlock() | ||||
| 	return resp | ||||
| } | ||||
|  | ||||
| func (server *server) freeResponse(resp *response) { | ||||
| 	server.respLock.Lock() | ||||
| 	resp.next = server.freeResp | ||||
| 	server.freeResp = resp | ||||
| 	server.respLock.Unlock() | ||||
| } | ||||
|  | ||||
| func (server *server) readRequest(codec serverCodec) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { | ||||
| 	service, mtype, req, keepReading, err = server.readRequestHeader(codec) | ||||
| 	if err != nil { | ||||
| 		if !keepReading { | ||||
| 			return | ||||
| 		} | ||||
| 		// discard body | ||||
| 		codec.ReadRequestBody(nil) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Decode the argument value. | ||||
| 	argIsValue := false // if true, need to indirect before calling. | ||||
| 	if mtype.ArgType.Kind() == reflect.Ptr { | ||||
| 		argv = reflect.New(mtype.ArgType.Elem()) | ||||
| 	} else { | ||||
| 		argv = reflect.New(mtype.ArgType) | ||||
| 		argIsValue = true | ||||
| 	} | ||||
| 	// argv guaranteed to be a pointer now. | ||||
| 	if err = codec.ReadRequestBody(argv.Interface()); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if argIsValue { | ||||
| 		argv = argv.Elem() | ||||
| 	} | ||||
|  | ||||
| 	if !mtype.stream { | ||||
| 		replyv = reflect.New(mtype.ReplyType.Elem()) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (server *server) readRequestHeader(codec serverCodec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { | ||||
| 	// Grab the request header. | ||||
| 	req = server.getRequest() | ||||
| 	err = codec.ReadRequestHeader(req) | ||||
| 	if err != nil { | ||||
| 		req = nil | ||||
| 		if err == io.EOF || err == io.ErrUnexpectedEOF { | ||||
| 			return | ||||
| 		} | ||||
| 		err = errors.New("rpc: server cannot decode request: " + err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// We read the header successfully. If we see an error now, | ||||
| 	// we can still recover and move on to the next request. | ||||
| 	keepReading = true | ||||
|  | ||||
| 	serviceMethod := strings.Split(req.ServiceMethod, ".") | ||||
| 	if len(serviceMethod) != 2 { | ||||
| 		err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) | ||||
| 		return | ||||
| 	} | ||||
| 	// Look up the request. | ||||
| 	server.mu.Lock() | ||||
| 	service = server.serviceMap[serviceMethod[0]] | ||||
| 	server.mu.Unlock() | ||||
| 	if service == nil { | ||||
| 		err = errors.New("rpc: can't find service " + req.ServiceMethod) | ||||
| 		return | ||||
| 	} | ||||
| 	mtype = service.method[serviceMethod[1]] | ||||
| 	if mtype == nil { | ||||
| 		err = errors.New("rpc: can't find method " + req.ServiceMethod) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // A serverCodec implements reading of RPC requests and writing of | ||||
| // RPC responses for the server side of an RPC session. | ||||
| // The server calls ReadRequestHeader and ReadRequestBody in pairs | ||||
| // to read requests from the connection, and it calls WriteResponse to | ||||
| // write a response back. The server calls Close when finished with the | ||||
| // connection. ReadRequestBody may be called with a nil | ||||
| // argument to force the body of the request to be read and discarded. | ||||
| type serverCodec interface { | ||||
| 	ReadRequestHeader(*request) error | ||||
| 	ReadRequestBody(interface{}) error | ||||
| 	WriteResponse(*response, interface{}, bool) error | ||||
|  | ||||
| 	Close() error | ||||
| } | ||||
		Reference in New Issue
	
	Block a user