next steps

Signed-off-by: Vasiliy Tolstov <v.tolstov@selfip.ru>
This commit is contained in:
Василий Толстов 2017-06-13 01:52:07 +03:00
parent a45bca15e3
commit dcb1e176c4
12 changed files with 947 additions and 259 deletions

1
.gitignore vendored
View File

@ -14,3 +14,4 @@
.glide/ .glide/
example/client/client example/client/client
example/server/server example/server/server
example/proxy/proxy

136
client.go
View File

@ -6,6 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"sync"
) )
var DefaultServerMessages = []ServerMessage{ var DefaultServerMessages = []ServerMessage{
@ -21,41 +22,58 @@ func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, e
conn.Close() conn.Close()
return nil, err return nil, err
} }
if err := cfg.VersionHandler(cfg, conn); err != nil { if err := cfg.VersionHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
if err := cfg.SecurityHandler(cfg, conn); err != nil { if err := cfg.SecurityHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
if err := cfg.ClientInitHandler(cfg, conn); err != nil { if err := cfg.ClientInitHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
if err := cfg.ServerInitHandler(cfg, conn); err != nil { if err := cfg.ServerInitHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
/*
// Send client-to-server messages.
encs := conn.encodings
if err := conn.SetEncodings(encs); err != nil {
conn.Close()
return nil, Errorf("failure calling SetEncodings; %s", err)
}
pf := conn.pixelFormat
if err := conn.SetPixelFormat(pf); err != nil {
conn.Close()
return nil, Errorf("failure calling SetPixelFormat; %s", err)
}
*/
return conn, nil return conn, nil
} }
var _ Conn = (*ClientConn)(nil) var _ Conn = (*ClientConn)(nil)
func (c *ClientConn) Conn() net.Conn {
return c.c
}
func (c *ClientConn) SetProtoVersion(pv string) {
c.protocol = pv
}
func (c *ClientConn) SetEncodings(encs []EncodingType) error {
msg := &SetEncodings{
MsgType: SetEncodingsMsgType,
EncNum: uint16(len(encs)),
Encodings: encs,
}
return msg.Write(c)
}
func (c *ClientConn) UnreadByte() error {
return c.br.UnreadByte()
}
func (c *ClientConn) Flush() error { func (c *ClientConn) Flush() error {
c.m.Lock()
defer c.m.Unlock()
return c.bw.Flush() return c.bw.Flush()
} }
@ -68,6 +86,8 @@ func (c *ClientConn) Read(buf []byte) (int, error) {
} }
func (c *ClientConn) Write(buf []byte) (int, error) { func (c *ClientConn) Write(buf []byte) (int, error) {
c.m.Lock()
defer c.m.Unlock()
return c.bw.Write(buf) return c.bw.Write(buf)
} }
@ -84,6 +104,13 @@ func (c *ClientConn) DesktopName() string {
func (c *ClientConn) PixelFormat() *PixelFormat { func (c *ClientConn) PixelFormat() *PixelFormat {
return c.pixelFormat return c.pixelFormat
} }
func (c *ClientConn) SetDesktopName(name string) {
c.desktopName = name
}
func (c *ClientConn) SetPixelFormat(pf *PixelFormat) error {
c.pixelFormat = pf
return nil
}
func (c *ClientConn) Encodings() []Encoding { func (c *ClientConn) Encodings() []Encoding {
return c.encodings return c.encodings
} }
@ -110,7 +137,7 @@ type ClientConn struct {
bw *bufio.Writer bw *bufio.Writer
cfg *ClientConfig cfg *ClientConfig
protocol string protocol string
m sync.Mutex
// If the pixel format uses a color map, then this is the color // If the pixel format uses a color map, then this is the color
// map that is used. This should not be modified directly, since // map that is used. This should not be modified directly, since
// the data comes from the server. // the data comes from the server.
@ -151,6 +178,7 @@ func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) {
br: bufio.NewReader(c), br: bufio.NewReader(c),
bw: bufio.NewWriter(c), bw: bufio.NewWriter(c),
encodings: cfg.Encodings, encodings: cfg.Encodings,
quit: make(chan struct{}),
pixelFormat: cfg.PixelFormat, pixelFormat: cfg.PixelFormat,
}, nil }, nil
} }
@ -179,7 +207,7 @@ type SetPixelFormat struct {
} }
func (msg *SetPixelFormat) Type() ClientMessageType { func (msg *SetPixelFormat) Type() ClientMessageType {
return msg.MsgType return SetPixelFormatMsgType
} }
func (msg *SetPixelFormat) Write(c Conn) error { func (msg *SetPixelFormat) Write(c Conn) error {
@ -201,7 +229,7 @@ func (msg *SetPixelFormat) Write(c Conn) error {
} }
func (msg *SetPixelFormat) Read(c Conn) error { func (msg *SetPixelFormat) Read(c Conn) error {
return binary.Read(c, binary.BigEndian, msg) return binary.Read(c, binary.BigEndian, &msg)
} }
// SetEncodings holds the wire format message, sans encoding-type field. // SetEncodings holds the wire format message, sans encoding-type field.
@ -209,34 +237,33 @@ type SetEncodings struct {
MsgType ClientMessageType MsgType ClientMessageType
_ [1]byte // padding _ [1]byte // padding
EncNum uint16 // number-of-encodings EncNum uint16 // number-of-encodings
Encodings []Encoding Encodings []EncodingType
} }
func (msg *SetEncodings) Type() ClientMessageType { func (msg *SetEncodings) Type() ClientMessageType {
return msg.MsgType return SetEncodingsMsgType
} }
func (msg *SetEncodings) Read(c Conn) error { func (msg *SetEncodings) Read(c Conn) error {
if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil { if err := binary.Read(c, binary.BigEndian, &msg.MsgType); err != nil {
return err return err
} }
var pad [1]byte var pad [1]byte
if err := binary.Read(c, binary.BigEndian, &pad); err != nil { if err := binary.Read(c, binary.BigEndian, &pad); err != nil {
return err return err
} }
if err := binary.Read(c, binary.BigEndian, msg.EncNum); err != nil { if err := binary.Read(c, binary.BigEndian, &msg.EncNum); err != nil {
return err return err
} }
var enc EncodingType
var enc Encoding
for i := uint16(0); i < msg.EncNum; i++ { for i := uint16(0); i < msg.EncNum; i++ {
if err := binary.Read(c, binary.BigEndian, &enc); err != nil { if err := binary.Read(c, binary.BigEndian, &enc); err != nil {
return err return err
} }
msg.Encodings = append(msg.Encodings, enc) msg.Encodings = append(msg.Encodings, enc)
} }
c.SetEncodings(msg.Encodings)
return nil return nil
} }
@ -275,15 +302,18 @@ type FramebufferUpdateRequest struct {
} }
func (msg *FramebufferUpdateRequest) Type() ClientMessageType { func (msg *FramebufferUpdateRequest) Type() ClientMessageType {
return msg.MsgType return FramebufferUpdateRequestMsgType
} }
func (msg *FramebufferUpdateRequest) Read(c Conn) error { func (msg *FramebufferUpdateRequest) Read(c Conn) error {
return binary.Read(c, binary.BigEndian, msg) return binary.Read(c, binary.BigEndian, &msg)
} }
func (msg *FramebufferUpdateRequest) Write(c Conn) error { func (msg *FramebufferUpdateRequest) Write(c Conn) error {
return binary.Write(c, binary.BigEndian, msg) if err := binary.Write(c, binary.BigEndian, msg); err != nil {
return err
}
return c.Flush()
} }
// KeyEvent holds the wire format message. // KeyEvent holds the wire format message.
@ -295,15 +325,18 @@ type KeyEvent struct {
} }
func (msg *KeyEvent) Type() ClientMessageType { func (msg *KeyEvent) Type() ClientMessageType {
return msg.MsgType return KeyEventMsgType
} }
func (msg *KeyEvent) Read(c Conn) error { func (msg *KeyEvent) Read(c Conn) error {
return binary.Read(c, binary.BigEndian, msg) return binary.Read(c, binary.BigEndian, &msg)
} }
func (msg *KeyEvent) Write(c Conn) error { func (msg *KeyEvent) Write(c Conn) error {
return binary.Write(c, binary.BigEndian, msg) if err := binary.Write(c, binary.BigEndian, msg); err != nil {
return err
}
return c.Flush()
} }
// PointerEventMessage holds the wire format message. // PointerEventMessage holds the wire format message.
@ -314,15 +347,18 @@ type PointerEvent struct {
} }
func (msg *PointerEvent) Type() ClientMessageType { func (msg *PointerEvent) Type() ClientMessageType {
return msg.MsgType return PointerEventMsgType
} }
func (msg *PointerEvent) Read(c Conn) error { func (msg *PointerEvent) Read(c Conn) error {
return binary.Read(c, binary.BigEndian, msg) return binary.Read(c, binary.BigEndian, &msg)
} }
func (msg *PointerEvent) Write(c Conn) error { func (msg *PointerEvent) Write(c Conn) error {
return binary.Write(c, binary.BigEndian, msg) if err := binary.Write(c, binary.BigEndian, msg); err != nil {
return err
}
return c.Flush()
} }
// ClientCutText holds the wire format message, sans the text field. // ClientCutText holds the wire format message, sans the text field.
@ -334,11 +370,11 @@ type ClientCutText struct {
} }
func (msg *ClientCutText) Type() ClientMessageType { func (msg *ClientCutText) Type() ClientMessageType {
return msg.MsgType return ClientCutTextMsgType
} }
func (msg *ClientCutText) Read(c Conn) error { func (msg *ClientCutText) Read(c Conn) error {
if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil { if err := binary.Read(c, binary.BigEndian, &msg.MsgType); err != nil {
return err return err
} }
@ -347,7 +383,7 @@ func (msg *ClientCutText) Read(c Conn) error {
return err return err
} }
if err := binary.Read(c, binary.BigEndian, msg.Length); err != nil { if err := binary.Read(c, binary.BigEndian, &msg.Length); err != nil {
return err return err
} }
@ -387,7 +423,8 @@ func (msg *ClientCutText) Write(c Conn) error {
// ListenAndHandle listens to a VNC server and handles server messages. // ListenAndHandle listens to a VNC server and handles server messages.
func (c *ClientConn) Handle() error { func (c *ClientConn) Handle() error {
var err error var err error
var wg sync.WaitGroup
wg.Add(2)
defer c.Close() defer c.Close()
serverMessages := make(map[ServerMessageType]ServerMessage) serverMessages := make(map[ServerMessageType]ServerMessage)
@ -395,34 +432,40 @@ func (c *ClientConn) Handle() error {
serverMessages[m.Type()] = m serverMessages[m.Type()] = m
} }
clientLoop: go func() error {
defer wg.Done()
for { for {
select { select {
case msg := <-c.cfg.ServerMessageCh: case msg := <-c.cfg.ClientMessageCh:
if err = msg.Write(c); err != nil { if err = msg.Write(c); err != nil {
return err return err
} }
case <-c.quit: case <-c.quit:
break clientLoop return nil
} }
} }
}()
serverLoop: go func() error {
defer wg.Done()
for { for {
select { select {
case <-c.quit: case <-c.quit:
break serverLoop return nil
default: default:
var messageType ServerMessageType var messageType ServerMessageType
if err = binary.Read(c, binary.BigEndian, &messageType); err != nil { if err = binary.Read(c, binary.BigEndian, &messageType); err != nil {
break serverLoop return err
}
if err := c.UnreadByte(); err != nil {
return err
} }
msg, ok := serverMessages[messageType] msg, ok := serverMessages[messageType]
if !ok { if !ok {
break serverLoop return fmt.Errorf("unknown message-type: %v", messageType)
} }
if err = msg.Read(c); err != nil { if err = msg.Read(c); err != nil {
break serverLoop return err
} }
if c.cfg.ServerMessageCh == nil { if c.cfg.ServerMessageCh == nil {
continue continue
@ -430,6 +473,9 @@ serverLoop:
c.cfg.ServerMessageCh <- msg c.cfg.ServerMessageCh <- msg
} }
} }
}()
wg.Wait()
fmt.Printf("tttt\n")
return err return err
} }
@ -440,7 +486,7 @@ type ClientHandler func(*ClientConfig, Conn) error
type ClientConfig struct { type ClientConfig struct {
VersionHandler ClientHandler VersionHandler ClientHandler
SecurityHandler ClientHandler SecurityHandler ClientHandler
SecurityHandlers []ClientHandler SecurityHandlers []SecurityHandler
ClientInitHandler ClientHandler ClientInitHandler ClientHandler
ServerInitHandler ClientHandler ServerInitHandler ClientHandler
Encodings []Encoding Encodings []Encoding

10
conn.go
View File

@ -1,18 +1,26 @@
package vnc package vnc
import "io" import (
"io"
"net"
)
type Conn interface { type Conn interface {
io.ReadWriteCloser io.ReadWriteCloser
Conn() net.Conn
Protocol() string Protocol() string
PixelFormat() *PixelFormat PixelFormat() *PixelFormat
SetPixelFormat(*PixelFormat) error
ColorMap() *ColorMap ColorMap() *ColorMap
SetColorMap(*ColorMap) SetColorMap(*ColorMap)
Encodings() []Encoding Encodings() []Encoding
SetEncodings([]EncodingType) error
Width() uint16 Width() uint16
Height() uint16 Height() uint16
SetWidth(uint16) SetWidth(uint16)
SetHeight(uint16) SetHeight(uint16)
DesktopName() string DesktopName() string
SetDesktopName(string)
Flush() error Flush() error
SetProtoVersion(string)
} }

View File

@ -1,5 +1,11 @@
package vnc package vnc
import (
"bytes"
"encoding/binary"
"fmt"
)
// EncodingType represents a known VNC encoding type. // EncodingType represents a known VNC encoding type.
type EncodingType int32 type EncodingType int32
@ -9,11 +15,23 @@ const (
EncRaw EncodingType = 0 EncRaw EncodingType = 0
EncCopyRect EncodingType = 1 EncCopyRect EncodingType = 1
EncRRE EncodingType = 2 EncRRE EncodingType = 2
EncCoRRE EncodingType = 4
EncHextile EncodingType = 5 EncHextile EncodingType = 5
EncZlib EncodingType = 6
EncTight EncodingType = 7
EncZlibHex EncodingType = 8
EncUltra1 EncodingType = 9
EncUltra2 EncodingType = 10
EncJPEG EncodingType = 21
EncJRLE EncodingType = 22
//EncRichCursor EncodingType = 0xFFFFFF11
//EncPointerPos EncodingType = 0xFFFFFF18
//EncLastRec EncodingType = 0xFFFFFF20
EncTRLE EncodingType = 15 EncTRLE EncodingType = 15
EncZRLE EncodingType = 16 EncZRLE EncodingType = 16
EncColorPseudo EncodingType = -239 EncColorPseudo EncodingType = -239
EncDesktopSizePseudo EncodingType = -223 EncDesktopSizePseudo EncodingType = -223
EncClientRedirect EncodingType = -311
) )
type Encoding interface { type Encoding interface {
@ -44,29 +62,30 @@ func (enc *RawEncoding) Write(c Conn, rect *Rectangle) error {
// Read implements the Encoding interface. // Read implements the Encoding interface.
func (enc *RawEncoding) Read(c Conn, rect *Rectangle) error { func (enc *RawEncoding) Read(c Conn, rect *Rectangle) error {
/* buf := bytes.NewBuffer(nil)
var buf bytes.Buffer
pf := c.PixelFormat() pf := c.PixelFormat()
cm := c.ColorMap() cm := c.ColorMap()
bytesPerPixel := int(pf.BPP / 8) bytesPerPixel := int(pf.BPP / 8)
n := rect.Area() * bytesPerPixel n := rect.Area() * bytesPerPixel
if err := c.receiveN(&buf, n); err != nil { data := make([]byte, n)
return fmt.Errorf("unable to read rectangle with raw encoding: %s", err) fmt.Printf("eeee\n")
if err := binary.Read(c, binary.BigEndian, &data); err != nil {
return err
} }
buf.Write(data)
defer buf.Reset()
colors := make([]Color, rect.Area()) colors := make([]Color, rect.Area())
for y := uint16(0); y < rect.Height; y++ { for y := uint16(0); y < rect.Height; y++ {
for x := uint16(0); x < rect.Width; x++ { for x := uint16(0); x < rect.Width; x++ {
color := NewColor(pf, cm) color := NewColor(pf, cm)
if err := color.Unmarshal(buf.Next(bytesPerPixel)); err != nil { if err := color.Unmarshal(buf.Next(bytesPerPixel)); err != nil {
return nil, err return err
} }
colors[int(y)*int(rect.Width)+int(x)] = *color colors[int(y)*int(rect.Width)+int(x)] = *color
} }
} }
return &RawEncoding{colors}, nil enc.Colors = colors
*/
return nil return nil
} }

View File

@ -5,34 +5,44 @@ package vnc
import "fmt" import "fmt"
const ( const (
_EncodingType_name_0 = "EncColorPseudo" _EncodingType_name_0 = "EncClientRedirect"
_EncodingType_name_1 = "EncDesktopSizePseudo" _EncodingType_name_1 = "EncColorPseudo"
_EncodingType_name_2 = "EncRawEncCopyRectEncRRE" _EncodingType_name_2 = "EncDesktopSizePseudo"
_EncodingType_name_3 = "EncHextile" _EncodingType_name_3 = "EncRawEncCopyRectEncRRE"
_EncodingType_name_4 = "EncTRLEEncZRLE" _EncodingType_name_4 = "EncCoRREEncHextileEncZlibEncTightEncZlibHexEncUltra1EncUltra2"
_EncodingType_name_5 = "EncTRLEEncZRLE"
_EncodingType_name_6 = "EncJPEGEncJRLE"
) )
var ( var (
_EncodingType_index_0 = [...]uint8{0, 14} _EncodingType_index_0 = [...]uint8{0, 17}
_EncodingType_index_1 = [...]uint8{0, 20} _EncodingType_index_1 = [...]uint8{0, 14}
_EncodingType_index_2 = [...]uint8{0, 6, 17, 23} _EncodingType_index_2 = [...]uint8{0, 20}
_EncodingType_index_3 = [...]uint8{0, 10} _EncodingType_index_3 = [...]uint8{0, 6, 17, 23}
_EncodingType_index_4 = [...]uint8{0, 7, 14} _EncodingType_index_4 = [...]uint8{0, 8, 18, 25, 33, 43, 52, 61}
_EncodingType_index_5 = [...]uint8{0, 7, 14}
_EncodingType_index_6 = [...]uint8{0, 7, 14}
) )
func (i EncodingType) String() string { func (i EncodingType) String() string {
switch { switch {
case i == -239: case i == -311:
return _EncodingType_name_0 return _EncodingType_name_0
case i == -223: case i == -239:
return _EncodingType_name_1 return _EncodingType_name_1
case i == -223:
return _EncodingType_name_2
case 0 <= i && i <= 2: case 0 <= i && i <= 2:
return _EncodingType_name_2[_EncodingType_index_2[i]:_EncodingType_index_2[i+1]] return _EncodingType_name_3[_EncodingType_index_3[i]:_EncodingType_index_3[i+1]]
case i == 5: case 4 <= i && i <= 10:
return _EncodingType_name_3 i -= 4
return _EncodingType_name_4[_EncodingType_index_4[i]:_EncodingType_index_4[i+1]]
case 15 <= i && i <= 16: case 15 <= i && i <= 16:
i -= 15 i -= 15
return _EncodingType_name_4[_EncodingType_index_4[i]:_EncodingType_index_4[i+1]] return _EncodingType_name_5[_EncodingType_index_5[i]:_EncodingType_index_5[i+1]]
case 21 <= i && i <= 22:
i -= 21
return _EncodingType_name_6[_EncodingType_index_6[i]:_EncodingType_index_6[i+1]]
default: default:
return fmt.Sprintf("EncodingType(%d)", i) return fmt.Sprintf("EncodingType(%d)", i)
} }

View File

@ -2,69 +2,111 @@ package main
import ( import (
"context" "context"
"flag"
"log" "log"
"net" "net"
"os"
"time"
vnc "github.com/kward/go-vnc" vnc "github.com/vtolstov/go-vnc"
"github.com/kward/go-vnc/logging"
"github.com/kward/go-vnc/messages"
"github.com/kward/go-vnc/rfbflags"
) )
func main() { func main() {
flag.Parse() ln, err := net.Listen("tcp", ":5900")
logging.V(logging.FnDeclLevel)
ln, err := net.Listen("tcp", os.Args[1])
if err != nil { if err != nil {
log.Fatalf("Error listen. %v", err) log.Fatalf("Error listen. %v", err)
} }
// Negotiate connection with the server. schServer := make(chan vnc.ClientMessage)
sch := make(chan vnc.ClientMessage) schClient := make(chan vnc.ServerMessage)
// handle client messages. scfg := &vnc.ServerConfig{
vcc := vnc.NewServerConfig() Width: 800,
vcc.Auth = []vnc.ServerAuth{&vnc.ServerAuthNone{}} Height: 600,
vcc.ClientMessageCh = sch VersionHandler: vnc.ServerVersionHandler,
go vnc.Serve(context.Background(), ln, vcc) SecurityHandler: vnc.ServerSecurityHandler,
SecurityHandlers: []vnc.SecurityHandler{&vnc.ClientAuthNone{}},
ClientInitHandler: vnc.ServerClientInitHandler,
ServerInitHandler: vnc.ServerServerInitHandler,
Encodings: []vnc.Encoding{&vnc.RawEncoding{}},
PixelFormat: vnc.PixelFormat24bit,
ClientMessageCh: schServer,
ServerMessageCh: schClient,
ClientMessages: vnc.DefaultClientMessages,
DesktopName: []byte("vnc proxy"),
}
go vnc.Serve(context.Background(), ln, scfg)
nc, err := net.Dial("tcp", os.Args[1]) c, err := net.Dial("tcp", "127.0.0.1:5944")
if err != nil { if err != nil {
log.Fatalf("Error connecting to VNC host. %v", err) log.Fatalf("Error dial. %v", err)
}
cchServer := make(chan vnc.ServerMessage)
cchClient := make(chan vnc.ClientMessage)
ccfg := &vnc.ClientConfig{
VersionHandler: vnc.ClientVersionHandler,
SecurityHandler: vnc.ClientSecurityHandler,
SecurityHandlers: []vnc.SecurityHandler{&vnc.ClientAuthNone{}},
ClientInitHandler: vnc.ClientClientInitHandler,
ServerInitHandler: vnc.ClientServerInitHandler,
PixelFormat: vnc.PixelFormat24bit,
ClientMessageCh: cchClient,
ServerMessageCh: cchServer,
ServerMessages: vnc.DefaultServerMessages,
Encodings: []vnc.Encoding{&vnc.RawEncoding{}},
} }
// Negotiate connection with the server. cc, err := vnc.Connect(context.Background(), c, ccfg)
cch := make(chan vnc.ServerMessage)
vc, err := vnc.Connect(context.Background(), nc,
&vnc.ClientConfig{
Auth: []vnc.ClientAuth{&vnc.ClientAuthNone{}},
ServerMessageCh: cch,
})
if err != nil { if err != nil {
log.Fatalf("Error negotiating connection to VNC host. %v", err) log.Fatalf("Error dial. %v", err)
} }
defer cc.Close()
go cc.Handle()
// Listen and handle server messages.
go vc.ListenAndHandle()
// Process messages coming in on the ServerMessage channel.
for { for {
msg := <-ch select {
case msg := <-cchClient:
switch msg.Type() { switch msg.Type() {
case messages.FramebufferUpdate:
log.Println("Received FramebufferUpdate message.")
default: default:
log.Printf("Received message type:%v msg:%v\n", msg.Type(), msg) log.Printf("00 Received message type:%v msg:%v\n", msg.Type(), msg)
}
case msg := <-cchServer:
switch msg.Type() {
default:
log.Printf("01 Received message type:%v msg:%v\n", msg.Type(), msg)
}
case msg := <-schClient:
switch msg.Type() {
default:
log.Printf("10 Received message type:%v msg:%v\n", msg.Type(), msg)
}
case msg := <-schServer:
log.Printf("11 Received message type:%v msg:%v\n", msg.Type(), msg)
switch msg.Type() {
case vnc.SetEncodingsMsgType:
encRaw := &vnc.RawEncoding{}
msg1 := &vnc.SetEncodings{
MsgType: vnc.SetEncodingsMsgType,
EncNum: 1,
Encodings: []vnc.EncodingType{encRaw.Type()},
}
if err := msg1.Write(cc); err != nil {
log.Fatalf("err %v\n", err)
}
msg2 := &vnc.FramebufferUpdateRequest{
MsgType: vnc.FramebufferUpdateRequestMsgType,
Inc: 0,
X: 0,
Y: 0,
Width: cc.Width(),
Height: cc.Height(),
}
if err := msg2.Write(cc); err != nil {
log.Fatalf("err %v\n", err)
}
default:
if err := msg.Write(cc); err != nil {
log.Fatalf("err %v\n", err)
}
} }
} }
// Process messages coming in on the ClientMessage channel.
for {
msg := <-ch
msg.Write(
} }
} }

View File

@ -2,8 +2,12 @@ package main
import ( import (
"context" "context"
"fmt"
"image"
"log" "log"
"math"
"net" "net"
"time"
vnc "github.com/vtolstov/go-vnc" vnc "github.com/vtolstov/go-vnc"
) )
@ -17,10 +21,16 @@ func main() {
chServer := make(chan vnc.ClientMessage) chServer := make(chan vnc.ClientMessage)
chClient := make(chan vnc.ServerMessage) chClient := make(chan vnc.ServerMessage)
im := image.NewRGBA(image.Rect(0, 0, width, height))
tick := time.NewTicker(time.Second / 2)
defer tick.Stop()
cfg := &vnc.ServerConfig{ cfg := &vnc.ServerConfig{
Width: 800,
Height: 600,
VersionHandler: vnc.ServerVersionHandler, VersionHandler: vnc.ServerVersionHandler,
SecurityHandler: vnc.ServerSecurityHandler, SecurityHandler: vnc.ServerSecurityHandler,
SecurityHandlers: []vnc.ServerHandler{vnc.ServerSecurityNoneHandler}, SecurityHandlers: []vnc.SecurityHandler{&vnc.ClientAuthNone{}},
ClientInitHandler: vnc.ServerClientInitHandler, ClientInitHandler: vnc.ServerClientInitHandler,
ServerInitHandler: vnc.ServerServerInitHandler, ServerInitHandler: vnc.ServerServerInitHandler,
Encodings: []vnc.Encoding{&vnc.RawEncoding{}}, Encodings: []vnc.Encoding{&vnc.RawEncoding{}},
@ -33,10 +43,51 @@ func main() {
// Process messages coming in on the ClientMessage channel. // Process messages coming in on the ClientMessage channel.
for { for {
msg := <-chClient select {
case <-tick.C:
drawImage(im, 0)
fmt.Printf("tick\n")
case msg := <-chClient:
switch msg.Type() { switch msg.Type() {
default: default:
log.Printf("Received message type:%v msg:%v\n", msg.Type(), msg) log.Printf("11 Received message type:%v msg:%v\n", msg.Type(), msg)
}
case msg := <-chServer:
switch msg.Type() {
default:
log.Printf("22 Received message type:%v msg:%v\n", msg.Type(), msg)
}
}
}
}
const (
width = 800
height = 600
)
func drawImage(im *image.RGBA, anim int) {
pos := 0
const border = 50
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
var r, g, b uint8
switch {
case x < border*2.5 && x < int((1.1+math.Sin(float64(y+anim*2)/40))*border):
r = 255
case x > width-border*2.5 && x > width-int((1.1+math.Sin(math.Pi+float64(y+anim*2)/40))*border):
g = 255
case y < border*2.5 && y < int((1.1+math.Sin(float64(x+anim*2)/40))*border):
r, g = 255, 255
case y > height-border*2.5 && y > height-int((1.1+math.Sin(math.Pi+float64(x+anim*2)/40))*border):
b = 255
default:
r, g, b = uint8(x+anim), uint8(y+anim), uint8(x+y+anim*3)
}
im.Pix[pos] = r
im.Pix[pos+1] = g
im.Pix[pos+2] = b
pos += 4 // skipping alpha
} }
} }
} }

View File

@ -1,5 +1,10 @@
package vnc package vnc
import (
"encoding/binary"
"fmt"
)
// ClientMessage is the interface // ClientMessage is the interface
type ClientMessage interface { type ClientMessage interface {
Type() ClientMessageType Type() ClientMessageType
@ -14,22 +19,264 @@ type ServerMessage interface {
Write(Conn) error Write(Conn) error
} }
const ProtoVersionLength = 12
const (
ProtoVersionUnknown = ""
ProtoVersion33 = "RFB 003.003\n"
ProtoVersion38 = "RFB 003.008\n"
)
func ParseProtoVersion(pv []byte) (uint, uint, error) {
var major, minor uint
if len(pv) < ProtoVersionLength {
return 0, 0, fmt.Errorf("ProtocolVersion message too short (%v < %v)", len(pv), ProtoVersionLength)
}
l, err := fmt.Sscanf(string(pv), "RFB %d.%d\n", &major, &minor)
if l != 2 {
return 0, 0, fmt.Errorf("error parsing ProtocolVersion.")
}
if err != nil {
return 0, 0, err
}
return major, minor, nil
}
func ClientVersionHandler(cfg *ClientConfig, c Conn) error {
var version [ProtoVersionLength]byte
if err := binary.Read(c, binary.BigEndian, &version); err != nil {
return err
}
major, minor, err := ParseProtoVersion(version[:])
if err != nil {
return err
}
pv := ProtoVersionUnknown
if major == 3 {
if minor >= 8 {
pv = ProtoVersion38
} else if minor >= 3 {
pv = ProtoVersion38
}
}
if pv == ProtoVersionUnknown {
return fmt.Errorf("ProtocolVersion handshake failed; unsupported version '%v'", string(version[:]))
}
c.SetProtoVersion(pv)
if err := binary.Write(c, binary.BigEndian, []byte(pv)); err != nil {
return err
}
return c.Flush()
}
func ServerVersionHandler(cfg *ServerConfig, c Conn) error { func ServerVersionHandler(cfg *ServerConfig, c Conn) error {
var version [ProtoVersionLength]byte
if err := binary.Write(c, binary.BigEndian, []byte(ProtoVersion38)); err != nil {
return err
}
if err := c.Flush(); err != nil {
return err
}
if err := binary.Read(c, binary.BigEndian, &version); err != nil {
return err
}
major, minor, err := ParseProtoVersion(version[:])
if err != nil {
return err
}
pv := ProtoVersionUnknown
if major == 3 {
if minor >= 8 {
pv = ProtoVersion38
} else if minor >= 3 {
pv = ProtoVersion33
}
}
if pv == ProtoVersionUnknown {
return fmt.Errorf("ProtocolVersion handshake failed; unsupported version '%v'", string(version[:]))
}
c.SetProtoVersion(pv)
return nil
}
func ClientSecurityHandler(cfg *ClientConfig, c Conn) error {
var numSecurityTypes uint8
if err := binary.Read(c, binary.BigEndian, &numSecurityTypes); err != nil {
return err
}
secTypes := make([]SecurityType, numSecurityTypes)
if err := binary.Read(c, binary.BigEndian, &secTypes); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, cfg.SecurityHandlers[0].Type()); err != nil {
return err
}
if err := c.Flush(); err != nil {
return err
}
var authCode uint32
if err := binary.Read(c, binary.BigEndian, &authCode); err != nil {
return err
}
if authCode == 1 {
var reasonLength uint32
if err := binary.Read(c, binary.BigEndian, &reasonLength); err != nil {
return err
}
reasonText := make([]byte, reasonLength)
if err := binary.Read(c, binary.BigEndian, &reasonText); err != nil {
return err
}
return fmt.Errorf("%s", reasonText)
}
return nil return nil
} }
func ServerSecurityHandler(cfg *ServerConfig, c Conn) error { func ServerSecurityHandler(cfg *ServerConfig, c Conn) error {
if err := binary.Write(c, binary.BigEndian, uint8(len(cfg.SecurityHandlers))); err != nil {
return err
}
for _, sectype := range cfg.SecurityHandlers {
if err := binary.Write(c, binary.BigEndian, sectype.Type()); err != nil {
return err
}
}
if err := c.Flush(); err != nil {
return err
}
var secType SecurityType
if err := binary.Read(c, binary.BigEndian, &secType); err != nil {
return err
}
secTypes := make(map[SecurityType]SecurityHandler)
for _, sType := range cfg.SecurityHandlers {
secTypes[sType.Type()] = sType
}
sType, ok := secTypes[secType]
if !ok {
return fmt.Errorf("server type %d not implemented")
}
var authCode uint32
authErr := sType.Auth(c)
if authErr != nil {
authCode = uint32(1)
}
if err := binary.Write(c, binary.BigEndian, authCode); err != nil {
return err
}
if err := c.Flush(); err != nil {
return err
}
if authErr != nil {
if err := binary.Write(c, binary.BigEndian, len(authErr.Error())); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, []byte(authErr.Error())); err != nil {
return err
}
return authErr
}
return nil return nil
} }
func ServerSecurityNoneHandler(cfg *ServerConfig, c Conn) error { func ClientServerInitHandler(cfg *ClientConfig, c Conn) error {
srvInit := &ServerInit{}
if err := binary.Read(c, binary.BigEndian, &srvInit.FBWidth); err != nil {
return err
}
if err := binary.Read(c, binary.BigEndian, &srvInit.FBHeight); err != nil {
return err
}
if err := binary.Read(c, binary.BigEndian, &srvInit.PixelFormat); err != nil {
return err
}
if err := binary.Read(c, binary.BigEndian, &srvInit.NameLength); err != nil {
return err
}
nameText := make([]byte, srvInit.NameLength)
if err := binary.Read(c, binary.BigEndian, nameText); err != nil {
return err
}
srvInit.NameText = nameText
c.SetDesktopName(string(srvInit.NameText))
c.SetWidth(srvInit.FBWidth)
c.SetHeight(srvInit.FBHeight)
c.SetPixelFormat(&srvInit.PixelFormat)
return nil return nil
} }
func ServerServerInitHandler(cfg *ServerConfig, c Conn) error { func ServerServerInitHandler(cfg *ServerConfig, c Conn) error {
return nil srvInit := &ServerInit{
FBWidth: c.Width(),
FBHeight: c.Height(),
PixelFormat: *c.PixelFormat(),
NameLength: uint32(len(cfg.DesktopName)),
NameText: []byte(cfg.DesktopName),
}
if err := binary.Write(c, binary.BigEndian, srvInit.FBWidth); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, srvInit.FBHeight); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, srvInit.PixelFormat); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, srvInit.NameLength); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, srvInit.NameText); err != nil {
return err
}
return c.Flush()
}
func ClientClientInitHandler(cfg *ClientConfig, c Conn) error {
if err := binary.Write(c, binary.BigEndian, cfg.Exclusive); err != nil {
return err
}
return c.Flush()
} }
func ServerClientInitHandler(cfg *ServerConfig, c Conn) error { func ServerClientInitHandler(cfg *ServerConfig, c Conn) error {
var shared uint8
if err := binary.Read(c, binary.BigEndian, &shared); err != nil {
return err
}
/* TODO
if shared != 1 {
c.SetShared(false)
}
*/
return nil return nil
} }

View File

@ -1,6 +1,8 @@
package vnc package vnc
import ( import (
"encoding/binary"
"fmt"
"image" "image"
) )
@ -29,9 +31,14 @@ func NewColor(pf *PixelFormat, cm *ColorMap) *Color {
type Rectangle struct { type Rectangle struct {
X, Y uint16 X, Y uint16
Width, Height uint16 Width, Height uint16
EncType EncodingType
Enc Encoding Enc Encoding
} }
func NewRectangle() *Rectangle {
return &Rectangle{}
}
// Marshal implements the Marshaler interface. // Marshal implements the Marshaler interface.
func (c *Color) Marshal() ([]byte, error) { func (c *Color) Marshal() ([]byte, error) {
order := c.pf.order() order := c.pf.order()
@ -106,50 +113,54 @@ func colorsToImage(x, y, width, height uint16, colors []Color) *image.RGBA64 {
} }
// Marshal implements the Marshaler interface. // Marshal implements the Marshaler interface.
func (r *Rectangle) Marshal() ([]byte, error) { func (r *Rectangle) Write(c Conn) error {
/* if err := binary.Write(c, binary.BigEndian, r.X); err != nil {
buf := bytes.NewBuffer(nil)
var msg Rectangle
msg.X, msg.Y, msg.W, msg.H = r.X, r.Y, r.Width, r.Height
msg.E = r.Enc.Type()
if err := binary.Write(buf, binary.BigEndian, msg); err != nil {
return nil, err
}
bytes, err := r.Enc.Marshal()
if err != nil {
return nil, err
}
if err := binary.Write(buf, binary.BigEndian, bytes); err != nil {
return nil, err
}
return buf.Bytes(), nil
*/
return nil, nil
}
// Unmarshal implements the Unmarshaler interface.
func (r *Rectangle) Unmarshal(data []byte) error {
/*
buf := bytes.NewBuffer(data)
var msg Rectangle
if err := binary.Read(buf, binary.BigEndian, &msg); err != nil {
return err return err
} }
r.X, r.Y, r.Width, r.Height = msg.X, msg.Y, msg.W, msg.H if err := binary.Write(c, binary.BigEndian, r.Y); err != nil {
return err
switch msg.E {
case encodings.Raw:
r.Enc = &RawEncoding{}
default:
return fmt.Errorf("unable to unmarshal encoding %v", msg.E)
} }
return nil if err := binary.Write(c, binary.BigEndian, r.Width); err != nil {
*/ return err
return nil }
if err := binary.Write(c, binary.BigEndian, r.Height); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, r.EncType); err != nil {
return err
}
if err := r.Enc.Write(c, r); err != nil {
return err
}
return c.Flush()
}
func (r *Rectangle) Read(c Conn) error {
fmt.Printf("qqq\n")
var err error
if err = binary.Read(c, binary.BigEndian, &r.X); err != nil {
return err
}
if err = binary.Read(c, binary.BigEndian, &r.Y); err != nil {
return err
}
if err = binary.Read(c, binary.BigEndian, &r.Width); err != nil {
return err
}
if err = binary.Read(c, binary.BigEndian, &r.Height); err != nil {
return err
}
if err = binary.Read(c, binary.BigEndian, &r.EncType); err != nil {
return err
}
fmt.Printf("rrrr %#+v\n", r)
switch r.EncType {
case EncRaw:
r.Enc = &RawEncoding{}
}
return r.Enc.Read(c, r)
} }
// Area returns the total area in pixels of the Rectangle. // Area returns the total area in pixels of the Rectangle.

View File

@ -13,6 +13,7 @@ import (
var ( var (
PixelFormat8bit *PixelFormat = NewPixelFormat(8) PixelFormat8bit *PixelFormat = NewPixelFormat(8)
PixelFormat16bit *PixelFormat = NewPixelFormat(16) PixelFormat16bit *PixelFormat = NewPixelFormat(16)
PixelFormat24bit *PixelFormat = NewPixelFormat(24)
PixelFormat32bit *PixelFormat = NewPixelFormat(32) PixelFormat32bit *PixelFormat = NewPixelFormat(32)
) )
@ -31,7 +32,7 @@ const pixelFormatLen = 16
// NewPixelFormat returns a populated PixelFormat structure. // NewPixelFormat returns a populated PixelFormat structure.
func NewPixelFormat(bpp uint8) *PixelFormat { func NewPixelFormat(bpp uint8) *PixelFormat {
bigEndian := uint8(1) bigEndian := uint8(0)
rgbMax := uint16(math.Exp2(float64(bpp))) - 1 rgbMax := uint16(math.Exp2(float64(bpp))) - 1
var ( var (
tc = uint8(1) tc = uint8(1)
@ -53,18 +54,18 @@ func NewPixelFormat(bpp uint8) *PixelFormat {
func (pf *PixelFormat) Marshal() ([]byte, error) { func (pf *PixelFormat) Marshal() ([]byte, error) {
// Validation checks. // Validation checks.
switch pf.BPP { switch pf.BPP {
case 8, 16, 32: case 8, 16, 24, 32:
default: default:
return nil, fmt.Errorf("Invalid BPP value %v; must be 8, 16, or 32.", pf.BPP) return nil, fmt.Errorf("Invalid BPP value %v; must be 8, 16, 24 or 32.", pf.BPP)
} }
if pf.Depth < pf.BPP { if pf.Depth < pf.BPP {
return nil, fmt.Errorf("Invalid Depth value %v; cannot be < BPP", pf.Depth) return nil, fmt.Errorf("Invalid Depth value %v; cannot be < BPP", pf.Depth)
} }
switch pf.Depth { switch pf.Depth {
case 8, 16, 32: case 8, 16, 24, 32:
default: default:
return nil, fmt.Errorf("Invalid Depth value %v; must be 8, 16, or 32.", pf.Depth) return nil, fmt.Errorf("Invalid Depth value %v; must be 8, 16, 24 or 32.", pf.Depth)
} }
// Create the slice of bytes // Create the slice of bytes

180
security.go Normal file
View File

@ -0,0 +1,180 @@
package vnc
import (
"crypto/des"
"encoding/binary"
"fmt"
)
type SecurityType uint8
const (
SecTypeUnknown = SecurityType(0)
SecTypeNone = SecurityType(1)
SecTypeVNC = SecurityType(2)
SecTypeVeNCrypt = SecurityType(19)
)
type SecuritySubType uint32
const (
SecSubTypeUnknown = SecuritySubType(0)
)
const (
SecSubTypeVeNCrypt01Unknown = SecuritySubType(0)
SecSubTypeVeNCrypt01Plain = SecuritySubType(19)
SecSubTypeVeNCrypt01TLSNone = SecuritySubType(20)
SecSubTypeVeNCrypt01TLSVNC = SecuritySubType(21)
SecSubTypeVeNCrypt01TLSPlain = SecuritySubType(22)
SecSubTypeVeNCrypt01X509None = SecuritySubType(23)
SecSubTypeVeNCrypt01X509VNC = SecuritySubType(24)
SecSubTypeVeNCrypt01X509Plain = SecuritySubType(25)
)
const (
SecSubTypeVeNCrypt02Unknown = SecuritySubType(0)
SecSubTypeVeNCrypt02Plain = SecuritySubType(256)
SecSubTypeVeNCrypt02TLSNone = SecuritySubType(257)
SecSubTypeVeNCrypt02TLSVNC = SecuritySubType(258)
SecSubTypeVeNCrypt02TLSPlain = SecuritySubType(259)
SecSubTypeVeNCrypt02X509None = SecuritySubType(260)
SecSubTypeVeNCrypt02X509VNC = SecuritySubType(261)
SecSubTypeVeNCrypt02X509Plain = SecuritySubType(262)
)
type SecurityHandler interface {
Type() SecurityType
SubType() SecuritySubType
Auth(Conn) error
}
type ClientAuthNone struct{}
func (*ClientAuthNone) Type() SecurityType {
return SecTypeNone
}
func (*ClientAuthNone) SubType() SecuritySubType {
return SecSubTypeUnknown
}
func (*ClientAuthNone) Auth(conn Conn) error {
return nil
}
// ServerAuthNone is the "none" authentication. See 7.2.1.
type ServerAuthNone struct{}
func (*ServerAuthNone) Type() SecurityType {
return SecTypeNone
}
func (*ServerAuthNone) Auth(c Conn) error {
return nil
}
func (*ClientAuthVeNCrypt02Plain) SubType() SecuritySubType {
return SecSubTypeVeNCrypt02Plain
}
// ClientAuthVeNCryptPlain see https://www.berrange.com/~dan/vencrypt.txt
type ClientAuthVeNCrypt02Plain struct {
Username []byte
Password []byte
}
func (auth *ClientAuthVeNCrypt02Plain) Auth(c Conn) error {
if len(auth.Password) == 0 || len(auth.Username) == 0 {
return fmt.Errorf("Security Handshake failed; no username and/or password provided for VeNCryptAuth.")
}
if err := binary.Write(c, binary.BigEndian, uint32(len(auth.Username))); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, uint32(len(auth.Password))); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, auth.Username); err != nil {
return err
}
if err := binary.Write(c, binary.BigEndian, auth.Password); err != nil {
return err
}
return c.Flush()
}
// ServerAuthVNC is the standard password authentication. See 7.2.2.
type ServerAuthVNC struct{}
func (*ServerAuthVNC) Type() SecurityType {
return SecTypeVNC
}
func (*ServerAuthVNC) SubType() SecuritySubType {
return SecSubTypeUnknown
}
func (auth *ServerAuthVNC) Auth(c Conn) error {
return nil
}
// ClientAuthVNC is the standard password authentication. See 7.2.2.
type ClientAuthVNC struct {
Challenge [16]byte
Password []byte
}
func (*ClientAuthVNC) Type() SecurityType {
return SecTypeVNC
}
func (*ClientAuthVNC) SubType() SecuritySubType {
return SecSubTypeUnknown
}
func (auth *ClientAuthVNC) Auth(c Conn) error {
if len(auth.Password) == 0 {
return fmt.Errorf("Security Handshake failed; no password provided for VNCAuth.")
}
if err := binary.Read(c, binary.BigEndian, auth.Challenge); err != nil {
return err
}
auth.encode()
// Send the encrypted challenge back to server
if err := binary.Write(c, binary.BigEndian, auth.Challenge); err != nil {
return err
}
return c.Flush()
}
func (auth *ClientAuthVNC) encode() error {
// Copy password string to 8 byte 0-padded slice
key := make([]byte, 8)
copy(key, auth.Password)
// Each byte of the password needs to be reversed. This is a
// non RFC-documented behaviour of VNC clients and servers
for i := range key {
key[i] = (key[i]&0x55)<<1 | (key[i]&0xAA)>>1 // Swap adjacent bits
key[i] = (key[i]&0x33)<<2 | (key[i]&0xCC)>>2 // Swap adjacent pairs
key[i] = (key[i]&0x0F)<<4 | (key[i]&0xF0)>>4 // Swap the 2 halves
}
// Encrypt challenge with key.
cipher, err := des.NewCipher(key)
if err != nil {
return err
}
for i := 0; i < len(auth.Challenge); i += cipher.BlockSize() {
cipher.Encrypt(auth.Challenge[i:i+cipher.BlockSize()], auth.Challenge[i:i+cipher.BlockSize()])
}
return nil
}

122
server.go
View File

@ -6,6 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"sync"
) )
var DefaultClientMessages = []ClientMessage{ var DefaultClientMessages = []ClientMessage{
@ -17,9 +18,43 @@ var DefaultClientMessages = []ClientMessage{
&ClientCutText{}, &ClientCutText{},
} }
type ServerInit struct {
FBWidth, FBHeight uint16
PixelFormat PixelFormat
NameLength uint32
NameText []byte
}
var _ Conn = (*ServerConn)(nil) var _ Conn = (*ServerConn)(nil)
func (c *ServerConn) UnreadByte() error {
return c.br.UnreadByte()
}
func (c *ServerConn) Conn() net.Conn {
return c.c
}
func (c *ServerConn) SetEncodings(encs []EncodingType) error {
encodings := make(map[EncodingType]Encoding)
for _, enc := range c.cfg.Encodings {
encodings[enc.Type()] = enc
}
for _, encType := range encs {
if enc, ok := encodings[encType]; ok {
c.encodings = append(c.encodings, enc)
}
}
return nil
}
func (c *ServerConn) SetProtoVersion(pv string) {
c.protocol = pv
}
func (c *ServerConn) Flush() error { func (c *ServerConn) Flush() error {
c.m.Lock()
defer c.m.Unlock()
return c.bw.Flush() return c.bw.Flush()
} }
@ -41,6 +76,8 @@ func (c *ServerConn) Read(buf []byte) (int, error) {
} }
func (c *ServerConn) Write(buf []byte) (int, error) { func (c *ServerConn) Write(buf []byte) (int, error) {
c.m.Lock()
defer c.m.Unlock()
return c.bw.Write(buf) return c.bw.Write(buf)
} }
@ -57,6 +94,13 @@ func (c *ServerConn) DesktopName() string {
func (c *ServerConn) PixelFormat() *PixelFormat { func (c *ServerConn) PixelFormat() *PixelFormat {
return c.pixelFormat return c.pixelFormat
} }
func (c *ServerConn) SetDesktopName(name string) {
c.desktopName = name
}
func (c *ServerConn) SetPixelFormat(pf *PixelFormat) error {
c.pixelFormat = pf
return nil
}
func (c *ServerConn) Encodings() []Encoding { func (c *ServerConn) Encodings() []Encoding {
return c.encodings return c.encodings
} }
@ -94,20 +138,20 @@ const (
// FramebufferUpdate holds a FramebufferUpdate wire format message. // FramebufferUpdate holds a FramebufferUpdate wire format message.
type FramebufferUpdate struct { type FramebufferUpdate struct {
MsgType ServerMessageType MsgType ServerMessageType
NumRect uint16 // number-of-rectangles
_ [1]byte // pad _ [1]byte // pad
NumRect uint16 // number-of-rectangles
Rects []Rectangle // rectangles Rects []Rectangle // rectangles
} }
func (msg *FramebufferUpdate) Type() ServerMessageType { func (msg *FramebufferUpdate) Type() ServerMessageType {
return msg.MsgType return FramebufferUpdateMsgType
} }
func (msg *FramebufferUpdate) Read(c Conn) error { func (msg *FramebufferUpdate) Read(c Conn) error {
if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil { if err := binary.Read(c, binary.BigEndian, msg.MsgType); err != nil {
return err return err
} }
fmt.Printf("qqqqq\n")
var pad [1]byte var pad [1]byte
if err := binary.Read(c, binary.BigEndian, &pad); err != nil { if err := binary.Read(c, binary.BigEndian, &pad); err != nil {
return err return err
@ -116,17 +160,15 @@ func (msg *FramebufferUpdate) Read(c Conn) error {
if err := binary.Read(c, binary.BigEndian, msg.NumRect); err != nil { if err := binary.Read(c, binary.BigEndian, msg.NumRect); err != nil {
return err return err
} }
/* msg.Rects = make([]Rectangle, msg.NumRect)
// Extract rectangles.
rects := make([]Rectangle, msg.NumRect)
for i := uint16(0); i < msg.NumRect; i++ { for i := uint16(0); i < msg.NumRect; i++ {
rect := NewRectangle(c) rect := NewRectangle()
if err := rect.Read(c); err != nil { if err := rect.Read(c); err != nil {
return err return err
} }
msg.Rects = append(msg.Rects, *rect) msg.Rects = append(msg.Rects, *rect)
} }
*/
return nil return nil
} }
@ -141,13 +183,11 @@ func (msg *FramebufferUpdate) Write(c Conn) error {
if err := binary.Write(c, binary.BigEndian, msg.NumRect); err != nil { if err := binary.Write(c, binary.BigEndian, msg.NumRect); err != nil {
return err return err
} }
/*
for _, rect := range msg.Rects { for _, rect := range msg.Rects {
if err := rect.Write(c); err != nil { if err := rect.Write(c); err != nil {
return err return err
} }
} }
*/
return c.Flush() return c.Flush()
} }
@ -157,7 +197,7 @@ type ServerConn struct {
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
protocol string protocol string
m sync.Mutex
// If the pixel format uses a color map, then this is the color // If the pixel format uses a color map, then this is the color
// map that is used. This should not be modified directly, since // map that is used. This should not be modified directly, since
// the data comes from the server. // the data comes from the server.
@ -190,7 +230,7 @@ type ServerHandler func(*ServerConfig, Conn) error
type ServerConfig struct { type ServerConfig struct {
VersionHandler ServerHandler VersionHandler ServerHandler
SecurityHandler ServerHandler SecurityHandler ServerHandler
SecurityHandlers []ServerHandler SecurityHandlers []SecurityHandler
ClientInitHandler ServerHandler ClientInitHandler ServerHandler
ServerInitHandler ServerHandler ServerInitHandler ServerHandler
Encodings []Encoding Encodings []Encoding
@ -199,35 +239,46 @@ type ServerConfig struct {
ClientMessageCh chan ClientMessage ClientMessageCh chan ClientMessage
ServerMessageCh chan ServerMessage ServerMessageCh chan ServerMessage
ClientMessages []ClientMessage ClientMessages []ClientMessage
DesktopName []byte
Height uint16
Width uint16
} }
func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) { func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) {
if cfg.ClientMessageCh == nil { if cfg.ClientMessageCh == nil {
return nil, fmt.Errorf("ClientMessageCh nil") return nil, fmt.Errorf("ClientMessageCh nil")
} }
if len(cfg.ClientMessages) == 0 { if len(cfg.ClientMessages) == 0 {
return nil, fmt.Errorf("ClientMessage 0") return nil, fmt.Errorf("ClientMessage 0")
} }
return &ServerConn{ return &ServerConn{
c: c, c: c,
br: bufio.NewReader(c), br: bufio.NewReader(c),
bw: bufio.NewWriter(c), bw: bufio.NewWriter(c),
cfg: cfg, cfg: cfg,
quit: make(chan struct{}),
encodings: cfg.Encodings, encodings: cfg.Encodings,
pixelFormat: cfg.PixelFormat, pixelFormat: cfg.PixelFormat,
fbWidth: cfg.Width,
fbHeight: cfg.Height,
}, nil }, nil
} }
func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error { func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error {
for { for {
c, err := ln.Accept() c, err := ln.Accept()
if err != nil { if err != nil {
continue continue
} }
conn, err := NewServerConn(c, cfg) conn, err := NewServerConn(c, cfg)
if err != nil { if err != nil {
continue continue
} }
if err := cfg.VersionHandler(cfg, conn); err != nil { if err := cfg.VersionHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
continue continue
@ -237,62 +288,80 @@ func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error {
conn.Close() conn.Close()
continue continue
} }
if err := cfg.ClientInitHandler(cfg, conn); err != nil { if err := cfg.ClientInitHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
continue continue
} }
if err := cfg.ServerInitHandler(cfg, conn); err != nil { if err := cfg.ServerInitHandler(cfg, conn); err != nil {
conn.Close() conn.Close()
continue continue
} }
go conn.Handle() go conn.Handle()
} }
} }
func (c *ServerConn) Handle() error { func (c *ServerConn) Handle() error {
var err error var err error
var wg sync.WaitGroup
defer c.Close() defer c.Close()
clientMessages := make(map[ClientMessageType]ClientMessage) clientMessages := make(map[ClientMessageType]ClientMessage)
for _, m := range c.cfg.ClientMessages { for _, m := range c.cfg.ClientMessages {
clientMessages[m.Type()] = m clientMessages[m.Type()] = m
} }
wg.Add(2)
serverLoop: // server
go func() error {
defer wg.Done()
for { for {
select { select {
case msg := <-c.cfg.ServerMessageCh: case msg := <-c.cfg.ServerMessageCh:
if err = msg.Write(c); err != nil { if err = msg.Write(c); err != nil {
return err return err
} }
c.Flush()
case <-c.quit: case <-c.quit:
break serverLoop return nil
} }
} }
}()
clientLoop: // client
go func() error {
defer wg.Done()
for { for {
select { select {
case <-c.quit: case <-c.quit:
break clientLoop return nil
default: default:
var messageType ClientMessageType var messageType ClientMessageType
if err := binary.Read(c, binary.BigEndian, &messageType); err != nil { if err := binary.Read(c, binary.BigEndian, &messageType); err != nil {
break clientLoop return err
}
if err := c.UnreadByte(); err != nil {
return err
} }
msg, ok := clientMessages[messageType] msg, ok := clientMessages[messageType]
if !ok { if !ok {
err = fmt.Errorf("unsupported message-type: %v", messageType) return fmt.Errorf("unsupported message-type: %v", messageType)
break clientLoop
} }
if err := msg.Read(c); err != nil { if err := msg.Read(c); err != nil {
break clientLoop return err
} }
c.cfg.ClientMessageCh <- msg c.cfg.ClientMessageCh <- msg
} }
} }
}()
wg.Wait()
return nil return nil
} }
@ -304,7 +373,7 @@ type ServerCutText struct {
} }
func (msg *ServerCutText) Type() ServerMessageType { func (msg *ServerCutText) Type() ServerMessageType {
return msg.MsgType return ServerCutTextMsgType
} }
func (msg *ServerCutText) Read(c Conn) error { func (msg *ServerCutText) Read(c Conn) error {
@ -353,7 +422,7 @@ type Bell struct {
} }
func (msg *Bell) Type() ServerMessageType { func (msg *Bell) Type() ServerMessageType {
return msg.MsgType return BellMsgType
} }
func (msg *Bell) Read(c Conn) error { func (msg *Bell) Read(c Conn) error {
@ -361,7 +430,10 @@ func (msg *Bell) Read(c Conn) error {
} }
func (msg *Bell) Write(c Conn) error { func (msg *Bell) Write(c Conn) error {
return binary.Write(c, binary.BigEndian, msg.MsgType) if err := binary.Write(c, binary.BigEndian, msg.MsgType); err != nil {
return err
}
return c.Flush()
} }
type SetColorMapEntries struct { type SetColorMapEntries struct {
@ -373,7 +445,7 @@ type SetColorMapEntries struct {
} }
func (msg *SetColorMapEntries) Type() ServerMessageType { func (msg *SetColorMapEntries) Type() ServerMessageType {
return msg.MsgType return SetColorMapEntriesMsgType
} }
func (msg *SetColorMapEntries) Read(c Conn) error { func (msg *SetColorMapEntries) Read(c Conn) error {