diff --git a/broker/mqtt/mqtt.go b/broker/mqtt/mqtt.go new file mode 100644 index 00000000..13bf21a5 --- /dev/null +++ b/broker/mqtt/mqtt.go @@ -0,0 +1,248 @@ +package mqtt + +/* + MQTT is a go-micro Broker for the MQTT protocol. + This can be integrated with any broker that supports MQTT, + including Mosquito and AWS IoT. + + TODO: Strip encoding? + Where brokers don't support headers we're actually + encoding the broker.Message in json to simplify usage + and cross broker compatibility. To actually use the + MQTT broker more widely on the internet we may need to + support stripping the encoding. + + Note: Because of the way the MQTT library works, when you + unsubscribe from a topic it will unsubscribe all subscribers. + TODO: Perhaps create a unique client per subscription. + Becomes slightly more difficult to track a disconnect. + +*/ + +import ( + "encoding/json" + "errors" + "fmt" + "log" + "math/rand" + "strconv" + "strings" + "time" + + "github.com/eclipse/paho.mqtt.golang" + "github.com/micro/go-micro/broker" +) + +type mqttBroker struct { + addrs []string + opts broker.Options + client mqtt.Client +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func setAddrs(addrs []string) []string { + var cAddrs []string + + for _, addr := range addrs { + if len(addr) == 0 { + continue + } + + var scheme string + var host string + var port int + + // split on scheme + parts := strings.Split(addr, "://") + + // no scheme + if len(parts) < 2 { + // default tcp scheme + scheme = "tcp" + parts = strings.Split(parts[0], ":") + // got scheme + } else { + scheme = parts[0] + parts = strings.Split(parts[1], ":") + } + + // no parts + if len(parts) == 0 { + continue + } + + // check scheme + switch scheme { + case "tcp", "ssl", "ws": + default: + continue + } + + if len(parts) < 2 { + // no port + host = parts[0] + + switch scheme { + case "tcp": + port = 1883 + case "ssl": + port = 8883 + case "ws": + // support secure port + port = 80 + default: + port = 1883 + } + // got host port + } else { + host = parts[0] + port, _ = strconv.Atoi(parts[1]) + } + + addr = fmt.Sprintf("%s://%s:%d", scheme, host, port) + cAddrs = append(cAddrs, addr) + + } + + // default an address if we have none + if len(cAddrs) == 0 { + cAddrs = []string{"tcp://127.0.0.1:1883"} + } + + return cAddrs +} + +func newClient(addrs []string, opts broker.Options) mqtt.Client { + // create opts + cOpts := mqtt.NewClientOptions() + cOpts.SetClientID(fmt.Sprintf("%d%d", time.Now().UnixNano(), rand.Intn(10))) + cOpts.SetCleanSession(false) + + // setup tls + if opts.TLSConfig != nil { + cOpts.SetTLSConfig(opts.TLSConfig) + } + + // add brokers + for _, addr := range addrs { + cOpts.AddBroker(addr) + } + + return mqtt.NewClient(cOpts) +} + +func newBroker(opts ...broker.Option) broker.Broker { + var options broker.Options + for _, o := range opts { + o(&options) + } + + addrs := setAddrs(options.Addrs) + client := newClient(addrs, options) + + return &mqttBroker{ + opts: options, + client: client, + addrs: addrs, + } +} + +func (m *mqttBroker) Options() broker.Options { + return m.opts +} + +func (m *mqttBroker) Address() string { + return strings.Join(m.addrs, ",") +} + +func (m *mqttBroker) Connect() error { + if m.client.IsConnected() { + return nil + } + + if t := m.client.Connect(); t.Wait() && t.Error() != nil { + return t.Error() + } + + return nil +} + +func (m *mqttBroker) Disconnect() error { + if !m.client.IsConnected() { + return nil + } + m.client.Disconnect(0) + return nil +} + +func (m *mqttBroker) Init(opts ...broker.Option) error { + if m.client.IsConnected() { + return errors.New("cannot init while connected") + } + + for _, o := range opts { + o(&m.opts) + } + + m.addrs = setAddrs(m.opts.Addrs) + m.client = newClient(m.addrs, m.opts) + return nil +} + +func (m *mqttBroker) Publish(topic string, msg *broker.Message, opts ...broker.PublishOption) error { + if !m.client.IsConnected() { + return errors.New("not connected") + } + + b, err := json.Marshal(msg) + if err != nil { + return err + } + + t := m.client.Publish(topic, 1, false, b) + return t.Error() +} + +func (m *mqttBroker) Subscribe(topic string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) { + if !m.client.IsConnected() { + return nil, errors.New("not connected") + } + + var options broker.SubscribeOptions + for _, o := range opts { + o(&options) + } + + t := m.client.Subscribe(topic, 1, func(c mqtt.Client, m mqtt.Message) { + var msg *broker.Message + if err := json.Unmarshal(m.Payload(), &msg); err != nil { + log.Println(err) + return + } + + if err := h(&mqttPub{topic: topic, msg: msg}); err != nil { + log.Println(err) + } + }) + + if t.Wait() && t.Error() != nil { + return nil, t.Error() + } + + return &mqttSub{ + opts: options, + client: m.client, + topic: topic, + }, nil +} + +func (m *mqttBroker) String() string { + return "mqtt" +} + +func NewBroker(opts ...broker.Option) broker.Broker { + return newBroker(opts...) +} diff --git a/broker/mqtt/mqtt_handler.go b/broker/mqtt/mqtt_handler.go new file mode 100644 index 00000000..00ff30e5 --- /dev/null +++ b/broker/mqtt/mqtt_handler.go @@ -0,0 +1,44 @@ +package mqtt + +import ( + "github.com/eclipse/paho.mqtt.golang" + "github.com/micro/go-micro/broker" +) + +// mqttPub is a broker.Publication +type mqttPub struct { + topic string + msg *broker.Message +} + +// mqttPub is a broker.Subscriber +type mqttSub struct { + opts broker.SubscribeOptions + topic string + client mqtt.Client +} + +func (m *mqttPub) Ack() error { + return nil +} + +func (m *mqttPub) Topic() string { + return m.topic +} + +func (m *mqttPub) Message() *broker.Message { + return m.msg +} + +func (m *mqttSub) Options() broker.SubscribeOptions { + return m.opts +} + +func (m *mqttSub) Topic() string { + return m.topic +} + +func (m *mqttSub) Unsubscribe() error { + t := m.client.Unsubscribe(m.topic) + return t.Error() +} diff --git a/broker/mqtt/mqtt_mock.go b/broker/mqtt/mqtt_mock.go new file mode 100644 index 00000000..748de41a --- /dev/null +++ b/broker/mqtt/mqtt_mock.go @@ -0,0 +1,171 @@ +package mqtt + +import ( + "math/rand" + "sync" + "time" + + "github.com/eclipse/paho.mqtt.golang" +) + +type mockClient struct { + sync.Mutex + connected bool + exit chan bool + + subs map[string][]mqtt.MessageHandler +} + +type mockMessage struct { + id uint16 + topic string + qos byte + retained bool + payload interface{} +} + +var ( + _ mqtt.Client = newMockClient() + _ mqtt.Message = newMockMessage("mock", 0, false, nil) +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func newMockClient() mqtt.Client { + return &mockClient{ + subs: make(map[string][]mqtt.MessageHandler), + } +} + +func newMockMessage(topic string, qos byte, retained bool, payload interface{}) mqtt.Message { + return &mockMessage{ + id: uint16(rand.Int()), + topic: topic, + qos: qos, + retained: retained, + payload: payload, + } +} + +func (m *mockMessage) Duplicate() bool { + return false +} + +func (m *mockMessage) Qos() byte { + return m.qos +} + +func (m *mockMessage) Retained() bool { + return m.retained +} + +func (m *mockMessage) Topic() string { + return m.topic +} + +func (m *mockMessage) MessageID() uint16 { + return m.id +} + +func (m *mockMessage) Payload() []byte { + return m.payload.([]byte) +} + +func (m *mockClient) IsConnected() bool { + m.Lock() + defer m.Unlock() + return m.connected +} + +func (m *mockClient) Connect() mqtt.Token { + m.Lock() + defer m.Unlock() + + if m.connected { + return nil + } + + m.connected = true + m.exit = make(chan bool) + return &mqtt.ConnectToken{} +} + +func (m *mockClient) Disconnect(uint) { + m.Lock() + defer m.Unlock() + + if !m.connected { + return + } + + m.connected = false + + select { + case <-m.exit: + return + default: + close(m.exit) + } +} + +func (m *mockClient) Publish(topic string, qos byte, retained bool, payload interface{}) mqtt.Token { + m.Lock() + defer m.Unlock() + + if !m.connected { + return nil + } + + msg := newMockMessage(topic, qos, retained, payload) + + for _, sub := range m.subs[topic] { + sub(m, msg) + } + + return &mqtt.PublishToken{} +} + +func (m *mockClient) Subscribe(topic string, qos byte, h mqtt.MessageHandler) mqtt.Token { + m.Lock() + defer m.Unlock() + + if !m.connected { + return nil + } + + m.subs[topic] = append(m.subs[topic], h) + + return &mqtt.SubscribeToken{} +} + +func (m *mockClient) SubscribeMultiple(topics map[string]byte, h mqtt.MessageHandler) mqtt.Token { + m.Lock() + defer m.Unlock() + + if !m.connected { + return nil + } + + for topic, _ := range topics { + m.subs[topic] = append(m.subs[topic], h) + } + + return &mqtt.SubscribeToken{} +} + +func (m *mockClient) Unsubscribe(topics ...string) mqtt.Token { + m.Lock() + defer m.Unlock() + + if !m.connected { + return nil + } + + for _, topic := range topics { + delete(m.subs, topic) + } + + return &mqtt.UnsubscribeToken{} +} diff --git a/broker/mqtt/mqtt_test.go b/broker/mqtt/mqtt_test.go new file mode 100644 index 00000000..6d7fd1cd --- /dev/null +++ b/broker/mqtt/mqtt_test.go @@ -0,0 +1,89 @@ +package mqtt + +import ( + "testing" + + "github.com/eclipse/paho.mqtt.golang" + "github.com/micro/go-micro/broker" +) + +func TestMQTTMock(t *testing.T) { + c := newMockClient() + + if tk := c.Connect(); tk == nil { + t.Fatal("got nil token") + } + + if tk := c.Subscribe("mock", 0, func(cm mqtt.Client, m mqtt.Message) { + t.Logf("Received payload %+v", string(m.Payload())) + }); tk == nil { + t.Fatal("got nil token") + } + + if tk := c.Publish("mock", 0, false, []byte(`hello world`)); tk == nil { + t.Fatal("got nil token") + } + + if tk := c.Unsubscribe("mock"); tk == nil { + t.Fatal("got nil token") + } + + c.Disconnect(0) +} + +func TestMQTTHandler(t *testing.T) { + p := &mqttPub{ + topic: "mock", + msg: &broker.Message{Body: []byte(`hello`)}, + } + + if p.Topic() != "mock" { + t.Fatal("Expected topic mock got", p.Topic()) + } + + if string(p.Message().Body) != "hello" { + t.Fatal("Expected `hello` message got %s", string(p.Message().Body)) + } + + s := &mqttSub{ + topic: "mock", + client: newMockClient(), + } + + s.client.Connect() + + if s.Topic() != "mock" { + t.Fatal("Expected topic mock got", s.Topic()) + } + + if err := s.Unsubscribe(); err != nil { + t.Fatal("Error unsubscribing", err) + } + + s.client.Disconnect(0) +} + +func TestMQTT(t *testing.T) { + b := NewBroker() + + if err := b.Init(); err != nil { + t.Fatal(err) + } + + // use mock client + b.(*mqttBroker).client = newMockClient() + + if tk := b.(*mqttBroker).client.Connect(); tk == nil { + t.Fatal("got nil token") + } + + if err := b.Publish("mock", &broker.Message{Body: []byte(`hello`)}); err != nil { + t.Fatal(err) + } + + if err := b.Disconnect(); err != nil { + t.Fatal(err) + } + + b.(*mqttBroker).client.Disconnect(0) +} diff --git a/cmd/cmd.go b/cmd/cmd.go index c134a0ca..eedd9acd 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -10,6 +10,7 @@ import ( "github.com/micro/cli" "github.com/micro/go-micro/broker" + "github.com/micro/go-micro/broker/mqtt" "github.com/micro/go-micro/client" "github.com/micro/go-micro/registry" "github.com/micro/go-micro/selector" @@ -118,6 +119,7 @@ var ( DefaultBrokers = map[string]func(...broker.Option) broker.Broker{ "http": broker.NewBroker, + "mqtt": mqtt.NewBroker, } DefaultRegistries = map[string]func(...registry.Option) registry.Registry{