diff --git a/.gitignore b/.gitignore index 5531d53..cdf7bbf 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ .glide/ example/client/client example/server/server +example/proxy/proxy diff --git a/client.go b/client.go index 2dc3396..8897c91 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "fmt" "net" + "sync" ) var DefaultServerMessages = []ServerMessage{ @@ -21,41 +22,58 @@ func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, e conn.Close() return nil, err } + if err := cfg.VersionHandler(cfg, conn); err != nil { conn.Close() return nil, err } + if err := cfg.SecurityHandler(cfg, conn); err != nil { conn.Close() return nil, err } + if err := cfg.ClientInitHandler(cfg, conn); err != nil { conn.Close() return nil, err } + if err := cfg.ServerInitHandler(cfg, conn); err != nil { conn.Close() return nil, err } - /* - // Send client-to-server messages. - encs := conn.encodings - if err := conn.SetEncodings(encs); err != nil { - conn.Close() - return nil, Errorf("failure calling SetEncodings; %s", err) - } - pf := conn.pixelFormat - if err := conn.SetPixelFormat(pf); err != nil { - conn.Close() - return nil, Errorf("failure calling SetPixelFormat; %s", err) - } - */ + return conn, nil } var _ Conn = (*ClientConn)(nil) +func (c *ClientConn) Conn() net.Conn { + return c.c +} + +func (c *ClientConn) SetProtoVersion(pv string) { + c.protocol = pv +} + +func (c *ClientConn) SetEncodings(encs []EncodingType) error { + + msg := &SetEncodings{ + MsgType: SetEncodingsMsgType, + EncNum: uint16(len(encs)), + Encodings: encs, + } + + return msg.Write(c) +} + +func (c *ClientConn) UnreadByte() error { + return c.br.UnreadByte() +} + func (c *ClientConn) Flush() error { + c.m.Lock() + defer c.m.Unlock() return c.bw.Flush() } @@ -68,6 +86,8 @@ func (c *ClientConn) Read(buf []byte) (int, error) { } func (c *ClientConn) Write(buf []byte) (int, error) { + c.m.Lock() + defer c.m.Unlock() return c.bw.Write(buf) } @@ -84,6 +104,13 @@ func (c *ClientConn) DesktopName() string { func (c *ClientConn) PixelFormat() *PixelFormat { return c.pixelFormat } +func (c *ClientConn) SetDesktopName(name string) { + c.desktopName = name +} +func (c *ClientConn) SetPixelFormat(pf *PixelFormat) error { + c.pixelFormat = pf + return nil +} func (c *ClientConn) Encodings() []Encoding { return c.encodings } @@ -110,7 +137,7 @@ type ClientConn struct { bw *bufio.Writer cfg *ClientConfig protocol string - + m sync.Mutex // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since // the data comes from the server. @@ -151,6 +178,7 @@ func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { br: bufio.NewReader(c), bw: bufio.NewWriter(c), encodings: cfg.Encodings, + quit: make(chan struct{}), pixelFormat: cfg.PixelFormat, }, nil } @@ -179,7 +207,7 @@ type SetPixelFormat struct { } func (msg *SetPixelFormat) Type() ClientMessageType { - return msg.MsgType + return SetPixelFormatMsgType } func (msg *SetPixelFormat) Write(c Conn) error { @@ -201,7 +229,7 @@ func (msg *SetPixelFormat) Write(c Conn) error { } func (msg *SetPixelFormat) Read(c Conn) error { - return binary.Read(c, binary.BigEndian, msg) + return binary.Read(c, binary.BigEndian, &msg) } // SetEncodings holds the wire format message, sans encoding-type field. @@ -209,34 +237,33 @@ type SetEncodings struct { MsgType ClientMessageType _ [1]byte // padding EncNum uint16 // number-of-encodings - Encodings []Encoding + Encodings []EncodingType } func (msg *SetEncodings) Type() ClientMessageType { - return msg.MsgType + return SetEncodingsMsgType } func (msg *SetEncodings) Read(c Conn) error { - if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil { + if err := binary.Read(c, binary.BigEndian, &msg.MsgType); err != nil { return err } - var pad [1]byte if err := binary.Read(c, binary.BigEndian, &pad); err != nil { return err } - if err := binary.Read(c, binary.BigEndian, msg.EncNum); err != nil { + if err := binary.Read(c, binary.BigEndian, &msg.EncNum); err != nil { return err } - - var enc Encoding + var enc EncodingType for i := uint16(0); i < msg.EncNum; i++ { if err := binary.Read(c, binary.BigEndian, &enc); err != nil { return err } msg.Encodings = append(msg.Encodings, enc) } + c.SetEncodings(msg.Encodings) return nil } @@ -275,15 +302,18 @@ type FramebufferUpdateRequest struct { } func (msg *FramebufferUpdateRequest) Type() ClientMessageType { - return msg.MsgType + return FramebufferUpdateRequestMsgType } func (msg *FramebufferUpdateRequest) Read(c Conn) error { - return binary.Read(c, binary.BigEndian, msg) + return binary.Read(c, binary.BigEndian, &msg) } func (msg *FramebufferUpdateRequest) Write(c Conn) error { - return binary.Write(c, binary.BigEndian, msg) + if err := binary.Write(c, binary.BigEndian, msg); err != nil { + return err + } + return c.Flush() } // KeyEvent holds the wire format message. @@ -295,15 +325,18 @@ type KeyEvent struct { } func (msg *KeyEvent) Type() ClientMessageType { - return msg.MsgType + return KeyEventMsgType } func (msg *KeyEvent) Read(c Conn) error { - return binary.Read(c, binary.BigEndian, msg) + return binary.Read(c, binary.BigEndian, &msg) } func (msg *KeyEvent) Write(c Conn) error { - return binary.Write(c, binary.BigEndian, msg) + if err := binary.Write(c, binary.BigEndian, msg); err != nil { + return err + } + return c.Flush() } // PointerEventMessage holds the wire format message. @@ -314,15 +347,18 @@ type PointerEvent struct { } func (msg *PointerEvent) Type() ClientMessageType { - return msg.MsgType + return PointerEventMsgType } func (msg *PointerEvent) Read(c Conn) error { - return binary.Read(c, binary.BigEndian, msg) + return binary.Read(c, binary.BigEndian, &msg) } func (msg *PointerEvent) Write(c Conn) error { - return binary.Write(c, binary.BigEndian, msg) + if err := binary.Write(c, binary.BigEndian, msg); err != nil { + return err + } + return c.Flush() } // ClientCutText holds the wire format message, sans the text field. @@ -334,11 +370,11 @@ type ClientCutText struct { } func (msg *ClientCutText) Type() ClientMessageType { - return msg.MsgType + return ClientCutTextMsgType } func (msg *ClientCutText) Read(c Conn) error { - if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil { + if err := binary.Read(c, binary.BigEndian, &msg.MsgType); err != nil { return err } @@ -347,7 +383,7 @@ func (msg *ClientCutText) Read(c Conn) error { return err } - if err := binary.Read(c, binary.BigEndian, msg.Length); err != nil { + if err := binary.Read(c, binary.BigEndian, &msg.Length); err != nil { return err } @@ -387,7 +423,8 @@ func (msg *ClientCutText) Write(c Conn) error { // ListenAndHandle listens to a VNC server and handles server messages. func (c *ClientConn) Handle() error { var err error - + var wg sync.WaitGroup + wg.Add(2) defer c.Close() serverMessages := make(map[ServerMessageType]ServerMessage) @@ -395,41 +432,50 @@ func (c *ClientConn) Handle() error { serverMessages[m.Type()] = m } -clientLoop: - for { - select { - case msg := <-c.cfg.ServerMessageCh: - if err = msg.Write(c); err != nil { - return err + go func() error { + defer wg.Done() + for { + select { + case msg := <-c.cfg.ClientMessageCh: + if err = msg.Write(c); err != nil { + return err + } + case <-c.quit: + return nil } - case <-c.quit: - break clientLoop } - } + }() -serverLoop: - for { - select { - case <-c.quit: - break serverLoop - default: - var messageType ServerMessageType - if err = binary.Read(c, binary.BigEndian, &messageType); err != nil { - break serverLoop + go func() error { + defer wg.Done() + for { + select { + case <-c.quit: + return nil + default: + var messageType ServerMessageType + if err = binary.Read(c, binary.BigEndian, &messageType); err != nil { + return err + } + if err := c.UnreadByte(); err != nil { + return err + } + msg, ok := serverMessages[messageType] + if !ok { + return fmt.Errorf("unknown message-type: %v", messageType) + } + if err = msg.Read(c); err != nil { + return err + } + if c.cfg.ServerMessageCh == nil { + continue + } + c.cfg.ServerMessageCh <- msg } - msg, ok := serverMessages[messageType] - if !ok { - break serverLoop - } - if err = msg.Read(c); err != nil { - break serverLoop - } - if c.cfg.ServerMessageCh == nil { - continue - } - c.cfg.ServerMessageCh <- msg } - } + }() + wg.Wait() + fmt.Printf("tttt\n") return err } @@ -440,7 +486,7 @@ type ClientHandler func(*ClientConfig, Conn) error type ClientConfig struct { VersionHandler ClientHandler SecurityHandler ClientHandler - SecurityHandlers []ClientHandler + SecurityHandlers []SecurityHandler ClientInitHandler ClientHandler ServerInitHandler ClientHandler Encodings []Encoding diff --git a/conn.go b/conn.go index 01abb52..ef572e1 100644 --- a/conn.go +++ b/conn.go @@ -1,18 +1,26 @@ package vnc -import "io" +import ( + "io" + "net" +) type Conn interface { io.ReadWriteCloser + Conn() net.Conn Protocol() string PixelFormat() *PixelFormat + SetPixelFormat(*PixelFormat) error ColorMap() *ColorMap SetColorMap(*ColorMap) Encodings() []Encoding + SetEncodings([]EncodingType) error Width() uint16 Height() uint16 SetWidth(uint16) SetHeight(uint16) DesktopName() string + SetDesktopName(string) Flush() error + SetProtoVersion(string) } diff --git a/encoding.go b/encoding.go index 417acff..020cc74 100644 --- a/encoding.go +++ b/encoding.go @@ -1,19 +1,37 @@ package vnc +import ( + "bytes" + "encoding/binary" + "fmt" +) + // EncodingType represents a known VNC encoding type. type EncodingType int32 //go:generate stringer -type=EncodingType const ( - EncRaw EncodingType = 0 - EncCopyRect EncodingType = 1 - EncRRE EncodingType = 2 - EncHextile EncodingType = 5 + EncRaw EncodingType = 0 + EncCopyRect EncodingType = 1 + EncRRE EncodingType = 2 + EncCoRRE EncodingType = 4 + EncHextile EncodingType = 5 + EncZlib EncodingType = 6 + EncTight EncodingType = 7 + EncZlibHex EncodingType = 8 + EncUltra1 EncodingType = 9 + EncUltra2 EncodingType = 10 + EncJPEG EncodingType = 21 + EncJRLE EncodingType = 22 + //EncRichCursor EncodingType = 0xFFFFFF11 + //EncPointerPos EncodingType = 0xFFFFFF18 + //EncLastRec EncodingType = 0xFFFFFF20 EncTRLE EncodingType = 15 EncZRLE EncodingType = 16 EncColorPseudo EncodingType = -239 EncDesktopSizePseudo EncodingType = -223 + EncClientRedirect EncodingType = -311 ) type Encoding interface { @@ -44,29 +62,30 @@ func (enc *RawEncoding) Write(c Conn, rect *Rectangle) error { // Read implements the Encoding interface. func (enc *RawEncoding) Read(c Conn, rect *Rectangle) error { - /* - var buf bytes.Buffer - pf := c.PixelFormat() - cm := c.ColorMap() - bytesPerPixel := int(pf.BPP / 8) - n := rect.Area() * bytesPerPixel - if err := c.receiveN(&buf, n); err != nil { - return fmt.Errorf("unable to read rectangle with raw encoding: %s", err) - } - - colors := make([]Color, rect.Area()) - for y := uint16(0); y < rect.Height; y++ { - for x := uint16(0); x < rect.Width; x++ { - color := NewColor(pf, cm) - if err := color.Unmarshal(buf.Next(bytesPerPixel)); err != nil { - return nil, err - } - colors[int(y)*int(rect.Width)+int(x)] = *color + buf := bytes.NewBuffer(nil) + pf := c.PixelFormat() + cm := c.ColorMap() + bytesPerPixel := int(pf.BPP / 8) + n := rect.Area() * bytesPerPixel + data := make([]byte, n) + fmt.Printf("eeee\n") + if err := binary.Read(c, binary.BigEndian, &data); err != nil { + return err + } + buf.Write(data) + defer buf.Reset() + colors := make([]Color, rect.Area()) + for y := uint16(0); y < rect.Height; y++ { + for x := uint16(0); x < rect.Width; x++ { + color := NewColor(pf, cm) + if err := color.Unmarshal(buf.Next(bytesPerPixel)); err != nil { + return err } + colors[int(y)*int(rect.Width)+int(x)] = *color } + } - return &RawEncoding{colors}, nil - */ + enc.Colors = colors return nil } diff --git a/encodingtype_string.go b/encodingtype_string.go index fded0b8..ac99e3e 100644 --- a/encodingtype_string.go +++ b/encodingtype_string.go @@ -5,34 +5,44 @@ package vnc import "fmt" const ( - _EncodingType_name_0 = "EncColorPseudo" - _EncodingType_name_1 = "EncDesktopSizePseudo" - _EncodingType_name_2 = "EncRawEncCopyRectEncRRE" - _EncodingType_name_3 = "EncHextile" - _EncodingType_name_4 = "EncTRLEEncZRLE" + _EncodingType_name_0 = "EncClientRedirect" + _EncodingType_name_1 = "EncColorPseudo" + _EncodingType_name_2 = "EncDesktopSizePseudo" + _EncodingType_name_3 = "EncRawEncCopyRectEncRRE" + _EncodingType_name_4 = "EncCoRREEncHextileEncZlibEncTightEncZlibHexEncUltra1EncUltra2" + _EncodingType_name_5 = "EncTRLEEncZRLE" + _EncodingType_name_6 = "EncJPEGEncJRLE" ) var ( - _EncodingType_index_0 = [...]uint8{0, 14} - _EncodingType_index_1 = [...]uint8{0, 20} - _EncodingType_index_2 = [...]uint8{0, 6, 17, 23} - _EncodingType_index_3 = [...]uint8{0, 10} - _EncodingType_index_4 = [...]uint8{0, 7, 14} + _EncodingType_index_0 = [...]uint8{0, 17} + _EncodingType_index_1 = [...]uint8{0, 14} + _EncodingType_index_2 = [...]uint8{0, 20} + _EncodingType_index_3 = [...]uint8{0, 6, 17, 23} + _EncodingType_index_4 = [...]uint8{0, 8, 18, 25, 33, 43, 52, 61} + _EncodingType_index_5 = [...]uint8{0, 7, 14} + _EncodingType_index_6 = [...]uint8{0, 7, 14} ) func (i EncodingType) String() string { switch { - case i == -239: + case i == -311: return _EncodingType_name_0 - case i == -223: + case i == -239: return _EncodingType_name_1 + case i == -223: + return _EncodingType_name_2 case 0 <= i && i <= 2: - return _EncodingType_name_2[_EncodingType_index_2[i]:_EncodingType_index_2[i+1]] - case i == 5: - return _EncodingType_name_3 + return _EncodingType_name_3[_EncodingType_index_3[i]:_EncodingType_index_3[i+1]] + case 4 <= i && i <= 10: + i -= 4 + return _EncodingType_name_4[_EncodingType_index_4[i]:_EncodingType_index_4[i+1]] case 15 <= i && i <= 16: i -= 15 - return _EncodingType_name_4[_EncodingType_index_4[i]:_EncodingType_index_4[i+1]] + return _EncodingType_name_5[_EncodingType_index_5[i]:_EncodingType_index_5[i+1]] + case 21 <= i && i <= 22: + i -= 21 + return _EncodingType_name_6[_EncodingType_index_6[i]:_EncodingType_index_6[i+1]] default: return fmt.Sprintf("EncodingType(%d)", i) } diff --git a/example/proxy/main.go b/example/proxy/main.go index b686739..3388d68 100644 --- a/example/proxy/main.go +++ b/example/proxy/main.go @@ -2,69 +2,111 @@ package main import ( "context" - "flag" "log" "net" - "os" - "time" - vnc "github.com/kward/go-vnc" - "github.com/kward/go-vnc/logging" - "github.com/kward/go-vnc/messages" - "github.com/kward/go-vnc/rfbflags" + vnc "github.com/vtolstov/go-vnc" ) func main() { - flag.Parse() - logging.V(logging.FnDeclLevel) - ln, err := net.Listen("tcp", os.Args[1]) + ln, err := net.Listen("tcp", ":5900") if err != nil { log.Fatalf("Error listen. %v", err) } - // Negotiate connection with the server. - sch := make(chan vnc.ClientMessage) + schServer := make(chan vnc.ClientMessage) + schClient := make(chan vnc.ServerMessage) - // handle client messages. - vcc := vnc.NewServerConfig() - vcc.Auth = []vnc.ServerAuth{&vnc.ServerAuthNone{}} - vcc.ClientMessageCh = sch - go vnc.Serve(context.Background(), ln, vcc) + scfg := &vnc.ServerConfig{ + Width: 800, + Height: 600, + VersionHandler: vnc.ServerVersionHandler, + SecurityHandler: vnc.ServerSecurityHandler, + SecurityHandlers: []vnc.SecurityHandler{&vnc.ClientAuthNone{}}, + ClientInitHandler: vnc.ServerClientInitHandler, + ServerInitHandler: vnc.ServerServerInitHandler, + Encodings: []vnc.Encoding{&vnc.RawEncoding{}}, + PixelFormat: vnc.PixelFormat24bit, + ClientMessageCh: schServer, + ServerMessageCh: schClient, + ClientMessages: vnc.DefaultClientMessages, + DesktopName: []byte("vnc proxy"), + } + go vnc.Serve(context.Background(), ln, scfg) - nc, err := net.Dial("tcp", os.Args[1]) + c, err := net.Dial("tcp", "127.0.0.1:5944") if err != nil { - log.Fatalf("Error connecting to VNC host. %v", err) + log.Fatalf("Error dial. %v", err) + } + cchServer := make(chan vnc.ServerMessage) + cchClient := make(chan vnc.ClientMessage) + + ccfg := &vnc.ClientConfig{ + VersionHandler: vnc.ClientVersionHandler, + SecurityHandler: vnc.ClientSecurityHandler, + SecurityHandlers: []vnc.SecurityHandler{&vnc.ClientAuthNone{}}, + ClientInitHandler: vnc.ClientClientInitHandler, + ServerInitHandler: vnc.ClientServerInitHandler, + PixelFormat: vnc.PixelFormat24bit, + ClientMessageCh: cchClient, + ServerMessageCh: cchServer, + ServerMessages: vnc.DefaultServerMessages, + Encodings: []vnc.Encoding{&vnc.RawEncoding{}}, } - // Negotiate connection with the server. - cch := make(chan vnc.ServerMessage) - vc, err := vnc.Connect(context.Background(), nc, - &vnc.ClientConfig{ - Auth: []vnc.ClientAuth{&vnc.ClientAuthNone{}}, - ServerMessageCh: cch, - }) - + cc, err := vnc.Connect(context.Background(), c, ccfg) if err != nil { - log.Fatalf("Error negotiating connection to VNC host. %v", err) + log.Fatalf("Error dial. %v", err) } + defer cc.Close() + go cc.Handle() - // Listen and handle server messages. - go vc.ListenAndHandle() - - // Process messages coming in on the ServerMessage channel. for { - msg := <-ch - switch msg.Type() { - case messages.FramebufferUpdate: - log.Println("Received FramebufferUpdate message.") - default: - log.Printf("Received message type:%v msg:%v\n", msg.Type(), msg) + select { + case msg := <-cchClient: + switch msg.Type() { + default: + log.Printf("00 Received message type:%v msg:%v\n", msg.Type(), msg) + } + case msg := <-cchServer: + switch msg.Type() { + default: + log.Printf("01 Received message type:%v msg:%v\n", msg.Type(), msg) + } + case msg := <-schClient: + switch msg.Type() { + default: + log.Printf("10 Received message type:%v msg:%v\n", msg.Type(), msg) + } + case msg := <-schServer: + log.Printf("11 Received message type:%v msg:%v\n", msg.Type(), msg) + switch msg.Type() { + case vnc.SetEncodingsMsgType: + encRaw := &vnc.RawEncoding{} + msg1 := &vnc.SetEncodings{ + MsgType: vnc.SetEncodingsMsgType, + EncNum: 1, + Encodings: []vnc.EncodingType{encRaw.Type()}, + } + if err := msg1.Write(cc); err != nil { + log.Fatalf("err %v\n", err) + } + msg2 := &vnc.FramebufferUpdateRequest{ + MsgType: vnc.FramebufferUpdateRequestMsgType, + Inc: 0, + X: 0, + Y: 0, + Width: cc.Width(), + Height: cc.Height(), + } + if err := msg2.Write(cc); err != nil { + log.Fatalf("err %v\n", err) + } + default: + if err := msg.Write(cc); err != nil { + log.Fatalf("err %v\n", err) + } + } } } - - // Process messages coming in on the ClientMessage channel. - for { - msg := <-ch - msg.Write( - } } diff --git a/example/server/main.go b/example/server/main.go index 4a4f868..899f46e 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -2,8 +2,12 @@ package main import ( "context" + "fmt" + "image" "log" + "math" "net" + "time" vnc "github.com/vtolstov/go-vnc" ) @@ -17,10 +21,16 @@ func main() { chServer := make(chan vnc.ClientMessage) chClient := make(chan vnc.ServerMessage) + im := image.NewRGBA(image.Rect(0, 0, width, height)) + tick := time.NewTicker(time.Second / 2) + defer tick.Stop() + cfg := &vnc.ServerConfig{ + Width: 800, + Height: 600, VersionHandler: vnc.ServerVersionHandler, SecurityHandler: vnc.ServerSecurityHandler, - SecurityHandlers: []vnc.ServerHandler{vnc.ServerSecurityNoneHandler}, + SecurityHandlers: []vnc.SecurityHandler{&vnc.ClientAuthNone{}}, ClientInitHandler: vnc.ServerClientInitHandler, ServerInitHandler: vnc.ServerServerInitHandler, Encodings: []vnc.Encoding{&vnc.RawEncoding{}}, @@ -33,10 +43,51 @@ func main() { // Process messages coming in on the ClientMessage channel. for { - msg := <-chClient - switch msg.Type() { - default: - log.Printf("Received message type:%v msg:%v\n", msg.Type(), msg) + select { + case <-tick.C: + drawImage(im, 0) + fmt.Printf("tick\n") + case msg := <-chClient: + switch msg.Type() { + default: + log.Printf("11 Received message type:%v msg:%v\n", msg.Type(), msg) + } + case msg := <-chServer: + switch msg.Type() { + default: + log.Printf("22 Received message type:%v msg:%v\n", msg.Type(), msg) + } + } + } +} + +const ( + width = 800 + height = 600 +) + +func drawImage(im *image.RGBA, anim int) { + pos := 0 + const border = 50 + for y := 0; y < height; y++ { + for x := 0; x < width; x++ { + var r, g, b uint8 + switch { + case x < border*2.5 && x < int((1.1+math.Sin(float64(y+anim*2)/40))*border): + r = 255 + case x > width-border*2.5 && x > width-int((1.1+math.Sin(math.Pi+float64(y+anim*2)/40))*border): + g = 255 + case y < border*2.5 && y < int((1.1+math.Sin(float64(x+anim*2)/40))*border): + r, g = 255, 255 + case y > height-border*2.5 && y > height-int((1.1+math.Sin(math.Pi+float64(x+anim*2)/40))*border): + b = 255 + default: + r, g, b = uint8(x+anim), uint8(y+anim), uint8(x+y+anim*3) + } + im.Pix[pos] = r + im.Pix[pos+1] = g + im.Pix[pos+2] = b + pos += 4 // skipping alpha } } } diff --git a/handlers.go b/handlers.go index 8c40940..8cade5e 100644 --- a/handlers.go +++ b/handlers.go @@ -1,5 +1,10 @@ package vnc +import ( + "encoding/binary" + "fmt" +) + // ClientMessage is the interface type ClientMessage interface { Type() ClientMessageType @@ -14,22 +19,264 @@ type ServerMessage interface { Write(Conn) error } +const ProtoVersionLength = 12 + +const ( + ProtoVersionUnknown = "" + ProtoVersion33 = "RFB 003.003\n" + ProtoVersion38 = "RFB 003.008\n" +) + +func ParseProtoVersion(pv []byte) (uint, uint, error) { + var major, minor uint + + if len(pv) < ProtoVersionLength { + return 0, 0, fmt.Errorf("ProtocolVersion message too short (%v < %v)", len(pv), ProtoVersionLength) + } + + l, err := fmt.Sscanf(string(pv), "RFB %d.%d\n", &major, &minor) + if l != 2 { + return 0, 0, fmt.Errorf("error parsing ProtocolVersion.") + } + if err != nil { + return 0, 0, err + } + + return major, minor, nil +} + +func ClientVersionHandler(cfg *ClientConfig, c Conn) error { + var version [ProtoVersionLength]byte + + if err := binary.Read(c, binary.BigEndian, &version); err != nil { + return err + } + + major, minor, err := ParseProtoVersion(version[:]) + if err != nil { + return err + } + + pv := ProtoVersionUnknown + if major == 3 { + if minor >= 8 { + pv = ProtoVersion38 + } else if minor >= 3 { + pv = ProtoVersion38 + } + } + if pv == ProtoVersionUnknown { + return fmt.Errorf("ProtocolVersion handshake failed; unsupported version '%v'", string(version[:])) + } + + c.SetProtoVersion(pv) + + if err := binary.Write(c, binary.BigEndian, []byte(pv)); err != nil { + return err + } + return c.Flush() +} + func ServerVersionHandler(cfg *ServerConfig, c Conn) error { + var version [ProtoVersionLength]byte + if err := binary.Write(c, binary.BigEndian, []byte(ProtoVersion38)); err != nil { + return err + } + if err := c.Flush(); err != nil { + return err + } + if err := binary.Read(c, binary.BigEndian, &version); err != nil { + return err + } + + major, minor, err := ParseProtoVersion(version[:]) + if err != nil { + return err + } + + pv := ProtoVersionUnknown + if major == 3 { + if minor >= 8 { + pv = ProtoVersion38 + } else if minor >= 3 { + pv = ProtoVersion33 + } + } + if pv == ProtoVersionUnknown { + return fmt.Errorf("ProtocolVersion handshake failed; unsupported version '%v'", string(version[:])) + } + + c.SetProtoVersion(pv) + return nil +} + +func ClientSecurityHandler(cfg *ClientConfig, c Conn) error { + var numSecurityTypes uint8 + if err := binary.Read(c, binary.BigEndian, &numSecurityTypes); err != nil { + return err + } + secTypes := make([]SecurityType, numSecurityTypes) + if err := binary.Read(c, binary.BigEndian, &secTypes); err != nil { + return err + } + + if err := binary.Write(c, binary.BigEndian, cfg.SecurityHandlers[0].Type()); err != nil { + return err + } + + if err := c.Flush(); err != nil { + return err + } + + var authCode uint32 + if err := binary.Read(c, binary.BigEndian, &authCode); err != nil { + return err + } + + if authCode == 1 { + var reasonLength uint32 + if err := binary.Read(c, binary.BigEndian, &reasonLength); err != nil { + return err + } + reasonText := make([]byte, reasonLength) + if err := binary.Read(c, binary.BigEndian, &reasonText); err != nil { + return err + } + return fmt.Errorf("%s", reasonText) + } + return nil } func ServerSecurityHandler(cfg *ServerConfig, c Conn) error { + if err := binary.Write(c, binary.BigEndian, uint8(len(cfg.SecurityHandlers))); err != nil { + return err + } + + for _, sectype := range cfg.SecurityHandlers { + if err := binary.Write(c, binary.BigEndian, sectype.Type()); err != nil { + return err + } + } + + if err := c.Flush(); err != nil { + return err + } + + var secType SecurityType + if err := binary.Read(c, binary.BigEndian, &secType); err != nil { + return err + } + + secTypes := make(map[SecurityType]SecurityHandler) + for _, sType := range cfg.SecurityHandlers { + secTypes[sType.Type()] = sType + } + + sType, ok := secTypes[secType] + if !ok { + return fmt.Errorf("server type %d not implemented") + } + + var authCode uint32 + authErr := sType.Auth(c) + if authErr != nil { + authCode = uint32(1) + } + + if err := binary.Write(c, binary.BigEndian, authCode); err != nil { + return err + } + if err := c.Flush(); err != nil { + return err + } + + if authErr != nil { + if err := binary.Write(c, binary.BigEndian, len(authErr.Error())); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, []byte(authErr.Error())); err != nil { + return err + } + return authErr + } + return nil } -func ServerSecurityNoneHandler(cfg *ServerConfig, c Conn) error { +func ClientServerInitHandler(cfg *ClientConfig, c Conn) error { + srvInit := &ServerInit{} + + if err := binary.Read(c, binary.BigEndian, &srvInit.FBWidth); err != nil { + return err + } + if err := binary.Read(c, binary.BigEndian, &srvInit.FBHeight); err != nil { + return err + } + if err := binary.Read(c, binary.BigEndian, &srvInit.PixelFormat); err != nil { + return err + } + if err := binary.Read(c, binary.BigEndian, &srvInit.NameLength); err != nil { + return err + } + + nameText := make([]byte, srvInit.NameLength) + if err := binary.Read(c, binary.BigEndian, nameText); err != nil { + return err + } + + srvInit.NameText = nameText + c.SetDesktopName(string(srvInit.NameText)) + c.SetWidth(srvInit.FBWidth) + c.SetHeight(srvInit.FBHeight) + c.SetPixelFormat(&srvInit.PixelFormat) return nil } func ServerServerInitHandler(cfg *ServerConfig, c Conn) error { - return nil + srvInit := &ServerInit{ + FBWidth: c.Width(), + FBHeight: c.Height(), + PixelFormat: *c.PixelFormat(), + NameLength: uint32(len(cfg.DesktopName)), + NameText: []byte(cfg.DesktopName), + } + + if err := binary.Write(c, binary.BigEndian, srvInit.FBWidth); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, srvInit.FBHeight); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, srvInit.PixelFormat); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, srvInit.NameLength); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, srvInit.NameText); err != nil { + return err + } + + return c.Flush() +} + +func ClientClientInitHandler(cfg *ClientConfig, c Conn) error { + if err := binary.Write(c, binary.BigEndian, cfg.Exclusive); err != nil { + return err + } + return c.Flush() } func ServerClientInitHandler(cfg *ServerConfig, c Conn) error { + var shared uint8 + if err := binary.Read(c, binary.BigEndian, &shared); err != nil { + return err + } + /* TODO + if shared != 1 { + c.SetShared(false) + } + */ return nil } diff --git a/image.go b/image.go index f5eee4b..24d4da7 100644 --- a/image.go +++ b/image.go @@ -1,6 +1,8 @@ package vnc import ( + "encoding/binary" + "fmt" "image" ) @@ -29,9 +31,14 @@ func NewColor(pf *PixelFormat, cm *ColorMap) *Color { type Rectangle struct { X, Y uint16 Width, Height uint16 + EncType EncodingType Enc Encoding } +func NewRectangle() *Rectangle { + return &Rectangle{} +} + // Marshal implements the Marshaler interface. func (c *Color) Marshal() ([]byte, error) { order := c.pf.order() @@ -106,50 +113,54 @@ func colorsToImage(x, y, width, height uint16, colors []Color) *image.RGBA64 { } // Marshal implements the Marshaler interface. -func (r *Rectangle) Marshal() ([]byte, error) { - /* - buf := bytes.NewBuffer(nil) +func (r *Rectangle) Write(c Conn) error { + if err := binary.Write(c, binary.BigEndian, r.X); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, r.Y); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, r.Width); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, r.Height); err != nil { + return err + } + if err := binary.Write(c, binary.BigEndian, r.EncType); err != nil { + return err + } - var msg Rectangle - msg.X, msg.Y, msg.W, msg.H = r.X, r.Y, r.Width, r.Height - msg.E = r.Enc.Type() - if err := binary.Write(buf, binary.BigEndian, msg); err != nil { - return nil, err - } - - bytes, err := r.Enc.Marshal() - if err != nil { - return nil, err - } - if err := binary.Write(buf, binary.BigEndian, bytes); err != nil { - return nil, err - } - - return buf.Bytes(), nil - */ - return nil, nil + if err := r.Enc.Write(c, r); err != nil { + return err + } + return c.Flush() } -// Unmarshal implements the Unmarshaler interface. -func (r *Rectangle) Unmarshal(data []byte) error { - /* - buf := bytes.NewBuffer(data) +func (r *Rectangle) Read(c Conn) error { + fmt.Printf("qqq\n") + var err error + if err = binary.Read(c, binary.BigEndian, &r.X); err != nil { + return err + } + if err = binary.Read(c, binary.BigEndian, &r.Y); err != nil { + return err + } + if err = binary.Read(c, binary.BigEndian, &r.Width); err != nil { + return err + } + if err = binary.Read(c, binary.BigEndian, &r.Height); err != nil { + return err + } + if err = binary.Read(c, binary.BigEndian, &r.EncType); err != nil { + return err + } + fmt.Printf("rrrr %#+v\n", r) + switch r.EncType { + case EncRaw: + r.Enc = &RawEncoding{} + } - var msg Rectangle - if err := binary.Read(buf, binary.BigEndian, &msg); err != nil { - return err - } - r.X, r.Y, r.Width, r.Height = msg.X, msg.Y, msg.W, msg.H - - switch msg.E { - case encodings.Raw: - r.Enc = &RawEncoding{} - default: - return fmt.Errorf("unable to unmarshal encoding %v", msg.E) - } - return nil - */ - return nil + return r.Enc.Read(c, r) } // Area returns the total area in pixels of the Rectangle. diff --git a/pixel_format.go b/pixel_format.go index 0234581..9f11084 100644 --- a/pixel_format.go +++ b/pixel_format.go @@ -13,6 +13,7 @@ import ( var ( PixelFormat8bit *PixelFormat = NewPixelFormat(8) PixelFormat16bit *PixelFormat = NewPixelFormat(16) + PixelFormat24bit *PixelFormat = NewPixelFormat(24) PixelFormat32bit *PixelFormat = NewPixelFormat(32) ) @@ -31,7 +32,7 @@ const pixelFormatLen = 16 // NewPixelFormat returns a populated PixelFormat structure. func NewPixelFormat(bpp uint8) *PixelFormat { - bigEndian := uint8(1) + bigEndian := uint8(0) rgbMax := uint16(math.Exp2(float64(bpp))) - 1 var ( tc = uint8(1) @@ -53,18 +54,18 @@ func NewPixelFormat(bpp uint8) *PixelFormat { func (pf *PixelFormat) Marshal() ([]byte, error) { // Validation checks. switch pf.BPP { - case 8, 16, 32: + case 8, 16, 24, 32: default: - return nil, fmt.Errorf("Invalid BPP value %v; must be 8, 16, or 32.", pf.BPP) + return nil, fmt.Errorf("Invalid BPP value %v; must be 8, 16, 24 or 32.", pf.BPP) } if pf.Depth < pf.BPP { return nil, fmt.Errorf("Invalid Depth value %v; cannot be < BPP", pf.Depth) } switch pf.Depth { - case 8, 16, 32: + case 8, 16, 24, 32: default: - return nil, fmt.Errorf("Invalid Depth value %v; must be 8, 16, or 32.", pf.Depth) + return nil, fmt.Errorf("Invalid Depth value %v; must be 8, 16, 24 or 32.", pf.Depth) } // Create the slice of bytes diff --git a/security.go b/security.go new file mode 100644 index 0000000..009af9e --- /dev/null +++ b/security.go @@ -0,0 +1,180 @@ +package vnc + +import ( + "crypto/des" + "encoding/binary" + "fmt" +) + +type SecurityType uint8 + +const ( + SecTypeUnknown = SecurityType(0) + SecTypeNone = SecurityType(1) + SecTypeVNC = SecurityType(2) + SecTypeVeNCrypt = SecurityType(19) +) + +type SecuritySubType uint32 + +const ( + SecSubTypeUnknown = SecuritySubType(0) +) + +const ( + SecSubTypeVeNCrypt01Unknown = SecuritySubType(0) + SecSubTypeVeNCrypt01Plain = SecuritySubType(19) + SecSubTypeVeNCrypt01TLSNone = SecuritySubType(20) + SecSubTypeVeNCrypt01TLSVNC = SecuritySubType(21) + SecSubTypeVeNCrypt01TLSPlain = SecuritySubType(22) + SecSubTypeVeNCrypt01X509None = SecuritySubType(23) + SecSubTypeVeNCrypt01X509VNC = SecuritySubType(24) + SecSubTypeVeNCrypt01X509Plain = SecuritySubType(25) +) + +const ( + SecSubTypeVeNCrypt02Unknown = SecuritySubType(0) + SecSubTypeVeNCrypt02Plain = SecuritySubType(256) + SecSubTypeVeNCrypt02TLSNone = SecuritySubType(257) + SecSubTypeVeNCrypt02TLSVNC = SecuritySubType(258) + SecSubTypeVeNCrypt02TLSPlain = SecuritySubType(259) + SecSubTypeVeNCrypt02X509None = SecuritySubType(260) + SecSubTypeVeNCrypt02X509VNC = SecuritySubType(261) + SecSubTypeVeNCrypt02X509Plain = SecuritySubType(262) +) + +type SecurityHandler interface { + Type() SecurityType + SubType() SecuritySubType + Auth(Conn) error +} + +type ClientAuthNone struct{} + +func (*ClientAuthNone) Type() SecurityType { + return SecTypeNone +} + +func (*ClientAuthNone) SubType() SecuritySubType { + return SecSubTypeUnknown +} + +func (*ClientAuthNone) Auth(conn Conn) error { + return nil +} + +// ServerAuthNone is the "none" authentication. See 7.2.1. +type ServerAuthNone struct{} + +func (*ServerAuthNone) Type() SecurityType { + return SecTypeNone +} + +func (*ServerAuthNone) Auth(c Conn) error { + return nil +} + +func (*ClientAuthVeNCrypt02Plain) SubType() SecuritySubType { + return SecSubTypeVeNCrypt02Plain +} + +// ClientAuthVeNCryptPlain see https://www.berrange.com/~dan/vencrypt.txt +type ClientAuthVeNCrypt02Plain struct { + Username []byte + Password []byte +} + +func (auth *ClientAuthVeNCrypt02Plain) Auth(c Conn) error { + if len(auth.Password) == 0 || len(auth.Username) == 0 { + return fmt.Errorf("Security Handshake failed; no username and/or password provided for VeNCryptAuth.") + } + + if err := binary.Write(c, binary.BigEndian, uint32(len(auth.Username))); err != nil { + return err + } + + if err := binary.Write(c, binary.BigEndian, uint32(len(auth.Password))); err != nil { + return err + } + + if err := binary.Write(c, binary.BigEndian, auth.Username); err != nil { + return err + } + + if err := binary.Write(c, binary.BigEndian, auth.Password); err != nil { + return err + } + + return c.Flush() +} + +// ServerAuthVNC is the standard password authentication. See 7.2.2. +type ServerAuthVNC struct{} + +func (*ServerAuthVNC) Type() SecurityType { + return SecTypeVNC +} +func (*ServerAuthVNC) SubType() SecuritySubType { + return SecSubTypeUnknown +} + +func (auth *ServerAuthVNC) Auth(c Conn) error { + return nil +} + +// ClientAuthVNC is the standard password authentication. See 7.2.2. +type ClientAuthVNC struct { + Challenge [16]byte + Password []byte +} + +func (*ClientAuthVNC) Type() SecurityType { + return SecTypeVNC +} +func (*ClientAuthVNC) SubType() SecuritySubType { + return SecSubTypeUnknown +} + +func (auth *ClientAuthVNC) Auth(c Conn) error { + if len(auth.Password) == 0 { + return fmt.Errorf("Security Handshake failed; no password provided for VNCAuth.") + } + + if err := binary.Read(c, binary.BigEndian, auth.Challenge); err != nil { + return err + } + + auth.encode() + + // Send the encrypted challenge back to server + if err := binary.Write(c, binary.BigEndian, auth.Challenge); err != nil { + return err + } + + return c.Flush() +} + +func (auth *ClientAuthVNC) encode() error { + // Copy password string to 8 byte 0-padded slice + key := make([]byte, 8) + copy(key, auth.Password) + + // Each byte of the password needs to be reversed. This is a + // non RFC-documented behaviour of VNC clients and servers + for i := range key { + key[i] = (key[i]&0x55)<<1 | (key[i]&0xAA)>>1 // Swap adjacent bits + key[i] = (key[i]&0x33)<<2 | (key[i]&0xCC)>>2 // Swap adjacent pairs + key[i] = (key[i]&0x0F)<<4 | (key[i]&0xF0)>>4 // Swap the 2 halves + } + + // Encrypt challenge with key. + cipher, err := des.NewCipher(key) + if err != nil { + return err + } + for i := 0; i < len(auth.Challenge); i += cipher.BlockSize() { + cipher.Encrypt(auth.Challenge[i:i+cipher.BlockSize()], auth.Challenge[i:i+cipher.BlockSize()]) + } + + return nil +} diff --git a/server.go b/server.go index 8166b65..fc48496 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "fmt" "net" + "sync" ) var DefaultClientMessages = []ClientMessage{ @@ -17,9 +18,43 @@ var DefaultClientMessages = []ClientMessage{ &ClientCutText{}, } +type ServerInit struct { + FBWidth, FBHeight uint16 + PixelFormat PixelFormat + NameLength uint32 + NameText []byte +} + var _ Conn = (*ServerConn)(nil) +func (c *ServerConn) UnreadByte() error { + return c.br.UnreadByte() +} + +func (c *ServerConn) Conn() net.Conn { + return c.c +} + +func (c *ServerConn) SetEncodings(encs []EncodingType) error { + encodings := make(map[EncodingType]Encoding) + for _, enc := range c.cfg.Encodings { + encodings[enc.Type()] = enc + } + for _, encType := range encs { + if enc, ok := encodings[encType]; ok { + c.encodings = append(c.encodings, enc) + } + } + return nil +} + +func (c *ServerConn) SetProtoVersion(pv string) { + c.protocol = pv +} + func (c *ServerConn) Flush() error { + c.m.Lock() + defer c.m.Unlock() return c.bw.Flush() } @@ -41,6 +76,8 @@ func (c *ServerConn) Read(buf []byte) (int, error) { } func (c *ServerConn) Write(buf []byte) (int, error) { + c.m.Lock() + defer c.m.Unlock() return c.bw.Write(buf) } @@ -57,6 +94,13 @@ func (c *ServerConn) DesktopName() string { func (c *ServerConn) PixelFormat() *PixelFormat { return c.pixelFormat } +func (c *ServerConn) SetDesktopName(name string) { + c.desktopName = name +} +func (c *ServerConn) SetPixelFormat(pf *PixelFormat) error { + c.pixelFormat = pf + return nil +} func (c *ServerConn) Encodings() []Encoding { return c.encodings } @@ -94,20 +138,20 @@ const ( // FramebufferUpdate holds a FramebufferUpdate wire format message. type FramebufferUpdate struct { MsgType ServerMessageType - NumRect uint16 // number-of-rectangles _ [1]byte // pad + NumRect uint16 // number-of-rectangles Rects []Rectangle // rectangles } func (msg *FramebufferUpdate) Type() ServerMessageType { - return msg.MsgType + return FramebufferUpdateMsgType } func (msg *FramebufferUpdate) Read(c Conn) error { if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil { return err } - + fmt.Printf("qqqqq\n") var pad [1]byte if err := binary.Read(c, binary.BigEndian, &pad); err != nil { return err @@ -116,17 +160,15 @@ func (msg *FramebufferUpdate) Read(c Conn) error { if err := binary.Read(c, binary.BigEndian, msg.NumRect); err != nil { return err } - /* - // Extract rectangles. - rects := make([]Rectangle, msg.NumRect) - for i := uint16(0); i < msg.NumRect; i++ { - rect := NewRectangle(c) - if err := rect.Read(c); err != nil { - return err - } - msg.Rects = append(msg.Rects, *rect) + msg.Rects = make([]Rectangle, msg.NumRect) + for i := uint16(0); i < msg.NumRect; i++ { + rect := NewRectangle() + if err := rect.Read(c); err != nil { + return err } - */ + msg.Rects = append(msg.Rects, *rect) + } + return nil } @@ -141,13 +183,11 @@ func (msg *FramebufferUpdate) Write(c Conn) error { if err := binary.Write(c, binary.BigEndian, msg.NumRect); err != nil { return err } - /* - for _, rect := range msg.Rects { - if err := rect.Write(c); err != nil { - return err - } + for _, rect := range msg.Rects { + if err := rect.Write(c); err != nil { + return err } - */ + } return c.Flush() } @@ -157,7 +197,7 @@ type ServerConn struct { br *bufio.Reader bw *bufio.Writer protocol string - + m sync.Mutex // If the pixel format uses a color map, then this is the color // map that is used. This should not be modified directly, since // the data comes from the server. @@ -190,7 +230,7 @@ type ServerHandler func(*ServerConfig, Conn) error type ServerConfig struct { VersionHandler ServerHandler SecurityHandler ServerHandler - SecurityHandlers []ServerHandler + SecurityHandlers []SecurityHandler ClientInitHandler ServerHandler ServerInitHandler ServerHandler Encodings []Encoding @@ -199,35 +239,46 @@ type ServerConfig struct { ClientMessageCh chan ClientMessage ServerMessageCh chan ServerMessage ClientMessages []ClientMessage + DesktopName []byte + Height uint16 + Width uint16 } func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) { if cfg.ClientMessageCh == nil { return nil, fmt.Errorf("ClientMessageCh nil") } + if len(cfg.ClientMessages) == 0 { return nil, fmt.Errorf("ClientMessage 0") } + return &ServerConn{ c: c, br: bufio.NewReader(c), bw: bufio.NewWriter(c), cfg: cfg, + quit: make(chan struct{}), encodings: cfg.Encodings, pixelFormat: cfg.PixelFormat, + fbWidth: cfg.Width, + fbHeight: cfg.Height, }, nil } func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error { for { + c, err := ln.Accept() if err != nil { continue } + conn, err := NewServerConn(c, cfg) if err != nil { continue } + if err := cfg.VersionHandler(cfg, conn); err != nil { conn.Close() continue @@ -237,62 +288,80 @@ func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error { conn.Close() continue } + if err := cfg.ClientInitHandler(cfg, conn); err != nil { conn.Close() continue } + if err := cfg.ServerInitHandler(cfg, conn); err != nil { conn.Close() continue } + go conn.Handle() } } func (c *ServerConn) Handle() error { var err error + var wg sync.WaitGroup + defer c.Close() clientMessages := make(map[ClientMessageType]ClientMessage) for _, m := range c.cfg.ClientMessages { clientMessages[m.Type()] = m } + wg.Add(2) -serverLoop: - for { - select { - case msg := <-c.cfg.ServerMessageCh: - if err = msg.Write(c); err != nil { - return err + // server + go func() error { + defer wg.Done() + for { + select { + case msg := <-c.cfg.ServerMessageCh: + if err = msg.Write(c); err != nil { + return err + } + case <-c.quit: + return nil } - c.Flush() - case <-c.quit: - break serverLoop } - } + }() -clientLoop: - for { - select { - case <-c.quit: - break clientLoop - default: - var messageType ClientMessageType - if err := binary.Read(c, binary.BigEndian, &messageType); err != nil { - break clientLoop - } + // client + go func() error { + defer wg.Done() + for { + select { + case <-c.quit: + return nil + default: + var messageType ClientMessageType + if err := binary.Read(c, binary.BigEndian, &messageType); err != nil { + return err + } - msg, ok := clientMessages[messageType] - if !ok { - err = fmt.Errorf("unsupported message-type: %v", messageType) - break clientLoop - } - if err := msg.Read(c); err != nil { - break clientLoop - } + if err := c.UnreadByte(); err != nil { + return err + } - c.cfg.ClientMessageCh <- msg + msg, ok := clientMessages[messageType] + if !ok { + return fmt.Errorf("unsupported message-type: %v", messageType) + + } + + if err := msg.Read(c); err != nil { + return err + } + + c.cfg.ClientMessageCh <- msg + } } - } + }() + + wg.Wait() return nil } @@ -304,7 +373,7 @@ type ServerCutText struct { } func (msg *ServerCutText) Type() ServerMessageType { - return msg.MsgType + return ServerCutTextMsgType } func (msg *ServerCutText) Read(c Conn) error { @@ -353,7 +422,7 @@ type Bell struct { } func (msg *Bell) Type() ServerMessageType { - return msg.MsgType + return BellMsgType } func (msg *Bell) Read(c Conn) error { @@ -361,7 +430,10 @@ func (msg *Bell) Read(c Conn) error { } func (msg *Bell) Write(c Conn) error { - return binary.Write(c, binary.BigEndian, msg.MsgType) + if err := binary.Write(c, binary.BigEndian, msg.MsgType); err != nil { + return err + } + return c.Flush() } type SetColorMapEntries struct { @@ -373,7 +445,7 @@ type SetColorMapEntries struct { } func (msg *SetColorMapEntries) Type() ServerMessageType { - return msg.MsgType + return SetColorMapEntriesMsgType } func (msg *SetColorMapEntries) Read(c Conn) error {