Merge pull request #979 from milosgajdos83/tunnel-encrypt
[WIP] Tunnel encryption
This commit is contained in:
commit
c420fa2dec
72
tunnel/crypto.go
Normal file
72
tunnel/crypto.go
Normal file
@ -0,0 +1,72 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
)
|
||||
|
||||
// hash hahes the data into 32 bytes key and returns it
|
||||
// hash uses sha256 underneath to hash the supplied key
|
||||
func hash(key string) []byte {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(key))
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
// Encrypt encrypts data and returns the encrypted data
|
||||
func Encrypt(data []byte, key string) ([]byte, error) {
|
||||
// generate a new AES cipher using our 32 byte key
|
||||
c, err := aes.NewCipher(hash(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// gcm or Galois/Counter Mode, is a mode of operation
|
||||
// for symmetric key cryptographic block ciphers
|
||||
// - https://en.wikipedia.org/wiki/Galois/Counter_Mode
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create a new byte array the size of the nonce
|
||||
// NOTE: we might use smaller nonce size in the future
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NOTE: we prepend the nonce to the payload
|
||||
// we need to do this as we need the same nonce
|
||||
// to decrypt the payload when receiving it
|
||||
return gcm.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the payload and returns the decrypted data
|
||||
func Decrypt(data []byte, key string) ([]byte, error) {
|
||||
// generate AES cipher for decrypting the message
|
||||
c, err := aes.NewCipher(hash(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we use GCM to encrypt the payload
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
// NOTE: we need to parse out nonce from the payload
|
||||
// we prepend the nonce to every encrypted payload
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
41
tunnel/crypto_test.go
Normal file
41
tunnel/crypto_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
key := "tokenpassphrase"
|
||||
data := []byte("supersecret")
|
||||
|
||||
cipherText, err := Encrypt(data, key)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encrypt data: %v", err)
|
||||
}
|
||||
|
||||
// verify the cipherText is not the same as data
|
||||
if bytes.Equal(data, cipherText) {
|
||||
t.Error("encrypted data are the same as plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
key := "tokenpassphrase"
|
||||
data := []byte("supersecret")
|
||||
|
||||
cipherText, err := Encrypt(data, key)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encrypt data: %v", err)
|
||||
}
|
||||
|
||||
plainText, err := Decrypt(cipherText, key)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decrypt data: %v", err)
|
||||
}
|
||||
|
||||
// verify the plainText is the same as data
|
||||
if !bytes.Equal(data, plainText) {
|
||||
t.Error("decrypted data not the same as plaintext")
|
||||
}
|
||||
}
|
@ -30,7 +30,7 @@ type tun struct {
|
||||
// the unique id for this tunnel
|
||||
id string
|
||||
|
||||
// tunnel token for authentication
|
||||
// tunnel token for session encryption
|
||||
token string
|
||||
|
||||
// to indicate if we're connected or not
|
||||
@ -119,6 +119,7 @@ func (t *tun) newSession(channel, sessionId string) (*session, bool) {
|
||||
tunnel: t.id,
|
||||
channel: channel,
|
||||
session: sessionId,
|
||||
token: t.token,
|
||||
closed: make(chan bool),
|
||||
recv: make(chan *message, 128),
|
||||
send: t.send,
|
||||
@ -159,7 +160,6 @@ func (t *tun) announce(channel, session string, link *link) {
|
||||
"Micro-Tunnel-Channel": channel,
|
||||
"Micro-Tunnel-Session": session,
|
||||
"Micro-Tunnel-Link": link.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
},
|
||||
}
|
||||
|
||||
@ -292,9 +292,6 @@ func (t *tun) process() {
|
||||
// set the session id
|
||||
newMsg.Header["Micro-Tunnel-Session"] = msg.session
|
||||
|
||||
// set the tunnel token
|
||||
newMsg.Header["Micro-Tunnel-Token"] = t.token
|
||||
|
||||
// send the message via the interface
|
||||
t.RLock()
|
||||
|
||||
@ -447,14 +444,11 @@ func (t *tun) listen(link *link) {
|
||||
return
|
||||
}
|
||||
|
||||
// always ensure we have the correct auth token
|
||||
// TODO: segment the tunnel based on token
|
||||
// e.g use it as the basis
|
||||
token := msg.Header["Micro-Tunnel-Token"]
|
||||
if token != t.token {
|
||||
log.Debugf("Tunnel link %s received invalid token %s", token)
|
||||
return
|
||||
}
|
||||
// TODO: figure out network authentication
|
||||
// for now we use tunnel token to encrypt/decrypt
|
||||
// session communication, but we will probably need
|
||||
// some sort of network authentication (token) to avoid
|
||||
// having rogue actors spamming the network
|
||||
|
||||
// message type
|
||||
mtype := msg.Header["Micro-Tunnel"]
|
||||
@ -702,9 +696,8 @@ func (t *tun) discover(link *link) {
|
||||
// send a discovery message to all links
|
||||
if err := link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "discover",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
"Micro-Tunnel": "discover",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Debugf("Tunnel failed to send discover to link %s: %v", link.id, err)
|
||||
@ -733,9 +726,8 @@ func (t *tun) keepalive(link *link) {
|
||||
log.Debugf("Tunnel sending keepalive to link: %v", link.Remote())
|
||||
if err := link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "keepalive",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
"Micro-Tunnel": "keepalive",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Debugf("Error sending keepalive to link %v: %v", link.Remote(), err)
|
||||
@ -765,9 +757,8 @@ func (t *tun) setupLink(node string) (*link, error) {
|
||||
// send the first connect message
|
||||
if err := link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "connect",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
"Micro-Tunnel": "connect",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@ -901,9 +892,8 @@ func (t *tun) close() error {
|
||||
for node, link := range t.links {
|
||||
link.Send(&transport.Message{
|
||||
Header: map[string]string{
|
||||
"Micro-Tunnel": "close",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
"Micro-Tunnel-Token": t.token,
|
||||
"Micro-Tunnel": "close",
|
||||
"Micro-Tunnel-Id": t.id,
|
||||
},
|
||||
})
|
||||
link.Close()
|
||||
@ -1157,6 +1147,8 @@ func (t *tun) Listen(channel string, opts ...ListenOption) (Listener, error) {
|
||||
|
||||
tl := &tunListener{
|
||||
channel: channel,
|
||||
// tunnel token
|
||||
token: t.token,
|
||||
// the accept channel
|
||||
accept: make(chan *session, 128),
|
||||
// the channel to close
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
type tunListener struct {
|
||||
// address of the listener
|
||||
channel string
|
||||
// token is the tunnel token
|
||||
token string
|
||||
// the accept channel
|
||||
accept chan *session
|
||||
// the channel to close
|
||||
@ -78,6 +80,8 @@ func (t *tunListener) process() {
|
||||
channel: m.channel,
|
||||
// the session id
|
||||
session: m.session,
|
||||
// tunnel token
|
||||
token: t.token,
|
||||
// is loopback conn
|
||||
loopback: m.loopback,
|
||||
// the link the message was received on
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
@ -17,6 +18,8 @@ type session struct {
|
||||
channel string
|
||||
// the session id based on Micro.Tunnel-Session
|
||||
session string
|
||||
// token is the session token
|
||||
token string
|
||||
// closed
|
||||
closed chan bool
|
||||
// remote addr
|
||||
@ -301,14 +304,29 @@ func (s *session) Send(m *transport.Message) error {
|
||||
// no op
|
||||
}
|
||||
|
||||
// encrypt the transport message payload
|
||||
body, err := Encrypt(m.Body, s.token+s.channel+s.session)
|
||||
if err != nil {
|
||||
log.Debugf("failed to encrypt message body: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// make copy
|
||||
data := &transport.Message{
|
||||
Header: make(map[string]string),
|
||||
Body: m.Body,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
// encrypt all the headers
|
||||
for k, v := range m.Header {
|
||||
data.Header[k] = v
|
||||
// encrypt the transport message payload
|
||||
val, err := Encrypt([]byte(v), s.token+s.channel+s.session)
|
||||
if err != nil {
|
||||
log.Debugf("failed to encrypt message header %s: %v", k, err)
|
||||
return err
|
||||
}
|
||||
// hex encode the encrypted header value
|
||||
data.Header[k] = hex.EncodeToString(val)
|
||||
}
|
||||
|
||||
// create a new message
|
||||
@ -352,7 +370,35 @@ func (s *session) Recv(m *transport.Message) error {
|
||||
default:
|
||||
}
|
||||
|
||||
log.Tracef("Received %+v from recv backlog", msg)
|
||||
//log.Tracef("Received %+v from recv backlog", msg)
|
||||
log.Debugf("Received %+v from recv backlog", msg)
|
||||
|
||||
// decrypt the received payload using the token
|
||||
body, err := Decrypt(msg.data.Body, s.token+s.channel+s.session)
|
||||
if err != nil {
|
||||
log.Debugf("failed to decrypt message body: %v", err)
|
||||
return err
|
||||
}
|
||||
msg.data.Body = body
|
||||
|
||||
// encrypt all the headers
|
||||
for k, v := range msg.data.Header {
|
||||
// hex decode the header values
|
||||
h, err := hex.DecodeString(v)
|
||||
if err != nil {
|
||||
log.Debugf("failed to decode message header %s: %v", k, err)
|
||||
return err
|
||||
}
|
||||
// encrypt the transport message payload
|
||||
val, err := Decrypt([]byte(h), s.token+s.channel+s.session)
|
||||
if err != nil {
|
||||
log.Debugf("failed to decrypt message header %s: %v", k, err)
|
||||
return err
|
||||
}
|
||||
// hex encode the encrypted header value
|
||||
msg.data.Header[k] = string(val)
|
||||
}
|
||||
|
||||
// set message
|
||||
*m = *msg.data
|
||||
// return nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user