diff --git a/client/grpc/codec.go b/client/grpc/codec.go index ff377690..0366675a 100644 --- a/client/grpc/codec.go +++ b/client/grpc/codec.go @@ -155,7 +155,7 @@ func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { m = new(codec.Message) } if m.Header == nil { - m.Header = make(map[string]string) + m.Header = make(map[string]string, len(md)) } for k, v := range md { m.Header[k] = strings.Join(v, ",") diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index 519b4450..0905eac9 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -110,13 +110,18 @@ func (g *grpcClient) next(request client.Request, opts client.CallOptions) (sele } func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.Request, rsp interface{}, opts client.CallOptions) error { + var header map[string]string + address := node.Address - header := make(map[string]string) + header = make(map[string]string) if md, ok := metadata.FromContext(ctx); ok { + header = make(map[string]string, len(md)) for k, v := range md { header[k] = v } + } else { + header = make(map[string]string) } // set timeout in nanoseconds @@ -182,13 +187,17 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R } func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client.Request, opts client.CallOptions) (client.Stream, error) { + var header map[string]string + address := node.Address - header := make(map[string]string) if md, ok := metadata.FromContext(ctx); ok { + header = make(map[string]string, len(md)) for k, v := range md { header[k] = v } + } else { + header = make(map[string]string) } // set timeout in nanoseconds diff --git a/client/grpc/response.go b/client/grpc/response.go index 5fd40169..cd4d319c 100644 --- a/client/grpc/response.go +++ b/client/grpc/response.go @@ -27,7 +27,7 @@ func (r *response) Header() map[string]string { if err != nil { return map[string]string{} } - hdr := make(map[string]string) + hdr := make(map[string]string, len(md)) for k, v := range md { hdr[k] = strings.Join(v, ",") } diff --git a/metadata/metadata.go b/metadata/metadata.go index c5fddd84..c41aa284 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -58,7 +58,7 @@ func FromContext(ctx context.Context) (Metadata, bool) { } // capitalise all values - newMD := make(map[string]string) + newMD := make(map[string]string, len(md)) for k, v := range md { newMD[strings.Title(k)] = v } diff --git a/registry/memory/util.go b/registry/memory/util.go index 32a3852d..69cfc8de 100644 --- a/registry/memory/util.go +++ b/registry/memory/util.go @@ -7,12 +7,12 @@ import ( ) func serviceToRecord(s *registry.Service, ttl time.Duration) *record { - metadata := make(map[string]string) + metadata := make(map[string]string, len(s.Metadata)) for k, v := range s.Metadata { metadata[k] = v } - nodes := make(map[string]*node) + nodes := make(map[string]*node, len(s.Nodes)) for _, n := range s.Nodes { nodes[n.Id] = &node{ Node: n, @@ -36,7 +36,7 @@ func serviceToRecord(s *registry.Service, ttl time.Duration) *record { } func recordToService(r *record) *registry.Service { - metadata := make(map[string]string) + metadata := make(map[string]string, len(r.Metadata)) for k, v := range r.Metadata { metadata[k] = v } @@ -52,7 +52,7 @@ func recordToService(r *record) *registry.Service { *response = *e.Response } - metadata := make(map[string]string) + metadata := make(map[string]string, len(e.Metadata)) for k, v := range e.Metadata { metadata[k] = v } @@ -68,7 +68,7 @@ func recordToService(r *record) *registry.Service { nodes := make([]*registry.Node, len(r.Nodes)) i := 0 for _, n := range r.Nodes { - metadata := make(map[string]string) + metadata := make(map[string]string, len(n.Metadata)) for k, v := range n.Metadata { metadata[k] = v } diff --git a/server/grpc/codec.go b/server/grpc/codec.go index eb1b09c5..db9706d0 100644 --- a/server/grpc/codec.go +++ b/server/grpc/codec.go @@ -140,7 +140,7 @@ func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { m = new(codec.Message) } if m.Header == nil { - m.Header = make(map[string]string) + m.Header = make(map[string]string, len(md)) } for k, v := range md { m.Header[k] = strings.Join(v, ",") diff --git a/server/grpc/subscriber.go b/server/grpc/subscriber.go index 7bd39a5f..a29ebcba 100644 --- a/server/grpc/subscriber.go +++ b/server/grpc/subscriber.go @@ -188,7 +188,7 @@ func (g *grpcServer) createSubHandler(sb *subscriber, opts server.Options) broke return err } - hdr := make(map[string]string) + hdr := make(map[string]string, len(msg.Header)) for k, v := range msg.Header { hdr[k] = v } diff --git a/server/rpc_server.go b/server/rpc_server.go index a692573a..eb079883 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -85,7 +85,7 @@ func (s *rpcServer) HandleEvent(e broker.Event) error { } // copy headers - hdr := make(map[string]string) + hdr := make(map[string]string, len(msg.Header)) for k, v := range msg.Header { hdr[k] = v } @@ -262,7 +262,7 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { ct := msg.Header["Content-Type"] // copy the message headers - hdr := make(map[string]string) + hdr := make(map[string]string, len(msg.Header)) for k, v := range msg.Header { hdr[k] = v } diff --git a/transport/http_transport.go b/transport/http_transport.go index ced00c13..3e8bce8b 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -156,7 +156,7 @@ func (h *httpTransportClient) Recv(m *Message) error { m.Body = b if m.Header == nil { - m.Header = make(map[string]string) + m.Header = make(map[string]string, len(rsp.Header)) } for k, v := range rsp.Header { @@ -193,10 +193,6 @@ func (h *httpTransportSocket) Recv(m *Message) error { return errors.New("message passed in is nil") } - if m.Header == nil { - m.Header = make(map[string]string) - } - // process http 1 if h.r.ProtoMajor == 1 { // set timeout if its greater than 0 @@ -228,6 +224,10 @@ func (h *httpTransportSocket) Recv(m *Message) error { r.Body.Close() m.Body = b + if m.Header == nil { + m.Header = make(map[string]string, len(r.Header)) + } + // set headers for k, v := range r.Header { if len(v) > 0 { diff --git a/tunnel/crypto.go b/tunnel/crypto.go index 9f9f45af..e7f5a2f0 100644 --- a/tunnel/crypto.go +++ b/tunnel/crypto.go @@ -5,7 +5,16 @@ import ( "crypto/cipher" "crypto/rand" "crypto/sha256" - "io" + + "github.com/oxtoacart/bpool" +) + +var ( + // the local buffer pool + // gcmStandardNonceSize from crypto/cipher/gcm.go is 12 bytes + // 100 - is max size of pool + noncePool = bpool.NewBytePool(100, 12) + hashPool = bpool.NewBytePool(1024*32, 32) ) // hash hahes the data into 32 bytes key and returns it @@ -13,7 +22,10 @@ import ( func hash(key string) []byte { hasher := sha256.New() hasher.Write([]byte(key)) - return hasher.Sum(nil) + out := hashPool.Get() + defer hashPool.Put(out[:0]) + out = hasher.Sum(out[:0]) + return out } // Encrypt encrypts data and returns the encrypted data @@ -32,12 +44,13 @@ func Encrypt(data []byte, key string) ([]byte, error) { return nil, err } - // create a new byte array the size of the nonce + // get new byte array the size of the nonce from pool // NOTE: we might use smaller nonce size in the future - nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + nonce := noncePool.Get() + if _, err = rand.Read(nonce); err != nil { return nil, err } + defer noncePool.Put(nonce) // NOTE: we prepend the nonce to the payload // we need to do this as we need the same nonce diff --git a/tunnel/default.go b/tunnel/default.go index 523d1978..398f52ad 100644 --- a/tunnel/default.go +++ b/tunnel/default.go @@ -131,6 +131,7 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) { recv: make(chan *message, 128), send: t.send, errChan: make(chan error, 1), + key: t.token + channel + sessionId, } // save session diff --git a/tunnel/listener.go b/tunnel/listener.go index 3aff7b85..4e35360b 100644 --- a/tunnel/listener.go +++ b/tunnel/listener.go @@ -77,6 +77,8 @@ func (t *tunListener) process() { // create a new session session sess = &session{ + // the session key + key: t.token + m.channel + sessionId, // the id of the remote side tunnel: m.tunnel, // the channel diff --git a/tunnel/session.go b/tunnel/session.go index fc5b9be9..6dd90a3d 100644 --- a/tunnel/session.go +++ b/tunnel/session.go @@ -47,6 +47,8 @@ type session struct { link string // the error response errChan chan error + // key for session encryption + key string } // message is sent over the send channel @@ -326,22 +328,22 @@ func (s *session) Announce() error { // Send is used to send a message func (s *session) Send(m *transport.Message) error { // encrypt the transport message payload - body, err := Encrypt(m.Body, s.token+s.channel+s.session) + body, err := Encrypt(m.Body, s.key) if err != nil { log.Debugf("failed to encrypt message body: %v", err) return err } - // make copy + // make copy, without rehash and realloc data := &transport.Message{ - Header: make(map[string]string), + Header: make(map[string]string, len(m.Header)), Body: body, } // encrypt all the headers for k, v := range m.Header { // encrypt the transport message payload - val, err := Encrypt([]byte(v), s.token+s.channel+s.session) + val, err := Encrypt([]byte(v), s.key) if err != nil { log.Debugf("failed to encrypt message header %s: %v", k, err) return err @@ -387,14 +389,14 @@ func (s *session) Recv(m *transport.Message) error { default: } - //log.Tracef("Received %+v from recv backlog", msg) log.Tracef("Received %+v from recv backlog", msg) + key := s.token + s.channel + msg.session // decrypt the received payload using the token // we have to used msg.session because multicast has a shared // session id of "multicast" in this session struct on // the listener side - body, err := Decrypt(msg.data.Body, s.token+s.channel+msg.session) + body, err := Decrypt(msg.data.Body, key) if err != nil { log.Debugf("failed to decrypt message body: %v", err) return err @@ -410,7 +412,7 @@ func (s *session) Recv(m *transport.Message) error { return err } // encrypt the transport message payload - val, err := Decrypt([]byte(h), s.token+s.channel+msg.session) + val, err := Decrypt([]byte(h), key) if err != nil { log.Debugf("failed to decrypt message header %s: %v", k, err) return err