Merge branch 'master' into func

This commit is contained in:
Asim Aslam 2017-05-31 19:55:03 +01:00
commit 0f1ec6ea0f
5 changed files with 71 additions and 19 deletions

View File

@ -1,14 +1,14 @@
// Package errors is an interface for defining detailed errors // Package errors provides a way to return detailed information
// for an RPC request error. The error is normally JSON encoded.
package errors package errors
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
) )
// Errors provide a way to return detailed information // Error implements the error interface.
// for an RPC request error. The error is normally
// JSON encoded.
type Error struct { type Error struct {
Id string `json:"id"` Id string `json:"id"`
Code int32 `json:"code"` Code int32 `json:"code"`
@ -21,6 +21,7 @@ func (e *Error) Error() string {
return string(b) return string(b)
} }
// New generates a custom error.
func New(id, detail string, code int32) error { func New(id, detail string, code int32) error {
return &Error{ return &Error{
Id: id, Id: id,
@ -30,6 +31,8 @@ func New(id, detail string, code int32) error {
} }
} }
// Parse tries to parse a JSON string into an error. If that
// fails, it will set the given string as the error detail.
func Parse(err string) *Error { func Parse(err string) *Error {
e := new(Error) e := new(Error)
errr := json.Unmarshal([]byte(err), e) errr := json.Unmarshal([]byte(err), e)
@ -39,47 +42,52 @@ func Parse(err string) *Error {
return e return e
} }
func BadRequest(id, detail string) error { // BadRequest generates a 400 error.
func BadRequest(id, format string, a ...interface{}) error {
return &Error{ return &Error{
Id: id, Id: id,
Code: 400, Code: 400,
Detail: detail, Detail: fmt.Sprintf(format, a...),
Status: http.StatusText(400), Status: http.StatusText(400),
} }
} }
func Unauthorized(id, detail string) error { // Unauthorized generates a 401 error.
func Unauthorized(id, format string, a ...interface{}) error {
return &Error{ return &Error{
Id: id, Id: id,
Code: 401, Code: 401,
Detail: detail, Detail: fmt.Sprintf(format, a...),
Status: http.StatusText(401), Status: http.StatusText(401),
} }
} }
func Forbidden(id, detail string) error { // Forbidden generates a 403 error.
func Forbidden(id, format string, a ...interface{}) error {
return &Error{ return &Error{
Id: id, Id: id,
Code: 403, Code: 403,
Detail: detail, Detail: fmt.Sprintf(format, a...),
Status: http.StatusText(403), Status: http.StatusText(403),
} }
} }
func NotFound(id, detail string) error { // NotFound generates a 404 error.
func NotFound(id, format string, a ...interface{}) error {
return &Error{ return &Error{
Id: id, Id: id,
Code: 404, Code: 404,
Detail: detail, Detail: fmt.Sprintf(format, a...),
Status: http.StatusText(404), Status: http.StatusText(404),
} }
} }
func InternalServerError(id, detail string) error { // InternalServerError generates a 500 error.
func InternalServerError(id, format string, a ...interface{}) error {
return &Error{ return &Error{
Id: id, Id: id,
Code: 500, Code: 500,
Detail: detail, Detail: fmt.Sprintf(format, a...),
Status: http.StatusText(500), Status: http.StatusText(500),
} }
} }

View File

@ -6,6 +6,17 @@ import (
type serverKey struct{} type serverKey struct{}
func wait(ctx context.Context) bool {
if ctx == nil {
return false
}
wait, ok := ctx.Value("wait").(bool)
if !ok {
return false
}
return wait
}
func FromContext(ctx context.Context) (Server, bool) { func FromContext(ctx context.Context) (Server, bool) {
c, ok := ctx.Value(serverKey{}).(Server) c, ok := ctx.Value(serverKey{}).(Server)
return c, ok return c, ok

View File

@ -165,6 +165,16 @@ func RegisterTTL(t time.Duration) Option {
} }
} }
// Wait tells the server to wait for requests to finish before exiting
func Wait(b bool) Option {
return func(o *Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, "wait", b)
}
}
// Adds a handler Wrapper to a list of options passed into the server // Adds a handler Wrapper to a list of options passed into the server
func WrapHandler(w HandlerWrapper) Option { func WrapHandler(w HandlerWrapper) Option {
return func(o *Options) { return func(o *Options) {

View File

@ -30,6 +30,8 @@ type rpcServer struct {
subscribers map[*subscriber][]broker.Subscriber subscribers map[*subscriber][]broker.Subscriber
// used for first registration // used for first registration
registered bool registered bool
// graceful exit
wg sync.WaitGroup
} }
func newRpcServer(opts ...Option) Server { func newRpcServer(opts ...Option) Server {
@ -100,11 +102,18 @@ func (s *rpcServer) accept(sock transport.Socket) {
} }
} }
// add to wait group
s.wg.Add(1)
// TODO: needs better error handling // TODO: needs better error handling
if err := s.rpc.serveRequest(ctx, codec, ct); err != nil { if err := s.rpc.serveRequest(ctx, codec, ct); err != nil {
log.Logf("Unexpected error serving request, closing socket: %v", err) log.Logf("Unexpected error serving request, closing socket: %v", err)
s.wg.Done()
return return
} }
// finish request
s.wg.Done()
} }
} }
@ -371,8 +380,18 @@ func (s *rpcServer) Start() error {
go ts.Accept(s.accept) go ts.Accept(s.accept)
go func() { go func() {
// wait for exit
ch := <-s.exit ch := <-s.exit
// wait for requests to finish
if wait(s.opts.Context) {
s.wg.Wait()
}
// close transport listener
ch <- ts.Close() ch <- ts.Close()
// disconnect the broker
config.Broker.Disconnect() config.Broker.Disconnect()
}() }()

View File

@ -226,11 +226,15 @@ func (s *rpcServer) createSubHandler(sb *subscriber, opts Options) broker.Handle
fn = opts.SubWrappers[i-1](fn) fn = opts.SubWrappers[i-1](fn)
} }
go fn(ctx, &rpcPublication{ s.wg.Add(1)
go func() {
fn(ctx, &rpcPublication{
topic: sb.topic, topic: sb.topic,
contentType: ct, contentType: ct,
message: req.Interface(), message: req.Interface(),
}) })
s.wg.Done()
}()
} }
return nil return nil
} }