Rework use of context

This commit is contained in:
Asim 2015-05-23 11:53:40 +01:00
parent d67c1ba111
commit 3db51216b2
16 changed files with 79 additions and 248 deletions

View File

@ -2,14 +2,15 @@ package client
import ( import (
"github.com/myodc/go-micro/transport" "github.com/myodc/go-micro/transport"
"golang.org/x/net/context"
) )
type Client interface { type Client interface {
NewRequest(string, string, interface{}) Request NewRequest(string, string, interface{}) Request
NewProtoRequest(string, string, interface{}) Request NewProtoRequest(string, string, interface{}) Request
NewJsonRequest(string, string, interface{}) Request NewJsonRequest(string, string, interface{}) Request
Call(Request, interface{}) error Call(context.Context, Request, interface{}) error
CallRemote(string, string, Request, interface{}) error CallRemote(context.Context, string, Request, interface{}) error
} }
type options struct { type options struct {
@ -28,12 +29,12 @@ func Transport(t transport.Transport) Option {
} }
} }
func Call(request Request, response interface{}) error { func Call(ctx context.Context, request Request, response interface{}) error {
return DefaultClient.Call(request, response) return DefaultClient.Call(ctx, request, response)
} }
func CallRemote(address, path string, request Request, response interface{}) error { func CallRemote(ctx context.Context, address string, request Request, response interface{}) error {
return DefaultClient.CallRemote(address, path, request, response) return DefaultClient.CallRemote(ctx, address, request, response)
} }
func NewRequest(service, method string, request interface{}) Request { func NewRequest(service, method string, request interface{}) Request {

View File

@ -1,8 +0,0 @@
package client
type Headers interface {
Add(string, string)
Del(string)
Get(string) string
Set(string, string)
}

View File

@ -5,5 +5,4 @@ type Request interface {
Method() string Method() string
ContentType() string ContentType() string
Request() interface{} Request() interface{}
Headers() Headers
} }

View File

@ -7,13 +7,16 @@ import (
"net/http" "net/http"
"time" "time"
c "github.com/myodc/go-micro/context"
"github.com/myodc/go-micro/errors" "github.com/myodc/go-micro/errors"
"github.com/myodc/go-micro/registry" "github.com/myodc/go-micro/registry"
"github.com/myodc/go-micro/transport" "github.com/myodc/go-micro/transport"
rpc "github.com/youtube/vitess/go/rpcplus" rpc "github.com/youtube/vitess/go/rpcplus"
js "github.com/youtube/vitess/go/rpcplus/jsonrpc" js "github.com/youtube/vitess/go/rpcplus/jsonrpc"
pb "github.com/youtube/vitess/go/rpcplus/pbrpc" pb "github.com/youtube/vitess/go/rpcplus/pbrpc"
ctx "golang.org/x/net/context"
"golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -34,14 +37,14 @@ func (t *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)
return t.r.RoundTrip(r) return t.r.RoundTrip(r)
} }
func (r *RpcClient) call(address, path string, request Request, response interface{}) error { func (r *RpcClient) call(ctx context.Context, address string, request Request, response interface{}) error {
switch request.ContentType() { switch request.ContentType() {
case "application/grpc": case "application/grpc":
cc, err := grpc.Dial(address) cc, err := grpc.Dial(address)
if err != nil { if err != nil {
return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error connecting to server: %v", err)) return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error connecting to server: %v", err))
} }
if err := grpc.Invoke(ctx.Background(), path, request.Request(), response, cc); err != nil { if err := grpc.Invoke(ctx, request.Method(), request.Request(), response, cc); err != nil {
return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err))
} }
return nil return nil
@ -77,10 +80,10 @@ func (r *RpcClient) call(address, path string, request Request, response interfa
Body: reqB.Bytes(), Body: reqB.Bytes(),
} }
h, _ := request.Headers().(http.Header) md, ok := c.GetMetaData(ctx)
for k, v := range h { if ok {
if len(v) > 0 { for k, v := range md {
msg.Header[k] = v[0] msg.Header[k] = v
} }
} }
@ -129,12 +132,12 @@ func (r *RpcClient) call(address, path string, request Request, response interfa
return nil return nil
} }
func (r *RpcClient) CallRemote(address, path string, request Request, response interface{}) error { func (r *RpcClient) CallRemote(ctx context.Context, address string, request Request, response interface{}) error {
return r.call(address, path, request, response) return r.call(ctx, address, request, response)
} }
// TODO: Call(..., opts *Options) error { // TODO: Call(..., opts *Options) error {
func (r *RpcClient) Call(request Request, response interface{}) error { func (r *RpcClient) Call(ctx context.Context, request Request, response interface{}) error {
service, err := registry.GetService(request.Service()) service, err := registry.GetService(request.Service())
if err != nil { if err != nil {
return errors.InternalServerError("go.micro.client", err.Error()) return errors.InternalServerError("go.micro.client", err.Error())
@ -152,7 +155,7 @@ func (r *RpcClient) Call(request Request, response interface{}) error {
address = fmt.Sprintf("%s:%d", address, node.Port()) address = fmt.Sprintf("%s:%d", address, node.Port())
} }
return r.call(address, "", request, response) return r.call(ctx, address, request, response)
} }
func (r *RpcClient) NewRequest(service, method string, request interface{}) Request { func (r *RpcClient) NewRequest(service, method string, request interface{}) Request {

View File

@ -1,13 +1,10 @@
package client package client
import (
"net/http"
)
type RpcRequest struct { type RpcRequest struct {
service, method, contentType string service string
request interface{} method string
headers http.Header contentType string
request interface{}
} }
func newRpcRequest(service, method string, request interface{}, contentType string) *RpcRequest { func newRpcRequest(service, method string, request interface{}, contentType string) *RpcRequest {
@ -16,7 +13,6 @@ func newRpcRequest(service, method string, request interface{}, contentType stri
method: method, method: method,
request: request, request: request,
contentType: contentType, contentType: contentType,
headers: make(http.Header),
} }
} }
@ -24,10 +20,6 @@ func (r *RpcRequest) ContentType() string {
return r.contentType return r.contentType
} }
func (r *RpcRequest) Headers() Headers {
return r.headers
}
func (r *RpcRequest) Service() string { func (r *RpcRequest) Service() string {
return r.service return r.service
} }

22
context/context.go Normal file
View File

@ -0,0 +1,22 @@
package context
import (
"golang.org/x/net/context"
)
type key int
const (
mdKey = key(0)
)
type MetaData map[string]string
func GetMetaData(ctx context.Context) (MetaData, bool) {
md, ok := ctx.Value(mdKey).(MetaData)
return md, ok
}
func WithMetaData(ctx context.Context, md MetaData) context.Context {
return context.WithValue(ctx, mdKey, md)
}

View File

@ -5,16 +5,17 @@ import (
h "github.com/grpc/grpc-common/go/helloworld" h "github.com/grpc/grpc-common/go/helloworld"
"github.com/myodc/go-micro/client" "github.com/myodc/go-micro/client"
"golang.org/x/net/context"
) )
// run github.com/grpc/grpc-common/go/greeter_server/main.go // run github.com/grpc/grpc-common/go/greeter_server/main.go
func main() { func main() {
req := client.NewRpcRequest("helloworld.Greeter", "SayHello", &h.HelloRequest{ req := client.NewRpcRequest("helloworld.Greeter", "helloworld.Greeter/SayHello", &h.HelloRequest{
Name: "John", Name: "John",
}, "application/grpc") }, "application/grpc")
rsp := &h.HelloReply{} rsp := &h.HelloReply{}
err := client.CallRemote("localhost:50051", "helloworld.Greeter/SayHello", req, rsp) err := client.CallRemote(context.Background(), "localhost:50051", req, rsp)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }

View File

@ -5,7 +5,9 @@ import (
"github.com/myodc/go-micro/client" "github.com/myodc/go-micro/client"
"github.com/myodc/go-micro/cmd" "github.com/myodc/go-micro/cmd"
c "github.com/myodc/go-micro/context"
example "github.com/myodc/go-micro/template/proto/example" example "github.com/myodc/go-micro/template/proto/example"
"golang.org/x/net/context"
) )
func main() { func main() {
@ -16,14 +18,16 @@ func main() {
Name: "John", Name: "John",
}) })
// Set arbitrary headers // create context with metadata
req.Headers().Set("X-User-Id", "john") ctx := c.WithMetaData(context.Background(), map[string]string{
req.Headers().Set("X-From-Id", "script") "X-User-Id": "john",
"X-From-Id": "script",
})
rsp := &example.Response{} rsp := &example.Response{}
// Call service // Call service
if err := client.Call(req, rsp); err != nil { if err := client.Call(ctx, req, rsp); err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }

