diff --git a/broker/broker.go b/broker/broker.go index 369a5352..53b15d4e 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -20,6 +20,8 @@ var ( ErrDisconnected = errors.New("broker disconnected") // ErrInvalidMessage returns when invalid Message passed ErrInvalidMessage = errors.New("invalid message") + // ErrInvalidHandler returns when subscriber passed to Subscribe + ErrInvalidHandler = errors.New("invalid handler") // DefaultGracefulTimeout DefaultGracefulTimeout = 5 * time.Second ) @@ -87,9 +89,3 @@ type Subscriber interface { // Unsubscribe from topic Unsubscribe(ctx context.Context) error } - -// MessageHandler func signature for single message processing -type MessageHandler func(Message) error - -// MessagesHandler func signature for batch message processing -type MessagesHandler func([]Message) error diff --git a/broker/memory/memory.go b/broker/memory/memory.go index 1ca50e87..86430a4c 100644 --- a/broker/memory/memory.go +++ b/broker/memory/memory.go @@ -2,6 +2,7 @@ package broker import ( "context" + "strings" "sync" "go.unistack.org/micro/v4/broker" @@ -34,6 +35,30 @@ type memoryMessage struct { opts broker.PublishOptions } +func (m *memoryMessage) Ack() error { + return nil +} + +func (m *memoryMessage) Body() []byte { + return m.body +} + +func (m *memoryMessage) Header() metadata.Metadata { + return m.hdr +} + +func (m *memoryMessage) Context() context.Context { + return m.ctx +} + +func (m *memoryMessage) Topic() string { + return "" +} + +func (m *memoryMessage) Unmarshal(dst interface{}, opts ...codec.Option) error { + return m.c.Unmarshal(m.body, dst) +} + type Subscriber struct { ctx context.Context exit chan bool @@ -43,25 +68,38 @@ type Subscriber struct { opts broker.SubscribeOptions } -func (m *Broker) Options() broker.Options { - return m.opts +func (b *Broker) newCodec(ct string) (codec.Codec, error) { + if idx := strings.IndexRune(ct, ';'); idx >= 0 { + ct = ct[:idx] + } + b.RLock() + c, ok := b.opts.Codecs[ct] + b.RUnlock() + if ok { + return c, nil + } + return nil, codec.ErrUnknownContentType } -func (m *Broker) Address() string { - return m.addr +func (b *Broker) Options() broker.Options { + return b.opts } -func (m *Broker) Connect(ctx context.Context) error { +func (b *Broker) Address() string { + return b.addr +} + +func (b *Broker) Connect(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() default: } - m.Lock() - defer m.Unlock() + b.Lock() + defer b.Unlock() - if m.connected { + if b.connected { return nil } @@ -75,65 +113,79 @@ func (m *Broker) Connect(ctx context.Context) error { // set addr with port addr = mnet.HostPort(addr, 10000+i) - m.addr = addr - m.connected = true + b.addr = addr + b.connected = true return nil } -func (m *Broker) Disconnect(ctx context.Context) error { +func (b *Broker) Disconnect(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() default: } - m.Lock() - defer m.Unlock() + b.Lock() + defer b.Unlock() - if !m.connected { + if !b.connected { return nil } - m.connected = false + b.connected = false return nil } -func (m *Broker) Init(opts ...broker.Option) error { +func (b *Broker) Init(opts ...broker.Option) error { for _, o := range opts { - o(&m.opts) + o(&b.opts) } - m.funcPublish = m.fnPublish - m.funcSubscribe = m.fnSubscribe + b.funcPublish = b.fnPublish + b.funcSubscribe = b.fnSubscribe - m.opts.Hooks.EachPrev(func(hook options.Hook) { + b.opts.Hooks.EachPrev(func(hook options.Hook) { switch h := hook.(type) { case broker.HookPublish: - m.funcPublish = h(m.funcPublish) + b.funcPublish = h(b.funcPublish) case broker.HookSubscribe: - m.funcSubscribe = h(m.funcSubscribe) + b.funcSubscribe = h(b.funcSubscribe) } }) return nil } -func (m *Broker) Publish(ctx context.Context, topic string, messages ...broker.Message) error { - return m.funcPublish(ctx, topic, messages...) +func (b *Broker) NewMessage(ctx context.Context, hdr metadata.Metadata, body interface{}, opts ...broker.PublishOption) (broker.Message, error) { + options := broker.NewPublishOptions(opts...) + m := &memoryMessage{ctx: ctx, hdr: hdr, opts: options} + c, err := b.newCodec(m.opts.ContentType) + if err == nil { + m.body, err = c.Marshal(body) + } + if err != nil { + return nil, err + } + + return m, nil } -func (m *Broker) fnPublish(ctx context.Context, topic string, messages ...broker.Message) error { - return m.publish(ctx, topic, messages...) +func (b *Broker) Publish(ctx context.Context, topic string, messages ...broker.Message) error { + return b.funcPublish(ctx, topic, messages...) } -func (m *Broker) publish(ctx context.Context, topic string, messages ...broker.Message) error { - m.RLock() - if !m.connected { - m.RUnlock() +func (b *Broker) fnPublish(ctx context.Context, topic string, messages ...broker.Message) error { + return b.publish(ctx, topic, messages...) +} + +func (b *Broker) publish(ctx context.Context, topic string, messages ...broker.Message) error { + b.RLock() + if !b.connected { + b.RUnlock() return broker.ErrNotConnected } - m.RUnlock() + b.RUnlock() select { case <-ctx.Done(): @@ -141,9 +193,9 @@ func (m *Broker) publish(ctx context.Context, topic string, messages ...broker.M default: } - m.RLock() - subs, ok := m.subscribers[topic] - m.RUnlock() + b.RLock() + subs, ok := b.subscribers[topic] + b.RUnlock() if !ok { return nil } @@ -152,24 +204,28 @@ func (m *Broker) publish(ctx context.Context, topic string, messages ...broker.M for _, sub := range subs { switch s := sub.handler.(type) { - case broker.MessageHandler: + default: + if b.opts.Logger.V(logger.ErrorLevel) { + b.opts.Logger.Error(ctx, "broker handler error", broker.ErrInvalidHandler) + } + case func(broker.Message) error: for _, message := range messages { if err = s(message); err == nil && sub.opts.AutoAck { err = message.Ack() } if err != nil { - if m.opts.Logger.V(logger.ErrorLevel) { - m.opts.Logger.Error(m.opts.Context, "broker handler error", err) + if b.opts.Logger.V(logger.ErrorLevel) { + b.opts.Logger.Error(ctx, "broker handler error", err) } } } - case broker.MessagesHandler: + case func([]broker.Message) error: if err = s(messages); err == nil && sub.opts.AutoAck { for _, message := range messages { err = message.Ack() if err != nil { - if m.opts.Logger.V(logger.ErrorLevel) { - m.opts.Logger.Error(m.opts.Context, "broker handler error", err) + if b.opts.Logger.V(logger.ErrorLevel) { + b.opts.Logger.Error(ctx, "broker handler error", err) } } } @@ -180,17 +236,21 @@ func (m *Broker) publish(ctx context.Context, topic string, messages ...broker.M return nil } -func (m *Broker) Subscribe(ctx context.Context, topic string, handler interface{}, opts ...broker.SubscribeOption) (broker.Subscriber, error) { - return m.funcSubscribe(ctx, topic, handler, opts...) +func (b *Broker) Subscribe(ctx context.Context, topic string, handler interface{}, opts ...broker.SubscribeOption) (broker.Subscriber, error) { + return b.funcSubscribe(ctx, topic, handler, opts...) } -func (m *Broker) fnSubscribe(ctx context.Context, topic string, handler interface{}, opts ...broker.SubscribeOption) (broker.Subscriber, error) { - m.RLock() - if !m.connected { - m.RUnlock() +func (b *Broker) fnSubscribe(ctx context.Context, topic string, handler interface{}, opts ...broker.SubscribeOption) (broker.Subscriber, error) { + if err := broker.IsValidHandler(handler); err != nil { + return nil, err + } + + b.RLock() + if !b.connected { + b.RUnlock() return nil, broker.ErrNotConnected } - m.RUnlock() + b.RUnlock() sid, err := id.New() if err != nil { @@ -208,63 +268,47 @@ func (m *Broker) fnSubscribe(ctx context.Context, topic string, handler interfac ctx: ctx, } - m.Lock() - m.subscribers[topic] = append(m.subscribers[topic], sub) - m.Unlock() + b.Lock() + b.subscribers[topic] = append(b.subscribers[topic], sub) + b.Unlock() go func() { <-sub.exit - m.Lock() - newSubscribers := make([]*Subscriber, 0, len(m.subscribers)-1) - for _, sb := range m.subscribers[topic] { + b.Lock() + newSubscribers := make([]*Subscriber, 0, len(b.subscribers)-1) + for _, sb := range b.subscribers[topic] { if sb.id == sub.id { continue } newSubscribers = append(newSubscribers, sb) } - m.subscribers[topic] = newSubscribers - m.Unlock() + b.subscribers[topic] = newSubscribers + b.Unlock() }() return sub, nil } -func (m *Broker) String() string { +func (b *Broker) String() string { return "memory" } -func (m *Broker) Name() string { - return m.opts.Name +func (b *Broker) Name() string { + return b.opts.Name } -func (m *Broker) Live() bool { +func (b *Broker) Live() bool { return true } -func (m *Broker) Ready() bool { +func (b *Broker) Ready() bool { return true } -func (m *Broker) Health() bool { +func (b *Broker) Health() bool { return true } -func (m *memoryMessage) Topic() string { - return m.topic -} - -func (m *memoryMessage) Body() []byte { - return m.body -} - -func (m *memoryMessage) Ack() error { - return nil -} - -func (m *memoryMessage) Context() context.Context { - return m.ctx -} - func (m *Subscriber) Options() broker.SubscribeOptions { return m.opts } diff --git a/broker/memory/memory_test.go b/broker/memory/memory_test.go index 936c3d7d..2d2fd4de 100644 --- a/broker/memory/memory_test.go +++ b/broker/memory/memory_test.go @@ -5,12 +5,23 @@ import ( "fmt" "testing" + "go.uber.org/atomic" "go.unistack.org/micro/v4/broker" + "go.unistack.org/micro/v4/codec" "go.unistack.org/micro/v4/metadata" ) +type hldr struct { + c atomic.Int64 +} + +func (h *hldr) Handler(m broker.Message) error { + h.c.Add(1) + return nil +} + func TestMemoryBroker(t *testing.T) { - b := NewBroker() + b := NewBroker(broker.Codec("application/octet-stream", codec.NewCodec())) ctx := context.Background() if err := b.Init(); err != nil { @@ -22,28 +33,27 @@ func TestMemoryBroker(t *testing.T) { } topic := "test" - count := 10 + count := int64(10) - fn := func(_ broker.Message) error { - return nil - } + h := &hldr{} - sub, err := b.Subscribe(ctx, topic, fn) + sub, err := b.Subscribe(ctx, topic, h.Handler) if err != nil { t.Fatalf("Unexpected error subscribing %v", err) } - msgs := make([]*broker.Message, 0, count) - for i := 0; i < count; i++ { - message, err := b.NewMessage(ctx, metadata.Pairs() - Header: map[string]string{ - metadata.HeaderTopic: topic, - "foo": "bar", - "id": fmt.Sprintf("%d", i), - }, + for i := int64(0); i < count; i++ { + message, err := b.NewMessage(ctx, + metadata.Pairs( + "foo", "bar", + "id", fmt.Sprintf("%d", i), + ), []byte(`"hello world"`), + broker.PublishContentType("application/octet-stream"), + ) + if err != nil { + t.Fatal(err) } - msgs = append(msgs, message) if err := b.Publish(ctx, topic, message); err != nil { t.Fatalf("Unexpected error publishing %d err: %v", i, err) @@ -57,4 +67,8 @@ func TestMemoryBroker(t *testing.T) { if err := b.Disconnect(ctx); err != nil { t.Fatalf("Unexpected connect error %v", err) } + + if h.c.Load() != count { + t.Fatal("invalid messages count received") + } } diff --git a/broker/noop.go b/broker/noop.go index c3e84e18..4d160f3d 100644 --- a/broker/noop.go +++ b/broker/noop.go @@ -118,10 +118,6 @@ func (m *noopMessage) Context() context.Context { return m.ctx } -func (m *noopMessage) Error() error { - return nil -} - func (m *noopMessage) Topic() string { return "" } diff --git a/broker/options.go b/broker/options.go index a5894794..a9900921 100644 --- a/broker/options.go +++ b/broker/options.go @@ -133,8 +133,8 @@ func Addrs(addrs ...string) Option { } } -// Codecs sets the codec used for encoding/decoding messages -func Codecs(ct string, c codec.Codec) Option { +// Codec sets the codec used for encoding/decoding messages +func Codec(ct string, c codec.Codec) Option { return func(o *Options) { o.Codecs[ct] = c } diff --git a/broker/subscriber.go b/broker/subscriber.go index 3470a5cd..bdced69c 100644 --- a/broker/subscriber.go +++ b/broker/subscriber.go @@ -1,6 +1,4 @@ -//go:build ignore - -package server +package broker import ( "fmt" @@ -10,7 +8,8 @@ import ( ) const ( - subSig = "func(context.Context, interface{}) error" + messageSig = "func(broker.Message) error" + messagesSig = "func([]broker.Message) error" ) // Precompute the reflect type for error. Can't use error directly @@ -33,31 +32,31 @@ func isExportedOrBuiltinType(t reflect.Type) bool { return isExported(t.Name()) || t.PkgPath() == "" } -// ValidateSubscriber func signature -func ValidateSubscriber(sub Subscriber) error { - typ := reflect.TypeOf(sub.Subscriber()) +// IsValidHandler func signature +func IsValidHandler(sub interface{}) error { + typ := reflect.TypeOf(sub) var argType reflect.Type switch typ.Kind() { case reflect.Func: name := "Func" switch typ.NumIn() { - case 2: - argType = typ.In(1) + case 1: + argType = typ.In(0) default: - return fmt.Errorf("subscriber %v takes wrong number of args: %v required signature %s", name, typ.NumIn(), subSig) + return fmt.Errorf("subscriber %v takes wrong number of args: %v required signature %s", name, typ.NumIn(), messageSig) } if !isExportedOrBuiltinType(argType) { return fmt.Errorf("subscriber %v argument type not exported: %v", name, argType) } if typ.NumOut() != 1 { return fmt.Errorf("subscriber %v has wrong number of return values: %v require signature %s", - name, typ.NumOut(), subSig) + name, typ.NumOut(), messageSig) } if returnType := typ.Out(0); returnType != typeOfError { return fmt.Errorf("subscriber %v returns %v not error", name, returnType.String()) } default: - hdlr := reflect.ValueOf(sub.Subscriber()) + hdlr := reflect.ValueOf(sub) name := reflect.Indirect(hdlr).Type().Name() for m := 0; m < typ.NumMethod(); m++ { @@ -67,7 +66,7 @@ func ValidateSubscriber(sub Subscriber) error { argType = method.Type.In(2) default: return fmt.Errorf("subscriber %v.%v takes wrong number of args: %v required signature %s", - name, method.Name, method.Type.NumIn(), subSig) + name, method.Name, method.Type.NumIn(), messageSig) } if !isExportedOrBuiltinType(argType) { @@ -76,7 +75,7 @@ func ValidateSubscriber(sub Subscriber) error { if method.Type.NumOut() != 1 { return fmt.Errorf( "subscriber %v.%v has wrong number of return values: %v require signature %s", - name, method.Name, method.Type.NumOut(), subSig) + name, method.Name, method.Type.NumOut(), messageSig) } if returnType := method.Type.Out(0); returnType != typeOfError { return fmt.Errorf("subscriber %v.%v returns %v not error", name, method.Name, returnType.String()) diff --git a/logger/slog/slog_test.go b/logger/slog/slog_test.go index 18a24297..e8e6d353 100644 --- a/logger/slog/slog_test.go +++ b/logger/slog/slog_test.go @@ -412,9 +412,10 @@ func Test_WithContextAttrFunc(t *testing.T) { } attrs := make([]interface{}, 0, 10) for k, v := range md { - switch k { - case "X-Request-Id", "Phone", "External-Id", "Source-Service", "X-App-Install-Id", "Client-Id", "Client-Ip": - attrs = append(attrs, strings.ToLower(k), v) + key := strings.ToLower(k) + switch key { + case "x-request-id", "phone", "external-Id", "source-service", "x-app-install-id", "client-id", "client-ip": + attrs = append(attrs, key, v[0]) } } return attrs diff --git a/metadata/context.go b/metadata/context.go deleted file mode 100644 index 471b144a..00000000 --- a/metadata/context.go +++ /dev/null @@ -1,181 +0,0 @@ -//go:build !exclude - -// Package metadata is a way of defining message headers -package metadata - -import ( - "context" -) - -type ( - mdIncomingKey struct{} - mdOutgoingKey struct{} - mdKey struct{} -) - -// FromIncomingContext returns metadata from incoming ctx -// returned metadata shoud not be modified or race condition happens -func FromIncomingContext(ctx context.Context) (Metadata, bool) { - if ctx == nil { - return nil, false - } - md, ok := ctx.Value(mdIncomingKey{}).(*rawMetadata) - if !ok || md.md == nil { - return nil, false - } - return md.md, ok -} - -// MustIncomingContext returns metadata from incoming ctx -// returned metadata shoud not be modified or race condition happens. -// If metadata not exists panics. -func MustIncomingContext(ctx context.Context) Metadata { - md, ok := FromIncomingContext(ctx) - if !ok { - panic("missing metadata") - } - return md -} - -// FromOutgoingContext returns metadata from outgoing ctx -// returned metadata shoud not be modified or race condition happens -func FromOutgoingContext(ctx context.Context) (Metadata, bool) { - if ctx == nil { - return nil, false - } - md, ok := ctx.Value(mdOutgoingKey{}).(*rawMetadata) - if !ok || md.md == nil { - return nil, false - } - return md.md, ok -} - -// MustOutgoingContext returns metadata from outgoing ctx -// returned metadata shoud not be modified or race condition happens. -// If metadata not exists panics. -func MustOutgoingContext(ctx context.Context) Metadata { - md, ok := FromOutgoingContext(ctx) - if !ok { - panic("missing metadata") - } - return md -} - -// FromContext returns metadata from the given context -// returned metadata shoud not be modified or race condition happens -func FromContext(ctx context.Context) (Metadata, bool) { - if ctx == nil { - return nil, false - } - md, ok := ctx.Value(mdKey{}).(*rawMetadata) - if !ok || md.md == nil { - return nil, false - } - return md.md, ok -} - -// MustContext returns metadata from the given context -// returned metadata shoud not be modified or race condition happens -func MustContext(ctx context.Context) Metadata { - md, ok := FromContext(ctx) - if !ok { - panic("missing metadata") - } - return md -} - -// NewContext creates a new context with the given metadata -func NewContext(ctx context.Context, md Metadata) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, mdKey{}, &rawMetadata{md}) -} - -// SetOutgoingContext modify outgoing context with given metadata -func SetOutgoingContext(ctx context.Context, md Metadata) bool { - if ctx == nil { - return false - } - if omd, ok := ctx.Value(mdOutgoingKey{}).(*rawMetadata); ok { - omd.md = md - return true - } - return false -} - -// SetIncomingContext modify incoming context with given metadata -func SetIncomingContext(ctx context.Context, md Metadata) bool { - if ctx == nil { - return false - } - if omd, ok := ctx.Value(mdIncomingKey{}).(*rawMetadata); ok { - omd.md = md - return true - } - return false -} - -// NewIncomingContext creates a new context with incoming metadata attached -func NewIncomingContext(ctx context.Context, md Metadata) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, mdIncomingKey{}, &rawMetadata{md}) -} - -// NewOutgoingContext creates a new context with outcoming metadata attached -func NewOutgoingContext(ctx context.Context, md Metadata) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, mdOutgoingKey{}, &rawMetadata{md}) -} - -// AppendOutgoingContext apends new md to context -func AppendOutgoingContext(ctx context.Context, kv ...string) context.Context { - md, ok := Pairs(kv...) - if !ok { - return ctx - } - omd, ok := FromOutgoingContext(ctx) - if !ok { - return NewOutgoingContext(ctx, md) - } - for k, v := range md { - omd.Set(k, v) - } - return ctx -} - -// AppendIncomingContext apends new md to context -func AppendIncomingContext(ctx context.Context, kv ...string) context.Context { - md, ok := Pairs(kv...) - if !ok { - return ctx - } - omd, ok := FromIncomingContext(ctx) - if !ok { - return NewIncomingContext(ctx, md) - } - for k, v := range md { - omd.Set(k, v) - } - return ctx -} - -// AppendContext apends new md to context -func AppendContext(ctx context.Context, kv ...string) context.Context { - md, ok := Pairs(kv...) - if !ok { - return ctx - } - omd, ok := FromContext(ctx) - if !ok { - return NewContext(ctx, md) - } - for k, v := range md { - omd.Set(k, v) - } - return ctx -} diff --git a/metadata/context_test.go b/metadata/context_test.go deleted file mode 100644 index deaa020a..00000000 --- a/metadata/context_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package metadata - -import ( - "context" - "testing" -) - -func TestFromNilContext(t *testing.T) { - // nolint: staticcheck - c, ok := FromContext(nil) - if ok || c != nil { - t.Fatal("FromContext not works") - } -} - -func TestNewNilContext(t *testing.T) { - // nolint: staticcheck - ctx := NewContext(nil, New(0)) - - c, ok := FromContext(ctx) - if c == nil || !ok { - t.Fatal("NewContext not works") - } -} - -func TestFromContext(t *testing.T) { - ctx := context.WithValue(context.TODO(), mdKey{}, &rawMetadata{New(0)}) - - c, ok := FromContext(ctx) - if c == nil || !ok { - t.Fatal("FromContext not works") - } -} - -func TestNewContext(t *testing.T) { - ctx := NewContext(context.TODO(), New(0)) - - c, ok := FromContext(ctx) - if c == nil || !ok { - t.Fatal("NewContext not works") - } -} - -func TestFromIncomingContext(t *testing.T) { - ctx := context.WithValue(context.TODO(), mdIncomingKey{}, &rawMetadata{New(0)}) - - c, ok := FromIncomingContext(ctx) - if c == nil || !ok { - t.Fatal("FromIncomingContext not works") - } -} - -func TestFromOutgoingContext(t *testing.T) { - ctx := context.WithValue(context.TODO(), mdOutgoingKey{}, &rawMetadata{New(0)}) - - c, ok := FromOutgoingContext(ctx) - if c == nil || !ok { - t.Fatal("FromOutgoingContext not works") - } -} - -func TestSetIncomingContext(t *testing.T) { - md := New(1) - md.Set("key", "val") - ctx := context.WithValue(context.TODO(), mdIncomingKey{}, &rawMetadata{}) - if !SetIncomingContext(ctx, md) { - t.Fatal("SetIncomingContext not works") - } - md, ok := FromIncomingContext(ctx) - if md == nil || !ok { - t.Fatal("SetIncomingContext not works") - } else if v, ok := md.Get("key"); !ok || v != "val" { - t.Fatal("SetIncomingContext not works") - } -} - -func TestSetOutgoingContext(t *testing.T) { - md := New(1) - md.Set("key", "val") - ctx := context.WithValue(context.TODO(), mdOutgoingKey{}, &rawMetadata{}) - if !SetOutgoingContext(ctx, md) { - t.Fatal("SetOutgoingContext not works") - } - md, ok := FromOutgoingContext(ctx) - if md == nil || !ok { - t.Fatal("SetOutgoingContext not works") - } else if v, ok := md.Get("key"); !ok || v != "val" { - t.Fatal("SetOutgoingContext not works") - } -} - -func TestNewIncomingContext(t *testing.T) { - md := New(1) - md.Set("key", "val") - ctx := NewIncomingContext(context.TODO(), md) - - c, ok := FromIncomingContext(ctx) - if c == nil || !ok { - t.Fatal("NewIncomingContext not works") - } -} - -func TestNewOutgoingContext(t *testing.T) { - md := New(1) - md.Set("key", "val") - ctx := NewOutgoingContext(context.TODO(), md) - - c, ok := FromOutgoingContext(ctx) - if c == nil || !ok { - t.Fatal("NewOutgoingContext not works") - } -} - -func TestAppendIncomingContext(t *testing.T) { - md := New(1) - md.Set("key1", "val1") - ctx := AppendIncomingContext(context.TODO(), "key2", "val2") - - nmd, ok := FromIncomingContext(ctx) - if nmd == nil || !ok { - t.Fatal("AppendIncomingContext not works") - } - if v, ok := nmd.Get("key2"); !ok || v != "val2" { - t.Fatal("AppendIncomingContext not works") - } -} - -func TestAppendOutgoingContext(t *testing.T) { - md := New(1) - md.Set("key1", "val1") - ctx := AppendOutgoingContext(context.TODO(), "key2", "val2") - - nmd, ok := FromOutgoingContext(ctx) - if nmd == nil || !ok { - t.Fatal("AppendOutgoingContext not works") - } - if v, ok := nmd.Get("key2"); !ok || v != "val2" { - t.Fatal("AppendOutgoingContext not works") - } -} diff --git a/metadata/headers.go b/metadata/headers.go new file mode 100644 index 00000000..7e670eee --- /dev/null +++ b/metadata/headers.go @@ -0,0 +1,19 @@ +// Package metadata is a way of defining message headers +package metadata + +var ( + // HeaderTopic is the header name that contains topic name + HeaderTopic = "Micro-Topic" + // HeaderContentType specifies content type of message + HeaderContentType = "Content-Type" + // HeaderEndpoint specifies endpoint in service + HeaderEndpoint = "Micro-Endpoint" + // HeaderService specifies service + HeaderService = "Micro-Service" + // HeaderTimeout specifies timeout of operation + HeaderTimeout = "Micro-Timeout" + // HeaderAuthorization specifies Authorization header + HeaderAuthorization = "Authorization" + // HeaderXRequestID specifies request id + HeaderXRequestID = "X-Request-Id" +) diff --git a/metadata/metadata.go b/metadata/metadata.go index 4f67efbf..7664d0ea 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -1,43 +1,461 @@ -//go:build !exclude - -// Package metadata is a way of defining message headers package metadata import ( + "context" + "fmt" "net/textproto" - "sort" "strings" ) -var ( - // HeaderTopic is the header name that contains topic name - HeaderTopic = "Micro-Topic" - // HeaderContentType specifies content type of message - HeaderContentType = "Content-Type" - // HeaderEndpoint specifies endpoint in service - HeaderEndpoint = "Micro-Endpoint" - // HeaderService specifies service - HeaderService = "Micro-Service" - // HeaderTimeout specifies timeout of operation - HeaderTimeout = "Micro-Timeout" - // HeaderAuthorization specifies Authorization header - HeaderAuthorization = "Authorization" - // HeaderXRequestID specifies request id - HeaderXRequestID = "X-Request-Id" -) - -// Metadata is our way of representing request headers internally. -// They're used at the RPC level and translate back and forth -// from Transport headers. -type Metadata map[string]string - -type rawMetadata struct { - md Metadata -} - // defaultMetadataSize used when need to init new Metadata var defaultMetadataSize = 2 +// Metadata is a mapping from metadata keys to values. Users should use the following +// two convenience functions New and Pairs to generate Metadata. +type Metadata map[string][]string + +// New creates an zero Metadata. +func New(l int) Metadata { + if l == 0 { + l = defaultMetadataSize + } + md := make(Metadata, l) + return md +} + +// NewWithMetadata creates an Metadata from a given key-value map. +func NewWithMetadata(m map[string]string) Metadata { + md := make(Metadata, len(m)) + for key, val := range m { + md[key] = append(md[key], val) + } + return md +} + +// Pairs returns an Metadata formed by the mapping of key, value ... +// Pairs panics if len(kv) is odd. +func Pairs(kv ...string) Metadata { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) + } + md := make(Metadata, len(kv)/2) + for i := 0; i < len(kv); i += 2 { + md[kv[i]] = append(md[kv[i]], kv[i+1]) + } + return md +} + +// Len returns the number of items in Metadata. +func (md Metadata) Len() int { + return len(md) +} + +// Copy returns a copy of Metadata. +func Copy(src Metadata) Metadata { + out := make(Metadata, len(src)) + for k, v := range src { + out[k] = copyOf(v) + } + return out +} + +// Copy returns a copy of Metadata. +func (md Metadata) Copy() Metadata { + out := make(Metadata, len(md)) + for k, v := range md { + out[k] = copyOf(v) + } + return out +} + +// AsHTTP1 returns a copy of Metadata +// with CanonicalMIMEHeaderKey. +func (md Metadata) AsHTTP1() map[string][]string { + out := make(map[string][]string, len(md)) + for k, v := range md { + out[textproto.CanonicalMIMEHeaderKey(k)] = copyOf(v) + } + return out +} + +// AsHTTP1 returns a copy of Metadata +// with strings.ToLower. +func (md Metadata) AsHTTP2() map[string][]string { + out := make(map[string][]string, len(md)) + for k, v := range md { + out[strings.ToLower(k)] = copyOf(v) + } + return out +} + +// CopyTo copies Metadata to out. +func (md Metadata) CopyTo(out Metadata) { + for k, v := range md { + out[k] = copyOf(v) + } +} + +// Get obtains the values for a given key. +func (md Metadata) MustGet(k string) []string { + v, ok := md.Get(k) + if !ok { + panic("missing metadata key") + } + return v +} + +// Get obtains the values for a given key. +func (md Metadata) Get(k string) ([]string, bool) { + v, ok := md[k] + if !ok { + v, ok = md[strings.ToLower(k)] + } + if !ok { + v, ok = md[textproto.CanonicalMIMEHeaderKey(k)] + } + return v, ok +} + +// MustGetJoined obtains the values for a given key +// with joined values with "," symbol +func (md Metadata) MustGetJoined(k string) string { + v, ok := md.GetJoined(k) + if !ok { + panic("missing metadata key") + } + return v +} + +// GetJoined obtains the values for a given key +// with joined values with "," symbol +func (md Metadata) GetJoined(k string) (string, bool) { + v, ok := md.Get(k) + if !ok { + return "", ok + } + return strings.Join(v, ","), true +} + +// Set sets the value of a given key with a slice of values. +func (md Metadata) Add(key string, vals ...string) { + if len(vals) == 0 { + return + } + md[key] = vals +} + +// Set sets the value of a given key with a slice of values. +func (md Metadata) Set(kvs ...string) { + if len(kvs)%2 == 1 { + panic(fmt.Sprintf("metadata: Set got an odd number of input pairs for metadata: %d", len(kvs))) + } + + for i := 0; i < len(kvs); i += 2 { + md[kvs[i]] = append(md[kvs[i]], kvs[i+1]) + } +} + +// Append adds the values to key k, not overwriting what was already stored at +// that key. +func (md Metadata) Append(key string, vals ...string) { + if len(vals) == 0 { + return + } + md[key] = append(md[key], vals...) +} + +// Del removes the values for a given keys k. +func (md Metadata) Del(k ...string) { + for i := range k { + delete(md, k[i]) + delete(md, strings.ToLower(k[i])) + delete(md, textproto.CanonicalMIMEHeaderKey(k[i])) + } +} + +// Join joins any number of Metadatas into a single Metadata. +// +// The order of values for each key is determined by the order in which the Metadatas +// containing those values are presented to Join. +func Join(mds ...Metadata) Metadata { + out := Metadata{} + for _, Metadata := range mds { + for k, v := range Metadata { + out[k] = append(out[k], v...) + } + } + return out +} + +type ( + metadataIncomingKey struct{} + metadataOutgoingKey struct{} + metadataCurrentKey struct{} +) + +// NewContext creates a new context with Metadata attached. Metadata must +// not be modified after calling this function. +func NewContext(ctx context.Context, md Metadata) context.Context { + return context.WithValue(ctx, metadataCurrentKey{}, rawMetadata{md: md}) +} + +// NewIncomingContext creates a new context with incoming Metadata attached. Metadata must +// not be modified after calling this function. +func NewIncomingContext(ctx context.Context, md Metadata) context.Context { + return context.WithValue(ctx, metadataIncomingKey{}, rawMetadata{md: md}) +} + +// NewOutgoingContext creates a new context with outgoing Metadata attached. If used +// in conjunction with AppendOutgoingContext, NewOutgoingContext will +// overwrite any previously-appended metadata. Metadata must not be modified after +// calling this function. +func NewOutgoingContext(ctx context.Context, md Metadata) context.Context { + return context.WithValue(ctx, metadataOutgoingKey{}, rawMetadata{md: md}) +} + +// AppendContext returns a new context with the provided kv merged +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. +func AppendContext(ctx context.Context, kv ...string) context.Context { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: AppendContext got an odd number of input pairs for metadata: %d", len(kv))) + } + md, _ := ctx.Value(metadataCurrentKey{}).(rawMetadata) + added := make([][]string, len(md.added)+1) + copy(added, md.added) + kvCopy := make([]string, 0, len(kv)) + for i := 0; i < len(kv); i += 2 { + kvCopy = append(kvCopy, strings.ToLower(kv[i]), kv[i+1]) + } + added[len(added)-1] = kvCopy + return context.WithValue(ctx, metadataCurrentKey{}, rawMetadata{md: md.md, added: added}) +} + +// AppendIncomingContext returns a new context with the provided kv merged +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. +func AppendIncomingContext(ctx context.Context, kv ...string) context.Context { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: AppendIncomingContext got an odd number of input pairs for metadata: %d", len(kv))) + } + md, _ := ctx.Value(metadataIncomingKey{}).(rawMetadata) + added := make([][]string, len(md.added)+1) + copy(added, md.added) + kvCopy := make([]string, 0, len(kv)) + for i := 0; i < len(kv); i += 2 { + kvCopy = append(kvCopy, strings.ToLower(kv[i]), kv[i+1]) + } + added[len(added)-1] = kvCopy + return context.WithValue(ctx, metadataIncomingKey{}, rawMetadata{md: md.md, added: added}) +} + +// AppendOutgoingContext returns a new context with the provided kv merged +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. +func AppendOutgoingContext(ctx context.Context, kv ...string) context.Context { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: AppendOutgoingContext got an odd number of input pairs for metadata: %d", len(kv))) + } + md, _ := ctx.Value(metadataOutgoingKey{}).(rawMetadata) + added := make([][]string, len(md.added)+1) + copy(added, md.added) + kvCopy := make([]string, 0, len(kv)) + for i := 0; i < len(kv); i += 2 { + kvCopy = append(kvCopy, strings.ToLower(kv[i]), kv[i+1]) + } + added[len(added)-1] = kvCopy + return context.WithValue(ctx, metadataOutgoingKey{}, rawMetadata{md: md.md, added: added}) +} + +// FromContext returns the metadata in ctx if it exists. +func FromContext(ctx context.Context) (Metadata, bool) { + raw, ok := ctx.Value(metadataCurrentKey{}).(rawMetadata) + if !ok { + return nil, false + } + metadataSize := len(raw.md) + for i := range raw.added { + metadataSize += len(raw.added[i]) / 2 + } + + out := make(Metadata, metadataSize) + for k, v := range raw.md { + out[k] = copyOf(v) + } + for _, added := range raw.added { + if len(added)%2 == 1 { + panic(fmt.Sprintf("metadata: FromContext got an odd number of input pairs for metadata: %d", len(added))) + } + + for i := 0; i < len(added); i += 2 { + out[added[i]] = append(out[added[i]], added[i+1]) + } + } + return out, true +} + +// MustContext returns the metadata in ctx. +func MustContext(ctx context.Context) Metadata { + md, ok := FromContext(ctx) + if !ok { + panic("missing metadata") + } + return md +} + +// FromIncomingContext returns the incoming metadata in ctx if it exists. +func FromIncomingContext(ctx context.Context) (Metadata, bool) { + raw, ok := ctx.Value(metadataIncomingKey{}).(rawMetadata) + if !ok { + return nil, false + } + metadataSize := len(raw.md) + for i := range raw.added { + metadataSize += len(raw.added[i]) / 2 + } + + out := make(Metadata, metadataSize) + for k, v := range raw.md { + out[k] = copyOf(v) + } + for _, added := range raw.added { + if len(added)%2 == 1 { + panic(fmt.Sprintf("metadata: FromIncomingContext got an odd number of input pairs for metadata: %d", len(added))) + } + + for i := 0; i < len(added); i += 2 { + out[added[i]] = append(out[added[i]], added[i+1]) + } + } + return out, true +} + +// MustIncomingContext returns the incoming metadata in ctx. +func MustIncomingContext(ctx context.Context) Metadata { + md, ok := FromIncomingContext(ctx) + if !ok { + panic("missing metadata") + } + return md +} + +// ValueFromIncomingContext returns the metadata value corresponding to the metadata +// key from the incoming metadata if it exists. Keys are matched in a case insensitive +// manner. +func ValueFromIncomingContext(ctx context.Context, key string) []string { + raw, ok := ctx.Value(metadataIncomingKey{}).(rawMetadata) + if !ok { + return nil + } + + if v, ok := raw.md[key]; ok { + return copyOf(v) + } + for k, v := range raw.md { + // Case insensitive comparison: Metadata is a map, and there's no guarantee + // that the Metadata attached to the context is created using our helper + // functions. + if strings.EqualFold(k, key) { + return copyOf(v) + } + } + return nil +} + +// ValueFromCurrentContext returns the metadata value corresponding to the metadata +// key from the incoming metadata if it exists. Keys are matched in a case insensitive +// manner. +func ValueFromCurrentContext(ctx context.Context, key string) []string { + md, ok := ctx.Value(metadataCurrentKey{}).(rawMetadata) + if !ok { + return nil + } + + if v, ok := md.md[key]; ok { + return copyOf(v) + } + for k, v := range md.md { + // Case insensitive comparison: Metadata is a map, and there's no guarantee + // that the Metadata attached to the context is created using our helper + // functions. + if strings.EqualFold(k, key) { + return copyOf(v) + } + } + return nil +} + +// MustOutgoingContext returns the outgoing metadata in ctx. +func MustOutgoingContext(ctx context.Context) Metadata { + md, ok := FromOutgoingContext(ctx) + if !ok { + panic("missing metadata") + } + return md +} + +// ValueFromOutgoingContext returns the metadata value corresponding to the metadata +// key from the incoming metadata if it exists. Keys are matched in a case insensitive +// manner. +func ValueFromOutgoingContext(ctx context.Context, key string) []string { + md, ok := ctx.Value(metadataOutgoingKey{}).(rawMetadata) + if !ok { + return nil + } + + if v, ok := md.md[key]; ok { + return copyOf(v) + } + for k, v := range md.md { + // Case insensitive comparison: Metadata is a map, and there's no guarantee + // that the Metadata attached to the context is created using our helper + // functions. + if strings.EqualFold(k, key) { + return copyOf(v) + } + } + return nil +} + +func copyOf(v []string) []string { + vals := make([]string, len(v)) + copy(vals, v) + return vals +} + +// FromOutgoingContext returns the outgoing metadata in ctx if it exists. +func FromOutgoingContext(ctx context.Context) (Metadata, bool) { + raw, ok := ctx.Value(metadataOutgoingKey{}).(rawMetadata) + if !ok { + return nil, false + } + + metadataSize := len(raw.md) + for i := range raw.added { + metadataSize += len(raw.added[i]) / 2 + } + + out := make(Metadata, metadataSize) + for k, v := range raw.md { + out[k] = copyOf(v) + } + for _, added := range raw.added { + if len(added)%2 == 1 { + panic(fmt.Sprintf("metadata: FromOutgoingContext got an odd number of input pairs for metadata: %d", len(added))) + } + + for i := 0; i < len(added); i += 2 { + out[added[i]] = append(out[added[i]], added[i+1]) + } + } + return out, ok +} + +type rawMetadata struct { + md Metadata + added [][]string +} + // Iterator used to iterate over metadata with order type Iterator struct { md Metadata @@ -46,6 +464,7 @@ type Iterator struct { cnt int } +/* // Next advance iterator to next element func (iter *Iterator) Next(k, v *string) bool { if iter.cur+1 > iter.cnt { @@ -53,122 +472,19 @@ func (iter *Iterator) Next(k, v *string) bool { } *k = iter.keys[iter.cur] - *v = iter.md[*k] + *v = iter.Metadata[*k] iter.cur++ return true } // Iterator returns the itarator for metadata in sorted order -func (md Metadata) Iterator() *Iterator { - iter := &Iterator{md: md, cnt: len(md)} +func (Metadata Metadata) Iterator() *Iterator { + iter := &Iterator{Metadata: Metadata, cnt: len(Metadata)} iter.keys = make([]string, 0, iter.cnt) - for k := range md { + for k := range Metadata { iter.keys = append(iter.keys, k) } sort.Strings(iter.keys) return iter } - -func (md Metadata) MustGet(key string) string { - val, ok := md.Get(key) - if !ok { - panic("missing metadata key") - } - return val -} - -// Len returns the number of items. -func (md Metadata) Len() int { - return len(md) -} - -// Get returns value from metadata by key -func (md Metadata) Get(key string) (string, bool) { - // fast path - val, ok := md[key] - if !ok { - // slow path - val, ok = md[textproto.CanonicalMIMEHeaderKey(key)] - if !ok { - val, ok = md[strings.ToLower(key)] - } - } - return val, ok -} - -// Set is used to store value in metadata -func (md Metadata) Set(kv ...string) { - if len(kv)%2 == 1 { - kv = kv[:len(kv)-1] - } - for idx := 0; idx < len(kv); idx += 2 { - md[textproto.CanonicalMIMEHeaderKey(kv[idx])] = kv[idx+1] - } -} - -// Del is used to remove value from metadata -func (md Metadata) Del(keys ...string) { - for _, key := range keys { - // fast path - delete(md, key) - // slow path - delete(md, textproto.CanonicalMIMEHeaderKey(key)) - // very slow path - delete(md, strings.ToLower(key)) - } -} - -// Copy makes a copy of the metadata -func (md Metadata) CopyTo(dst Metadata) { - for k, v := range md { - dst[k] = v - } -} - -// Copy makes a copy of the metadata -func Copy(md Metadata, exclude ...string) Metadata { - nmd := New(len(md)) - for k, v := range md { - nmd[k] = v - } - nmd.Del(exclude...) - return nmd -} - -// New return new sized metadata -func New(size int) Metadata { - if size == 0 { - size = defaultMetadataSize - } - return make(Metadata, size) -} - -// Merge merges metadata to existing metadata, overwriting if specified -func Merge(omd Metadata, mmd Metadata, overwrite bool) Metadata { - var ok bool - nmd := Copy(omd) - for key, val := range mmd { - _, ok = nmd[key] - switch { - case ok && !overwrite: - continue - case val != "": - nmd[key] = val - case ok && val == "": - nmd.Del(key) - } - } - return nmd -} - -// Pairs from which metadata created -func Pairs(kv ...string) Metadata { - if len(kv)%2 == 1 { - return nil - } - md := New(len(kv) / 2) - for idx := 0; idx < len(kv); idx += 2 { - md[kv[idx]] = kv[idx+1] - } - return md -} +*/ diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 47d53cf9..0891021f 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -5,10 +5,21 @@ import ( "testing" ) +/* +func TestAppendOutgoingContextModify(t *testing.T) { + md := Pairs("key1", "val1") + ctx := NewOutgoingContext(context.TODO(), md) + nctx := AppendOutgoingContext(ctx, "key1", "val3", "key2", "val2") + _ = nctx + omd := MustOutgoingContext(nctx) + fmt.Printf("%#+v\n", omd) +} +*/ + func TestLowercase(t *testing.T) { md := New(1) - md["x-request-id"] = "12345" - v, ok := md.Get("X-Request-Id") + md["x-request-id"] = []string{"12345"} + v, ok := md.GetJoined("X-Request-Id") if !ok || v == "" { t.Fatalf("metadata invalid %#+v", md) } @@ -38,15 +49,12 @@ func TestMultipleUsage(t *testing.T) { func TestMetadataSetMultiple(t *testing.T) { md := New(4) - md.Set("key1", "val1", "key2", "val2", "key3") + md.Set("key1", "val1", "key2", "val2") - if v, ok := md.Get("key1"); !ok || v != "val1" { + if v, ok := md.GetJoined("key1"); !ok || v != "val1" { t.Fatalf("invalid kv %#+v", md) } - if v, ok := md.Get("key2"); !ok || v != "val2" { - t.Fatalf("invalid kv %#+v", md) - } - if _, ok := md.Get("key3"); ok { + if v, ok := md.GetJoined("key2"); !ok || v != "val2" { t.Fatalf("invalid kv %#+v", md) } } @@ -64,22 +72,12 @@ func TestAppend(t *testing.T) { } func TestPairs(t *testing.T) { - md, ok := Pairs("key1", "val1", "key2", "val2") - if !ok { - t.Fatal("odd number of kv") - } - if _, ok = md.Get("key1"); !ok { + md := Pairs("key1", "val1", "key2", "val2") + if _, ok := md.Get("key1"); !ok { t.Fatal("key1 not found") } } -func testCtx(ctx context.Context) { - md := New(2) - md.Set("Key1", "Val1_new") - md.Set("Key3", "Val3") - SetOutgoingContext(ctx, md) -} - func TestPassing(t *testing.T) { ctx := context.TODO() md1 := New(2) @@ -87,37 +85,24 @@ func TestPassing(t *testing.T) { md1.Set("Key2", "Val2") ctx = NewIncomingContext(ctx, md1) - testCtx(ctx) + _, ok := FromOutgoingContext(ctx) if ok { t.Fatalf("create outgoing context") } - ctx = NewOutgoingContext(ctx, New(1)) - testCtx(ctx) + ctx = NewOutgoingContext(ctx, md1) + md, ok := FromOutgoingContext(ctx) if !ok { t.Fatalf("missing metadata from outgoing context") } - if v, ok := md.Get("Key1"); !ok || v != "Val1_new" { + if v, ok := md.Get("Key1"); !ok || v[0] != "Val1" { t.Fatalf("invalid metadata value %#+v", md) } } -func TestMerge(t *testing.T) { - omd := Metadata{ - "key1": "val1", - } - mmd := Metadata{ - "key2": "val2", - } - - nmd := Merge(omd, mmd, true) - if len(nmd) != 2 { - t.Fatalf("merge failed: %v", nmd) - } -} - +/* func TestIterator(_ *testing.T) { md := Metadata{ "1Last": "last", @@ -132,24 +117,25 @@ func TestIterator(_ *testing.T) { // fmt.Printf("k: %s, v: %s\n", k, v) } } +*/ func TestMedataCanonicalKey(t *testing.T) { md := New(1) md.Set("x-request-id", "12345") - v, ok := md.Get("x-request-id") + v, ok := md.GetJoined("x-request-id") if !ok { t.Fatalf("failed to get x-request-id") } else if v != "12345" { t.Fatalf("invalid metadata value: %s != %s", "12345", v) } - v, ok = md.Get("X-Request-Id") + v, ok = md.GetJoined("X-Request-Id") if !ok { t.Fatalf("failed to get x-request-id") } else if v != "12345" { t.Fatalf("invalid metadata value: %s != %s", "12345", v) } - v, ok = md.Get("X-Request-ID") + v, ok = md.GetJoined("X-Request-ID") if !ok { t.Fatalf("failed to get x-request-id") } else if v != "12345" { @@ -162,7 +148,7 @@ func TestMetadataSet(t *testing.T) { md.Set("Key", "val") - val, ok := md.Get("Key") + val, ok := md.GetJoined("Key") if !ok { t.Fatal("key Key not found") } @@ -173,8 +159,8 @@ func TestMetadataSet(t *testing.T) { func TestMetadataDelete(t *testing.T) { md := Metadata{ - "Foo": "bar", - "Baz": "empty", + "Foo": []string{"bar"}, + "Baz": []string{"empty"}, } md.Del("Baz") @@ -184,25 +170,16 @@ func TestMetadataDelete(t *testing.T) { } } -func TestNilContext(t *testing.T) { - var ctx context.Context - - _, ok := FromContext(ctx) - if ok { - t.Fatal("nil context") - } -} - func TestMetadataCopy(t *testing.T) { md := Metadata{ - "Foo": "bar", - "Bar": "baz", + "Foo": []string{"bar"}, + "Bar": []string{"baz"}, } cp := Copy(md) for k, v := range md { - if cv := cp[k]; cv != v { + if cv := cp[k]; cv[0] != v[0] { t.Fatalf("Got %s:%s for %s:%s", k, cv, k, v) } } @@ -210,7 +187,7 @@ func TestMetadataCopy(t *testing.T) { func TestMetadataContext(t *testing.T) { md := Metadata{ - "Foo": "bar", + "Foo": []string{"bar"}, } ctx := NewContext(context.TODO(), md) @@ -220,7 +197,7 @@ func TestMetadataContext(t *testing.T) { t.Errorf("Unexpected error retrieving metadata, got %t", ok) } - if emd["Foo"] != md["Foo"] { + if emd["Foo"][0] != md["Foo"][0] { t.Errorf("Expected key: %s val: %s, got key: %s val: %s", "Foo", md["Foo"], "Foo", emd["Foo"]) } @@ -229,13 +206,88 @@ func TestMetadataContext(t *testing.T) { } } -func TestCopy(t *testing.T) { - md := New(2) - md.Set("key1", "val1", "key2", "val2") - nmd := Copy(md, "key2") - if len(nmd) != 1 { - t.Fatal("Copy exclude not works") - } else if nmd["Key1"] != "val1" { - t.Fatal("Copy exclude not works") +func TestFromContext(t *testing.T) { + ctx := context.WithValue(context.TODO(), metadataCurrentKey{}, rawMetadata{md: New(0)}) + + c, ok := FromContext(ctx) + if c == nil || !ok { + t.Fatal("FromContext not works") + } +} + +func TestNewContext(t *testing.T) { + ctx := NewContext(context.TODO(), New(0)) + + c, ok := FromContext(ctx) + if c == nil || !ok { + t.Fatal("NewContext not works") + } +} + +func TestFromIncomingContext(t *testing.T) { + ctx := context.WithValue(context.TODO(), metadataIncomingKey{}, rawMetadata{md: New(0)}) + + c, ok := FromIncomingContext(ctx) + if c == nil || !ok { + t.Fatal("FromIncomingContext not works") + } +} + +func TestFromOutgoingContext(t *testing.T) { + ctx := context.WithValue(context.TODO(), metadataOutgoingKey{}, rawMetadata{md: New(0)}) + + c, ok := FromOutgoingContext(ctx) + if c == nil || !ok { + t.Fatal("FromOutgoingContext not works") + } +} + +func TestNewIncomingContext(t *testing.T) { + md := New(1) + md.Set("key", "val") + ctx := NewIncomingContext(context.TODO(), md) + + c, ok := FromIncomingContext(ctx) + if c == nil || !ok { + t.Fatal("NewIncomingContext not works") + } +} + +func TestNewOutgoingContext(t *testing.T) { + md := New(1) + md.Set("key", "val") + ctx := NewOutgoingContext(context.TODO(), md) + + c, ok := FromOutgoingContext(ctx) + if c == nil || !ok { + t.Fatal("NewOutgoingContext not works") + } +} + +func TestAppendIncomingContext(t *testing.T) { + md := New(1) + md.Set("key1", "val1") + ctx := AppendIncomingContext(context.TODO(), "key2", "val2") + + nmd, ok := FromIncomingContext(ctx) + if nmd == nil || !ok { + t.Fatal("AppendIncomingContext not works") + } + if v, ok := nmd.GetJoined("key2"); !ok || v != "val2" { + t.Fatal("AppendIncomingContext not works") + } +} + +func TestAppendOutgoingContext(t *testing.T) { + md := New(1) + md.Set("key1", "val1") + ctx := AppendOutgoingContext(context.TODO(), "key2", "val2") + + nmd, ok := FromOutgoingContext(ctx) + if nmd == nil || !ok { + t.Fatal("AppendOutgoingContext not works") + } + if v, ok := nmd.GetJoined("key2"); !ok || v != "val2" { + t.Fatal("AppendOutgoingContext not works") } } diff --git a/micro_test.go b/micro_test.go index f92c223e..f8990f6f 100644 --- a/micro_test.go +++ b/micro_test.go @@ -8,6 +8,7 @@ import ( "go.unistack.org/micro/v4/broker" "go.unistack.org/micro/v4/fsm" + "go.unistack.org/micro/v4/metadata" ) func TestAs(t *testing.T) { @@ -61,6 +62,8 @@ func TestAs(t *testing.T) { } } +var _ broker.Broker = (*bro)(nil) + type bro struct { name string } @@ -87,23 +90,18 @@ func (p *bro) Connect(_ context.Context) error { return nil } // Disconnect disconnect from broker func (p *bro) Disconnect(_ context.Context) error { return nil } -// Publish message, msg can be single broker.Message or []broker.Message -func (p *bro) Publish(_ context.Context, _ string, _ *broker.Message, _ ...broker.PublishOption) error { - return nil -} - -// BatchPublish messages to broker with multiple topics -func (p *bro) BatchPublish(_ context.Context, _ []*broker.Message, _ ...broker.PublishOption) error { - return nil -} - -// BatchSubscribe subscribes to topic messages via handler -func (p *bro) BatchSubscribe(_ context.Context, _ string, _ broker.BatchHandler, _ ...broker.SubscribeOption) (broker.Subscriber, error) { +// NewMessage creates new message +func (p *bro) NewMessage(_ context.Context, _ metadata.Metadata, _ interface{}, _ ...broker.PublishOption) (broker.Message, error) { return nil, nil } +// Publish message, msg can be single broker.Message or []broker.Message +func (p *bro) Publish(_ context.Context, _ string, _ ...broker.Message) error { + return nil +} + // Subscribe subscribes to topic message via handler -func (p *bro) Subscribe(_ context.Context, _ string, _ broker.Handler, _ ...broker.SubscribeOption) (broker.Subscriber, error) { +func (p *bro) Subscribe(_ context.Context, _ string, _ interface{}, _ ...broker.SubscribeOption) (broker.Subscriber, error) { return nil, nil } diff --git a/register/memory/memory.go b/register/memory/memory.go index 43518e98..ebf93178 100644 --- a/register/memory/memory.go +++ b/register/memory/memory.go @@ -6,9 +6,10 @@ import ( "sync" "time" - "go.unistack.org/micro/v4/logger" - "go.unistack.org/micro/v4/register" - "go.unistack.org/micro/v4/util/id" + "go.unistack.org/micro/v3/logger" + "go.unistack.org/micro/v3/metadata" + "go.unistack.org/micro/v3/register" + "go.unistack.org/micro/v3/util/id" ) var ( @@ -23,10 +24,9 @@ type node struct { } type record struct { - Name string - Version string - Metadata map[string]string - Nodes map[string]*node + Name string + Version string + Nodes map[string]*node } type memory struct { @@ -160,19 +160,14 @@ func (m *memory) Register(_ context.Context, s *register.Service, opts ...regist continue } - metadata := make(map[string]string, len(n.Metadata)) - - // make copy of metadata - for k, v := range n.Metadata { - metadata[k] = v - } + md := metadata.Copy(n.Metadata) // add the node srvs[s.Name][s.Version].Nodes[n.ID] = &node{ Node: ®ister.Node{ ID: n.ID, Address: n.Address, - Metadata: metadata, + Metadata: md, }, TTL: options.TTL, LastSeen: time.Now(), @@ -452,23 +447,15 @@ func serviceToRecord(s *register.Service, ttl time.Duration) *record { } func recordToService(r *record, namespace string) *register.Service { - metadata := make(map[string]string, len(r.Metadata)) - for k, v := range r.Metadata { - metadata[k] = v - } - nodes := make([]*register.Node, len(r.Nodes)) i := 0 for _, n := range r.Nodes { - md := make(map[string]string, len(n.Metadata)) - for k, v := range n.Metadata { - md[k] = v - } + nmd := metadata.Copy(n.Metadata) nodes[i] = ®ister.Node{ ID: n.ID, Address: n.Address, - Metadata: md, + Metadata: nmd, } i++ } diff --git a/register/memory/memory_test.go b/register/memory/memory_test.go index 348a6c90..71f4cd77 100644 --- a/register/memory/memory_test.go +++ b/register/memory/memory_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "go.unistack.org/micro/v4/register" + "go.unistack.org/micro/v3/register" ) var testData = map[string][]*register.Service{ diff --git a/server/registry.go b/server/registry.go index a6d34fed..9d4bf7dc 100644 --- a/server/registry.go +++ b/server/registry.go @@ -77,10 +77,6 @@ func NewRegisterService(s Server) (*register.Service, error) { } node.Metadata = metadata.Copy(opts.Metadata) - node.Metadata["server"] = s.String() - node.Metadata["broker"] = opts.Broker.String() - node.Metadata["register"] = opts.Register.String() - return ®ister.Service{ Name: opts.Name, Version: opts.Version,