diff --git a/broker/memory/memory.go b/broker/memory/memory.go index 36554237..1f9bea86 100644 --- a/broker/memory/memory.go +++ b/broker/memory/memory.go @@ -2,6 +2,7 @@ package memory import ( + "context" "errors" "math/rand" "sync" @@ -9,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/micro/go-micro/v2/broker" + log "github.com/micro/go-micro/v2/logger" maddr "github.com/micro/go-micro/v2/util/addr" mnet "github.com/micro/go-micro/v2/util/net" ) @@ -23,8 +25,9 @@ type memoryBroker struct { } type memoryEvent struct { + opts broker.Options topic string - message *broker.Message + message interface{} } type memorySubscriber struct { @@ -85,7 +88,7 @@ func (m *memoryBroker) Init(opts ...broker.Option) error { return nil } -func (m *memoryBroker) Publish(topic string, message *broker.Message, opts ...broker.PublishOption) error { +func (m *memoryBroker) Publish(topic string, msg *broker.Message, opts ...broker.PublishOption) error { m.RLock() if !m.connected { m.RUnlock() @@ -98,9 +101,21 @@ func (m *memoryBroker) Publish(topic string, message *broker.Message, opts ...br return nil } + var v interface{} + if m.opts.Codec != nil { + buf, err := m.opts.Codec.Marshal(msg) + if err != nil { + return err + } + v = buf + } else { + v = msg + } + p := &memoryEvent{ topic: topic, - message: message, + message: v, + opts: m.opts, } for _, sub := range subs { @@ -163,7 +178,19 @@ func (m *memoryEvent) Topic() string { } func (m *memoryEvent) Message() *broker.Message { - return m.message + switch v := m.message.(type) { + case *broker.Message: + return v + case []byte: + msg := &broker.Message{} + if err := m.opts.Codec.Unmarshal(v, msg); err != nil { + log.Errorf("[memory]: failed to unmarshal: %v\n", err) + return nil + } + return msg + } + + return nil } func (m *memoryEvent) Ack() error { @@ -184,7 +211,10 @@ func (m *memorySubscriber) Unsubscribe() error { } func NewBroker(opts ...broker.Option) broker.Broker { - var options broker.Options + options := broker.Options{ + Context: context.Background(), + } + rand.Seed(time.Now().UnixNano()) for _, o := range opts { o(&options)