View File

@ -1,35 +0,0 @@
package server
import (
"time"
"code.google.com/p/go.net/context"
)
type ctx struct{}
func (ctx *ctx) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
func (ctx *ctx) Done() <-chan struct{} {
return nil
}
func (ctx *ctx) Err() error {
return nil
}
func (ctx *ctx) Value(key interface{}) interface{} {
return nil
}
func newContext(parent context.Context, s *serverContext) context.Context {
return context.WithValue(parent, "serverContext", s)
}
// return server.Context
func NewContext(ctx context.Context) (Context, bool) {
c, ok := ctx.Value("serverContext").(*serverContext)
return c, ok
}

View File

@ -1,8 +0,0 @@
package server
type Headers interface {
Add(string, string)
Del(string)
Get(string) string
Set(string, string)
}

View File

@ -1,6 +0,0 @@
package server
type Request interface {
Headers() Headers
Session(string) string
}

View File

@ -4,11 +4,14 @@ import (
"bytes" "bytes"
"sync" "sync"
log "github.com/golang/glog" c "github.com/myodc/go-micro/context"
"github.com/myodc/go-micro/transport" "github.com/myodc/go-micro/transport"
log "github.com/golang/glog"
rpc "github.com/youtube/vitess/go/rpcplus" rpc "github.com/youtube/vitess/go/rpcplus"
js "github.com/youtube/vitess/go/rpcplus/jsonrpc" js "github.com/youtube/vitess/go/rpcplus/jsonrpc"
pb "github.com/youtube/vitess/go/rpcplus/pbrpc" pb "github.com/youtube/vitess/go/rpcplus/pbrpc"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -26,7 +29,6 @@ var (
) )
func (s *RpcServer) accept(sock transport.Socket) { func (s *RpcServer) accept(sock transport.Socket) {
// serveCtx := getServerContext(req)
var msg transport.Message var msg transport.Message
if err := sock.Recv(&msg); err != nil { if err := sock.Recv(&msg); err != nil {
return return
@ -50,17 +52,21 @@ func (s *RpcServer) accept(sock transport.Socket) {
cc = js.NewServerCodec(buf) cc = js.NewServerCodec(buf)
default: default:
return return
// return nil, errors.InternalServerError("go.micro.server", fmt.Sprintf("Unsupported content-type: %v", req.Header.Get("Content-Type")))
} }
//ctx := newContext(&ctx{}, serveCtx) // strip our headers
if err := s.rpc.ServeRequestWithContext(context.Background(), cc); err != nil { ct := msg.Header["Content-Type"]
delete(msg.Header, "Content-Type")
ctx := c.WithMetaData(context.Background(), msg.Header)
if err := s.rpc.ServeRequestWithContext(ctx, cc); err != nil {
return return
} }
sock.Send(&transport.Message{ sock.Send(&transport.Message{
Header: map[string]string{ Header: map[string]string{
"Content-Type": msg.Header["Content-Type"], "Content-Type": ct,
}, },
Body: rsp.Bytes(), Body: rsp.Bytes(),
}) })

View File

@ -1,120 +0,0 @@
package server
import (
"net/http"
"sync"
log "github.com/golang/glog"
"github.com/myodc/go-micro/client"
)
var ctxs = struct {
sync.Mutex
m map[*http.Request]*serverContext
}{
m: make(map[*http.Request]*serverContext),
}
// A server context interface
type Context interface {
Request() Request // the request made to the server
Headers() Headers // the response headers
NewRequest(string, string, interface{}) client.Request // a new scoped client request
NewProtoRequest(string, string, interface{}) client.Request // a new scoped client request
NewJsonRequest(string, string, interface{}) client.Request // a new scoped client request
}
// context represents the context of an in-flight HTTP request.
// It implements the appengine.Context and http.ResponseWriter interfaces.
type serverContext struct {
req *serverRequest
outCode int
outHeader http.Header
outBody []byte
}
// Copied from $GOROOT/src/pkg/net/http/transfer.go. Some response status
// codes do not permit a response body (nor response entity headers such as
// Content-Length, Content-Type, etc).
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
func getServerContext(req *http.Request) *serverContext {
ctxs.Lock()
c := ctxs.m[req]
ctxs.Unlock()
if c == nil {
// Someone passed in an http.Request that is not in-flight.
panic("NewContext passed an unknown http.Request")
}
return c
}
func (c *serverContext) NewRequest(service, method string, request interface{}) client.Request {
req := client.NewRequest(service, method, request)
// TODO: set headers and scope
req.Headers().Set("X-User-Session", c.Request().Session("X-User-Session"))
return req
}
func (c *serverContext) NewProtoRequest(service, method string, request interface{}) client.Request {
req := client.NewProtoRequest(service, method, request)
// TODO: set headers and scope
req.Headers().Set("X-User-Session", c.Request().Session("X-User-Session"))
return req
}
func (c *serverContext) NewJsonRequest(service, method string, request interface{}) client.Request {
req := client.NewJsonRequest(service, method, request)
// TODO: set headers and scope
req.Headers().Set("X-User-Session", c.Request().Session("X-User-Session"))
return req
}
// The response headers
func (c *serverContext) Headers() Headers {
return c.outHeader
}
// The response headers
func (c *serverContext) Header() http.Header {
return c.outHeader
}
// The request made to the server
func (c *serverContext) Request() Request {
return c.req
}
func (c *serverContext) Write(b []byte) (int, error) {
if c.outCode == 0 {
c.WriteHeader(http.StatusOK)
}
if len(b) > 0 && !bodyAllowedForStatus(c.outCode) {
return 0, http.ErrBodyNotAllowed
}
c.outBody = append(c.outBody, b...)
return len(b), nil
}
func (c *serverContext) WriteHeader(code int) {
if c.outCode != 0 {
log.Error("WriteHeader called multiple times on request.")
return
}
c.outCode = code
}
func GetContext(r *http.Request) *serverContext {
return getServerContext(r)
}

View File

@ -1,25 +0,0 @@
package server
import (
"net/http"
)
type serverRequest struct {
req *http.Request
}
func (s *serverRequest) Headers() Headers {
return s.req.Header
}
func (s *serverRequest) Session(name string) string {
if sess := s.Headers().Get(name); len(sess) > 0 {
return sess
}
c, err := s.req.Cookie(name)
if err != nil {
return ""
}
return c.Value
}

View File

@ -1,17 +1,23 @@
package handler package handler
import ( import (
"code.google.com/p/go.net/context"
log "github.com/golang/glog" log "github.com/golang/glog"
c "github.com/myodc/go-micro/context"
"github.com/myodc/go-micro/server" "github.com/myodc/go-micro/server"
example "github.com/myodc/go-micro/template/proto/example" example "github.com/myodc/go-micro/template/proto/example"
"golang.org/x/net/context"
) )
type Example struct{} type Example struct{}
func (e *Example) Call(ctx context.Context, req *example.Request, rsp *example.Response) error { func (e *Example) Call(ctx context.Context, req *example.Request, rsp *example.Response) error {
log.Info("Received Example.Call request") md, ok := c.GetMetaData(ctx)
if ok {
log.Infof("Received Example.Call request with metadata: %v", md)
} else {
log.Info("Received Example.Call request")
}
rsp.Msg = server.Id + ": Hello " + req.Name rsp.Msg = server.Id + ": Hello " + req.Name

View File

@ -54,7 +54,6 @@ func (h *HttpTransportClient) Send(m *Message) (*Message, error) {
URL: &url.URL{ URL: &url.URL{
Scheme: "http", Scheme: "http",
Host: h.addr, Host: h.addr,
// Path: path,
}, },
Header: header, Header: header,
Body: buf, Body: buf,