diff --git a/client.go b/client.go index ddad1c7..5938a3a 100644 --- a/client.go +++ b/client.go @@ -16,31 +16,34 @@ var DefaultServerMessages = []ServerMessage{ &ServerCutText{}, } +var ( + DefaultClientHandlers []ClientHandler = []ClientHandler{ + &DefaultClientVersionHandler{}, + &DefaultClientSecurityHandler{}, + &DefaultClientClientInitHandler{}, + &DefaultClientServerInitHandler{}, + // &DefaultClientMessageHandler{}, + } +) + func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, error) { conn, err := NewClientConn(c, cfg) if err != nil { conn.Close() + cfg.ErrorCh <- err return nil, err } - if err := cfg.VersionHandler(cfg, conn); err != nil { - conn.Close() - return nil, err + if len(cfg.Handlers) == 0 { + cfg.Handlers = DefaultClientHandlers } - if err := cfg.SecurityHandler(cfg, conn); err != nil { - conn.Close() - return nil, err - } - - if err := cfg.ClientInitHandler(cfg, conn); err != nil { - conn.Close() - return nil, err - } - - if err := cfg.ServerInitHandler(cfg, conn); err != nil { - conn.Close() - return nil, err + for _, h := range cfg.Handlers { + if err := h.Handle(conn); err != nil { + conn.Close() + cfg.ErrorCh <- err + return nil, err + } } return conn, nil @@ -48,6 +51,10 @@ func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, e var _ Conn = (*ClientConn)(nil) +func (c *ClientConn) Config() interface{} { + return c.cfg +} + func (c *ClientConn) Conn() net.Conn { return c.c } @@ -89,13 +96,13 @@ func (c *ClientConn) ColorMap() *ColorMap { func (c *ClientConn) SetColorMap(cm *ColorMap) { c.colorMap = cm } -func (c *ClientConn) DesktopName() string { +func (c *ClientConn) DesktopName() []byte { return c.desktopName } func (c *ClientConn) PixelFormat() *PixelFormat { return c.pixelFormat } -func (c *ClientConn) SetDesktopName(name string) { +func (c *ClientConn) SetDesktopName(name []byte) { c.desktopName = name } func (c *ClientConn) SetPixelFormat(pf *PixelFormat) error { @@ -136,7 +143,7 @@ type ClientConn struct { colorMap *ColorMap // Name associated with the desktop, sent from the server. - desktopName string + desktopName []byte // Encodings supported by the client. This should not be modified // directly. Instead, SetEncodings() should be used. @@ -153,7 +160,8 @@ type ClientConn struct { // SetPixelFormat method. pixelFormat *PixelFormat - quit chan struct{} + quitCh chan struct{} + errorCh chan error } func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { @@ -169,7 +177,8 @@ func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { br: bufio.NewReader(c), bw: bufio.NewWriter(c), encodings: cfg.Encodings, - quit: make(chan struct{}), + quitCh: cfg.QuitCh, + errorCh: cfg.ErrorCh, pixelFormat: cfg.PixelFormat, }, nil } @@ -422,76 +431,82 @@ func (msg *ClientCutText) Write(c Conn) error { return c.Flush() } -// ListenAndHandle listens to a VNC server and handles server messages. -func (c *ClientConn) Handle() error { +type DefaultClientMessageHandler struct{} + +// listens to a VNC server and handles server messages. +func (*DefaultClientMessageHandler) Handle(c Conn) error { + cfg := c.Config().(*ClientConfig) var err error var wg sync.WaitGroup wg.Add(2) defer c.Close() serverMessages := make(map[ServerMessageType]ServerMessage) - for _, m := range c.cfg.ServerMessages { + for _, m := range cfg.ServerMessages { serverMessages[m.Type()] = m } - go func() error { + go func() { defer wg.Done() for { select { - case msg := <-c.cfg.ClientMessageCh: + case msg := <-cfg.ClientMessageCh: if err = msg.Write(c); err != nil { - return err + cfg.ErrorCh <- err + return } - case <-c.quit: - return nil } } }() - go func() error { + go func() { defer wg.Done() for { select { - case <-c.quit: - return nil default: var messageType ServerMessageType if err = binary.Read(c, binary.BigEndian, &messageType); err != nil { - return err + cfg.ErrorCh <- err + return } msg, ok := serverMessages[messageType] if !ok { - return fmt.Errorf("unknown message-type: %v", messageType) + err = fmt.Errorf("unknown message-type: %v", messageType) + cfg.ErrorCh <- err + return } parsedMsg, err := msg.Read(c) if err != nil { - return err + cfg.ErrorCh <- err + return } - c.cfg.ServerMessageCh <- parsedMsg + cfg.ServerMessageCh <- parsedMsg } } }() + wg.Wait() - return err + return nil } -type ClientHandler func(*ClientConfig, Conn) error +type ClientHandler interface { + Handle(Conn) error +} // A ClientConfig structure is used to configure a ClientConn. After // one has been passed to initialize a connection, it must not be modified. type ClientConfig struct { - VersionHandler ClientHandler - SecurityHandler ClientHandler - SecurityHandlers []SecurityHandler - ClientInitHandler ClientHandler - ServerInitHandler ClientHandler - Encodings []Encoding - PixelFormat *PixelFormat - ColorMap *ColorMap - ClientMessageCh chan ClientMessage - ServerMessageCh chan ServerMessage - Exclusive bool - ServerMessages []ServerMessage + Handlers []ClientHandler + SecurityHandlers []SecurityHandler + Encodings []Encoding + PixelFormat *PixelFormat + ColorMap *ColorMap + ClientMessageCh chan ClientMessage + ServerMessageCh chan ServerMessage + Exclusive bool + ServerMessages []ServerMessage + QuitCh chan struct{} + ErrorCh chan error } diff --git a/conn.go b/conn.go index ef572e1..0b01f2b 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( type Conn interface { io.ReadWriteCloser Conn() net.Conn + Config() interface{} Protocol() string PixelFormat() *PixelFormat SetPixelFormat(*PixelFormat) error @@ -19,8 +20,8 @@ type Conn interface { Height() uint16 SetWidth(uint16) SetHeight(uint16) - DesktopName() string - SetDesktopName(string) + DesktopName() []byte + SetDesktopName([]byte) Flush() error SetProtoVersion(string) } diff --git a/handlers.go b/handlers.go index c9ac065..b89dc7c 100644 --- a/handlers.go +++ b/handlers.go @@ -45,7 +45,9 @@ func ParseProtoVersion(pv []byte) (uint, uint, error) { return major, minor, nil } -func ClientVersionHandler(cfg *ClientConfig, c Conn) error { +type DefaultClientVersionHandler struct{} + +func (*DefaultClientVersionHandler) Handle(c Conn) error { var version [ProtoVersionLength]byte if err := binary.Read(c, binary.BigEndian, &version); err != nil { @@ -76,7 +78,9 @@ func ClientVersionHandler(cfg *ClientConfig, c Conn) error { return c.Flush() } -func ServerVersionHandler(cfg *ServerConfig, c Conn) error { +type DefaultServerVersionHandler struct{} + +func (*DefaultServerVersionHandler) Handle(c Conn) error { var version [ProtoVersionLength]byte if err := binary.Write(c, binary.BigEndian, []byte(ProtoVersion38)); err != nil { return err @@ -109,7 +113,10 @@ func ServerVersionHandler(cfg *ServerConfig, c Conn) error { return nil } -func ClientSecurityHandler(cfg *ClientConfig, c Conn) error { +type DefaultClientSecurityHandler struct{} + +func (*DefaultClientSecurityHandler) Handle(c Conn) error { + cfg := c.Config().(*ClientConfig) var numSecurityTypes uint8 if err := binary.Read(c, binary.BigEndian, &numSecurityTypes); err != nil { return err @@ -161,7 +168,10 @@ func ClientSecurityHandler(cfg *ClientConfig, c Conn) error { return nil } -func ServerSecurityHandler(cfg *ServerConfig, c Conn) error { +type DefaultServerSecurityHandler struct{} + +func (*DefaultServerSecurityHandler) Handle(c Conn) error { + cfg := c.Config().(*ServerConfig) if err := binary.Write(c, binary.BigEndian, uint8(len(cfg.SecurityHandlers))); err != nil { return err } @@ -220,7 +230,9 @@ func ServerSecurityHandler(cfg *ServerConfig, c Conn) error { return nil } -func ClientServerInitHandler(cfg *ClientConfig, c Conn) error { +type DefaultClientServerInitHandler struct{} + +func (*DefaultClientServerInitHandler) Handle(c Conn) error { srvInit := &ServerInit{} if err := binary.Read(c, binary.BigEndian, &srvInit.FBWidth); err != nil { @@ -242,42 +254,55 @@ func ClientServerInitHandler(cfg *ClientConfig, c Conn) error { } srvInit.NameText = nameText - c.SetDesktopName(string(srvInit.NameText)) + c.SetDesktopName(srvInit.NameText) c.SetWidth(srvInit.FBWidth) c.SetHeight(srvInit.FBHeight) c.SetPixelFormat(&srvInit.PixelFormat) + + if c.Protocol() == "aten" { + fmt.Printf("$$$$$$\n") + var pad [28]byte + /* 12 + 8 byte unknown + 1 byte IKVMVideoEnable + 1 byte IKVMKMEnable + 1 byte IKVMKickEnable + 1 byte VUSBEnable + */ + if err := binary.Read(c, binary.BigEndian, &pad); err != nil { + return err + } + fmt.Printf("rrrr\n") + } return nil } -func ServerServerInitHandler(cfg *ServerConfig, c Conn) error { - srvInit := &ServerInit{ - FBWidth: c.Width(), - FBHeight: c.Height(), - PixelFormat: *c.PixelFormat(), - NameLength: uint32(len(cfg.DesktopName)), - NameText: []byte(cfg.DesktopName), - } +type DefaultServerServerInitHandler struct{} - if err := binary.Write(c, binary.BigEndian, srvInit.FBWidth); err != nil { +func (*DefaultServerServerInitHandler) Handle(c Conn) error { + if err := binary.Write(c, binary.BigEndian, c.Width()); err != nil { return err } - if err := binary.Write(c, binary.BigEndian, srvInit.FBHeight); err != nil { + if err := binary.Write(c, binary.BigEndian, c.Height()); err != nil { return err } - if err := binary.Write(c, binary.BigEndian, srvInit.PixelFormat); err != nil { + if err := binary.Write(c, binary.BigEndian, c.PixelFormat()); err != nil { return err } - if err := binary.Write(c, binary.BigEndian, srvInit.NameLength); err != nil { + if err := binary.Write(c, binary.BigEndian, uint32(len(c.DesktopName()))); err != nil { return err } - if err := binary.Write(c, binary.BigEndian, srvInit.NameText); err != nil { + if err := binary.Write(c, binary.BigEndian, []byte(c.DesktopName())); err != nil { return err } return c.Flush() } -func ClientClientInitHandler(cfg *ClientConfig, c Conn) error { +type DefaultClientClientInitHandler struct{} + +func (*DefaultClientClientInitHandler) Handle(c Conn) error { + cfg := c.Config().(*ClientConfig) var shared uint8 if cfg.Exclusive { shared = 0 @@ -290,7 +315,9 @@ func ClientClientInitHandler(cfg *ClientConfig, c Conn) error { return c.Flush() } -func ServerClientInitHandler(cfg *ServerConfig, c Conn) error { +type DefaultServerClientInitHandler struct{} + +func (*DefaultServerClientInitHandler) Handle(c Conn) error { var shared uint8 if err := binary.Read(c, binary.BigEndian, &shared); err != nil { return err diff --git a/security.go b/security.go index 86eb989..08c8e05 100644 --- a/security.go +++ b/security.go @@ -73,6 +73,10 @@ func (*ServerAuthNone) Type() SecurityType { return SecTypeNone } +func (*ServerAuthNone) SubType() SecuritySubType { + return SecSubTypeUnknown +} + func (*ServerAuthNone) Auth(c Conn) error { return nil } @@ -115,7 +119,7 @@ func (auth *ClientAuthATEN) Auth(c Conn) error { return err } if (nt&0xffff0ff0)>>0 == 0xaff90fb0 { - fmt.Printf("aten\n") + c.SetProtoVersion("aten") var skip [20]byte binary.Read(c, binary.BigEndian, &skip) fmt.Printf("skip %s\n", skip) @@ -158,13 +162,22 @@ func (auth *ClientAuthATEN) Auth(c Conn) error { sendPassword[i] = 0 } } + if err := binary.Write(c, binary.BigEndian, sendUsername); err != nil { return err } if err := binary.Write(c, binary.BigEndian, sendPassword); err != nil { return err } - return c.Flush() + + if err := c.Flush(); err != nil { + return err + } + + //var pp [10]byte + //binary.Read(c, binary.BigEndian, &pp) + //fmt.Printf("ddd %v\n", pp) + return nil } func (*ClientAuthVeNCrypt02Plain) Type() SecurityType { @@ -270,7 +283,11 @@ func (auth *ClientAuthVeNCrypt02Plain) Auth(c Conn) error { } // ServerAuthVNC is the standard password authentication. See 7.2.2. -type ServerAuthVNC struct{} +type ServerAuthVNC struct { + Challenge []byte + Password []byte + Crypted []byte +} func (*ServerAuthVNC) Type() SecurityType { return SecTypeVNC @@ -279,13 +296,44 @@ func (*ServerAuthVNC) SubType() SecuritySubType { return SecSubTypeUnknown } +func (auth *ServerAuthVNC) WriteChallenge(c Conn) error { + if err := binary.Write(c, binary.BigEndian, auth.Challenge); err != nil { + return err + } + return c.Flush() +} + +func (auth *ServerAuthVNC) ReadChallenge(c Conn) error { + var crypted [16]byte + if err := binary.Read(c, binary.BigEndian, &crypted); err != nil { + return err + } + auth.Crypted = crypted[:] + return nil +} + func (auth *ServerAuthVNC) Auth(c Conn) error { + if err := auth.WriteChallenge(c); err != nil { + return err + } + + if err := auth.ReadChallenge(c); err != nil { + return err + } + + encrypted, err := AuthVNCEncode(auth.Password, auth.Challenge) + if err != nil { + return err + } + if !bytes.Equal(encrypted, auth.Crypted) { + return fmt.Errorf("password invalid") + } return nil } // ClientAuthVNC is the standard password authentication. See 7.2.2. type ClientAuthVNC struct { - Challenge [16]byte + Challenge []byte Password []byte } @@ -300,25 +348,34 @@ func (auth *ClientAuthVNC) Auth(c Conn) error { if len(auth.Password) == 0 { return fmt.Errorf("Security Handshake failed; no password provided for VNCAuth.") } - - if err := binary.Read(c, binary.BigEndian, auth.Challenge); err != nil { + var challenge [16]byte + if err := binary.Read(c, binary.BigEndian, &challenge); err != nil { return err } - auth.encode() + crypted, err := AuthVNCEncode(auth.Password, challenge[:]) + if err != nil { + return err + } // Send the encrypted challenge back to server - if err := binary.Write(c, binary.BigEndian, auth.Challenge); err != nil { + if err := binary.Write(c, binary.BigEndian, crypted); err != nil { return err } return c.Flush() } -func (auth *ClientAuthVNC) encode() error { +func AuthVNCEncode(password []byte, challenge []byte) ([]byte, error) { + if len(password) > 8 { + return nil, fmt.Errorf("password too long") + } + if len(challenge) != 16 { + return nil, fmt.Errorf("challenge size not 16 byte long") + } // Copy password string to 8 byte 0-padded slice key := make([]byte, 8) - copy(key, auth.Password) + copy(key, password) // Each byte of the password needs to be reversed. This is a // non RFC-documented behaviour of VNC clients and servers @@ -331,11 +388,11 @@ func (auth *ClientAuthVNC) encode() error { // Encrypt challenge with key. cipher, err := des.NewCipher(key) if err != nil { - return err + return nil, err } - for i := 0; i < len(auth.Challenge); i += cipher.BlockSize() { - cipher.Encrypt(auth.Challenge[i:i+cipher.BlockSize()], auth.Challenge[i:i+cipher.BlockSize()]) + for i := 0; i < len(challenge); i += cipher.BlockSize() { + cipher.Encrypt(challenge[i:i+cipher.BlockSize()], challenge[i:i+cipher.BlockSize()]) } - return nil + return challenge, nil } diff --git a/server.go b/server.go index f664902..caeaa24 100644 --- a/server.go +++ b/server.go @@ -27,6 +27,10 @@ type ServerInit struct { var _ Conn = (*ServerConn)(nil) +func (c *ServerConn) Config() interface{} { + return c.cfg +} + func (c *ServerConn) Conn() net.Conn { return c.c } @@ -71,13 +75,13 @@ func (c *ServerConn) ColorMap() *ColorMap { func (c *ServerConn) SetColorMap(cm *ColorMap) { c.colorMap = cm } -func (c *ServerConn) DesktopName() string { +func (c *ServerConn) DesktopName() []byte { return c.desktopName } func (c *ServerConn) PixelFormat() *PixelFormat { return c.pixelFormat } -func (c *ServerConn) SetDesktopName(name string) { +func (c *ServerConn) SetDesktopName(name []byte) { c.desktopName = name } func (c *ServerConn) SetPixelFormat(pf *PixelFormat) error { @@ -182,7 +186,7 @@ type ServerConn struct { colorMap *ColorMap // Name associated with the desktop, sent from the server. - desktopName string + desktopName []byte // Encodings supported by the client. This should not be modified // directly. Instead, SetEncodings() should be used. @@ -198,28 +202,35 @@ type ServerConn struct { // be modified. If you wish to set a new pixel format, use the // SetPixelFormat method. pixelFormat *PixelFormat - - quit chan struct{} } -type ServerHandler func(*ServerConfig, Conn) error +type ServerHandler interface { + Handle(Conn) error +} + +var ( + DefaultServerHandlers []ServerHandler = []ServerHandler{ + &DefaultServerVersionHandler{}, + &DefaultServerSecurityHandler{}, + &DefaultServerClientInitHandler{}, + &DefaultServerServerInitHandler{}, + &DefaultServerMessageHandler{}, + } +) type ServerConfig struct { - VersionHandler ServerHandler - SecurityHandler ServerHandler - SecurityHandlers []SecurityHandler - ClientInitHandler ServerHandler - ServerInitHandler ServerHandler - Encodings []Encoding - PixelFormat *PixelFormat - ColorMap *ColorMap - ClientMessageCh chan ClientMessage - ServerMessageCh chan ServerMessage - ErrorCh chan error - ClientMessages []ClientMessage - DesktopName []byte - Height uint16 - Width uint16 + Handlers []ServerHandler + SecurityHandlers []SecurityHandler + Encodings []Encoding + PixelFormat *PixelFormat + ColorMap *ColorMap + ClientMessageCh chan ClientMessage + ServerMessageCh chan ServerMessage + ClientMessages []ClientMessage + DesktopName []byte + Height uint16 + Width uint16 + errorCh chan error } func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) { @@ -236,7 +247,7 @@ func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) { br: bufio.NewReader(c), bw: bufio.NewWriter(c), cfg: cfg, - quit: make(chan struct{}), + desktopName: cfg.DesktopName, encodings: cfg.Encodings, pixelFormat: cfg.PixelFormat, fbWidth: cfg.Width, @@ -256,80 +267,69 @@ func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error { if err != nil { continue } - - if err := cfg.VersionHandler(cfg, conn); err != nil { - conn.Close() - continue + if len(cfg.Handlers) == 0 { + cfg.Handlers = DefaultServerHandlers } - if err := cfg.SecurityHandler(cfg, conn); err != nil { - conn.Close() - continue + for _, h := range cfg.Handlers { + if err := h.Handle(conn); err != nil { + conn.Close() + continue + } } - - if err := cfg.ClientInitHandler(cfg, conn); err != nil { - conn.Close() - continue - } - - if err := cfg.ServerInitHandler(cfg, conn); err != nil { - conn.Close() - continue - } - - go conn.Handle() } } -func (c *ServerConn) Handle() error { +type DefaultServerMessageHandler struct{} + +func (*DefaultServerMessageHandler) Handle(c Conn) error { + cfg := c.Config().(*ServerConfig) var err error var wg sync.WaitGroup defer c.Close() clientMessages := make(map[ClientMessageType]ClientMessage) - for _, m := range c.cfg.ClientMessages { + for _, m := range cfg.ClientMessages { clientMessages[m.Type()] = m } wg.Add(2) // server - go func() error { + go func() { defer wg.Done() for { select { - case msg := <-c.cfg.ServerMessageCh: + case msg := <-cfg.ServerMessageCh: if err = msg.Write(c); err != nil { - return err + cfg.errorCh <- err + return } - case <-c.quit: - return nil } } }() // client - go func() error { + go func() { defer wg.Done() for { select { - case <-c.quit: - return nil default: var messageType ClientMessageType if err := binary.Read(c, binary.BigEndian, &messageType); err != nil { - return err + cfg.errorCh <- err + return } msg, ok := clientMessages[messageType] if !ok { - return fmt.Errorf("unsupported message-type: %v", messageType) - + cfg.errorCh <- fmt.Errorf("unsupported message-type: %v", messageType) + return } parsedMsg, err := msg.Read(c) if err != nil { - fmt.Printf("srv err %s\n", err.Error()) - return err + cfg.errorCh <- err + return } - c.cfg.ClientMessageCh <- parsedMsg + cfg.ClientMessageCh <- parsedMsg } } }()