simplify client

Signed-off-by: Vasiliy Tolstov <v.tolstov@selfip.ru>
This commit is contained in:
Василий Толстов 2017-06-26 14:16:03 +03:00
parent 1ec2df07e9
commit cc3c34ca50
5 changed files with 244 additions and 144 deletions

View File

@ -16,31 +16,34 @@ var DefaultServerMessages = []ServerMessage{
&ServerCutText{}, &ServerCutText{},
} }
var (
DefaultClientHandlers []ClientHandler = []ClientHandler{
&DefaultClientVersionHandler{},
&DefaultClientSecurityHandler{},
&DefaultClientClientInitHandler{},
&DefaultClientServerInitHandler{},
// &DefaultClientMessageHandler{},
}
)
func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, error) { func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, error) {
conn, err := NewClientConn(c, cfg) conn, err := NewClientConn(c, cfg)
if err != nil { if err != nil {
conn.Close() conn.Close()
cfg.ErrorCh <- err
return nil, err return nil, err
} }
if err := cfg.VersionHandler(cfg, conn); err != nil { if len(cfg.Handlers) == 0 {
conn.Close() cfg.Handlers = DefaultClientHandlers
return nil, err
} }
if err := cfg.SecurityHandler(cfg, conn); err != nil { for _, h := range cfg.Handlers {
if err := h.Handle(conn); err != nil {
conn.Close() conn.Close()
cfg.ErrorCh <- err
return nil, err return nil, err
} }
if err := cfg.ClientInitHandler(cfg, conn); err != nil {
conn.Close()
return nil, err
}
if err := cfg.ServerInitHandler(cfg, conn); err != nil {
conn.Close()
return nil, err
} }
return conn, nil return conn, nil
@ -48,6 +51,10 @@ func Connect(ctx context.Context, c net.Conn, cfg *ClientConfig) (*ClientConn, e
var _ Conn = (*ClientConn)(nil) var _ Conn = (*ClientConn)(nil)
func (c *ClientConn) Config() interface{} {
return c.cfg
}
func (c *ClientConn) Conn() net.Conn { func (c *ClientConn) Conn() net.Conn {
return c.c return c.c
} }
@ -89,13 +96,13 @@ func (c *ClientConn) ColorMap() *ColorMap {
func (c *ClientConn) SetColorMap(cm *ColorMap) { func (c *ClientConn) SetColorMap(cm *ColorMap) {
c.colorMap = cm c.colorMap = cm
} }
func (c *ClientConn) DesktopName() string { func (c *ClientConn) DesktopName() []byte {
return c.desktopName return c.desktopName
} }
func (c *ClientConn) PixelFormat() *PixelFormat { func (c *ClientConn) PixelFormat() *PixelFormat {
return c.pixelFormat return c.pixelFormat
} }
func (c *ClientConn) SetDesktopName(name string) { func (c *ClientConn) SetDesktopName(name []byte) {
c.desktopName = name c.desktopName = name
} }
func (c *ClientConn) SetPixelFormat(pf *PixelFormat) error { func (c *ClientConn) SetPixelFormat(pf *PixelFormat) error {
@ -136,7 +143,7 @@ type ClientConn struct {
colorMap *ColorMap colorMap *ColorMap
// Name associated with the desktop, sent from the server. // Name associated with the desktop, sent from the server.
desktopName string desktopName []byte
// Encodings supported by the client. This should not be modified // Encodings supported by the client. This should not be modified
// directly. Instead, SetEncodings() should be used. // directly. Instead, SetEncodings() should be used.
@ -153,7 +160,8 @@ type ClientConn struct {
// SetPixelFormat method. // SetPixelFormat method.
pixelFormat *PixelFormat pixelFormat *PixelFormat
quit chan struct{} quitCh chan struct{}
errorCh chan error
} }
func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) { func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) {
@ -169,7 +177,8 @@ func NewClientConn(c net.Conn, cfg *ClientConfig) (*ClientConn, error) {
br: bufio.NewReader(c), br: bufio.NewReader(c),
bw: bufio.NewWriter(c), bw: bufio.NewWriter(c),
encodings: cfg.Encodings, encodings: cfg.Encodings,
quit: make(chan struct{}), quitCh: cfg.QuitCh,
errorCh: cfg.ErrorCh,
pixelFormat: cfg.PixelFormat, pixelFormat: cfg.PixelFormat,
}, nil }, nil
} }
@ -422,71 +431,75 @@ func (msg *ClientCutText) Write(c Conn) error {
return c.Flush() return c.Flush()
} }
// ListenAndHandle listens to a VNC server and handles server messages. type DefaultClientMessageHandler struct{}
func (c *ClientConn) Handle() error {
// listens to a VNC server and handles server messages.
func (*DefaultClientMessageHandler) Handle(c Conn) error {
cfg := c.Config().(*ClientConfig)
var err error var err error
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
defer c.Close() defer c.Close()
serverMessages := make(map[ServerMessageType]ServerMessage) serverMessages := make(map[ServerMessageType]ServerMessage)
for _, m := range c.cfg.ServerMessages { for _, m := range cfg.ServerMessages {
serverMessages[m.Type()] = m serverMessages[m.Type()] = m
} }
go func() error { go func() {
defer wg.Done() defer wg.Done()
for { for {
select { select {
case msg := <-c.cfg.ClientMessageCh: case msg := <-cfg.ClientMessageCh:
if err = msg.Write(c); err != nil { if err = msg.Write(c); err != nil {
return err cfg.ErrorCh <- err
return
} }
case <-c.quit:
return nil
} }
} }
}() }()
go func() error { go func() {
defer wg.Done() defer wg.Done()
for { for {
select { select {
case <-c.quit:
return nil
default: default:
var messageType ServerMessageType var messageType ServerMessageType
if err = binary.Read(c, binary.BigEndian, &messageType); err != nil { if err = binary.Read(c, binary.BigEndian, &messageType); err != nil {
return err cfg.ErrorCh <- err
return
} }
msg, ok := serverMessages[messageType] msg, ok := serverMessages[messageType]
if !ok { if !ok {
return fmt.Errorf("unknown message-type: %v", messageType) err = fmt.Errorf("unknown message-type: %v", messageType)
cfg.ErrorCh <- err
return
} }
parsedMsg, err := msg.Read(c) parsedMsg, err := msg.Read(c)
if err != nil { if err != nil {
return err cfg.ErrorCh <- err
return
} }
c.cfg.ServerMessageCh <- parsedMsg cfg.ServerMessageCh <- parsedMsg
} }
} }
}() }()
wg.Wait() wg.Wait()
return err return nil
} }
type ClientHandler func(*ClientConfig, Conn) error type ClientHandler interface {
Handle(Conn) error
}
// A ClientConfig structure is used to configure a ClientConn. After // A ClientConfig structure is used to configure a ClientConn. After
// one has been passed to initialize a connection, it must not be modified. // one has been passed to initialize a connection, it must not be modified.
type ClientConfig struct { type ClientConfig struct {
VersionHandler ClientHandler Handlers []ClientHandler
SecurityHandler ClientHandler
SecurityHandlers []SecurityHandler SecurityHandlers []SecurityHandler
ClientInitHandler ClientHandler
ServerInitHandler ClientHandler
Encodings []Encoding Encodings []Encoding
PixelFormat *PixelFormat PixelFormat *PixelFormat
ColorMap *ColorMap ColorMap *ColorMap
@ -494,4 +507,6 @@ type ClientConfig struct {
ServerMessageCh chan ServerMessage ServerMessageCh chan ServerMessage
Exclusive bool Exclusive bool
ServerMessages []ServerMessage ServerMessages []ServerMessage
QuitCh chan struct{}
ErrorCh chan error
} }

View File

@ -8,6 +8,7 @@ import (
type Conn interface { type Conn interface {
io.ReadWriteCloser io.ReadWriteCloser
Conn() net.Conn Conn() net.Conn
Config() interface{}
Protocol() string Protocol() string
PixelFormat() *PixelFormat PixelFormat() *PixelFormat
SetPixelFormat(*PixelFormat) error SetPixelFormat(*PixelFormat) error
@ -19,8 +20,8 @@ type Conn interface {
Height() uint16 Height() uint16
SetWidth(uint16) SetWidth(uint16)
SetHeight(uint16) SetHeight(uint16)
DesktopName() string DesktopName() []byte
SetDesktopName(string) SetDesktopName([]byte)
Flush() error Flush() error
SetProtoVersion(string) SetProtoVersion(string)
} }

View File

@ -45,7 +45,9 @@ func ParseProtoVersion(pv []byte) (uint, uint, error) {
return major, minor, nil return major, minor, nil
} }
func ClientVersionHandler(cfg *ClientConfig, c Conn) error { type DefaultClientVersionHandler struct{}
func (*DefaultClientVersionHandler) Handle(c Conn) error {
var version [ProtoVersionLength]byte var version [ProtoVersionLength]byte
if err := binary.Read(c, binary.BigEndian, &version); err != nil { if err := binary.Read(c, binary.BigEndian, &version); err != nil {
@ -76,7 +78,9 @@ func ClientVersionHandler(cfg *ClientConfig, c Conn) error {
return c.Flush() return c.Flush()
} }
func ServerVersionHandler(cfg *ServerConfig, c Conn) error { type DefaultServerVersionHandler struct{}
func (*DefaultServerVersionHandler) Handle(c Conn) error {
var version [ProtoVersionLength]byte var version [ProtoVersionLength]byte
if err := binary.Write(c, binary.BigEndian, []byte(ProtoVersion38)); err != nil { if err := binary.Write(c, binary.BigEndian, []byte(ProtoVersion38)); err != nil {
return err return err
@ -109,7 +113,10 @@ func ServerVersionHandler(cfg *ServerConfig, c Conn) error {
return nil return nil
} }
func ClientSecurityHandler(cfg *ClientConfig, c Conn) error { type DefaultClientSecurityHandler struct{}
func (*DefaultClientSecurityHandler) Handle(c Conn) error {
cfg := c.Config().(*ClientConfig)
var numSecurityTypes uint8 var numSecurityTypes uint8
if err := binary.Read(c, binary.BigEndian, &numSecurityTypes); err != nil { if err := binary.Read(c, binary.BigEndian, &numSecurityTypes); err != nil {
return err return err
@ -161,7 +168,10 @@ func ClientSecurityHandler(cfg *ClientConfig, c Conn) error {
return nil return nil
} }
func ServerSecurityHandler(cfg *ServerConfig, c Conn) error { type DefaultServerSecurityHandler struct{}
func (*DefaultServerSecurityHandler) Handle(c Conn) error {
cfg := c.Config().(*ServerConfig)
if err := binary.Write(c, binary.BigEndian, uint8(len(cfg.SecurityHandlers))); err != nil { if err := binary.Write(c, binary.BigEndian, uint8(len(cfg.SecurityHandlers))); err != nil {
return err return err
} }
@ -220,7 +230,9 @@ func ServerSecurityHandler(cfg *ServerConfig, c Conn) error {
return nil return nil
} }
func ClientServerInitHandler(cfg *ClientConfig, c Conn) error { type DefaultClientServerInitHandler struct{}
func (*DefaultClientServerInitHandler) Handle(c Conn) error {
srvInit := &ServerInit{} srvInit := &ServerInit{}
if err := binary.Read(c, binary.BigEndian, &srvInit.FBWidth); err != nil { if err := binary.Read(c, binary.BigEndian, &srvInit.FBWidth); err != nil {
@ -242,42 +254,55 @@ func ClientServerInitHandler(cfg *ClientConfig, c Conn) error {
} }
srvInit.NameText = nameText srvInit.NameText = nameText
c.SetDesktopName(string(srvInit.NameText)) c.SetDesktopName(srvInit.NameText)
c.SetWidth(srvInit.FBWidth) c.SetWidth(srvInit.FBWidth)
c.SetHeight(srvInit.FBHeight) c.SetHeight(srvInit.FBHeight)
c.SetPixelFormat(&srvInit.PixelFormat) c.SetPixelFormat(&srvInit.PixelFormat)
if c.Protocol() == "aten" {
fmt.Printf("$$$$$$\n")
var pad [28]byte
/* 12
8 byte unknown
1 byte IKVMVideoEnable
1 byte IKVMKMEnable
1 byte IKVMKickEnable
1 byte VUSBEnable
*/
if err := binary.Read(c, binary.BigEndian, &pad); err != nil {
return err
}
fmt.Printf("rrrr\n")
}
return nil return nil
} }
func ServerServerInitHandler(cfg *ServerConfig, c Conn) error { type DefaultServerServerInitHandler struct{}
srvInit := &ServerInit{
FBWidth: c.Width(),
FBHeight: c.Height(),
PixelFormat: *c.PixelFormat(),
NameLength: uint32(len(cfg.DesktopName)),
NameText: []byte(cfg.DesktopName),
}
if err := binary.Write(c, binary.BigEndian, srvInit.FBWidth); err != nil { func (*DefaultServerServerInitHandler) Handle(c Conn) error {
if err := binary.Write(c, binary.BigEndian, c.Width()); err != nil {
return err return err
} }
if err := binary.Write(c, binary.BigEndian, srvInit.FBHeight); err != nil { if err := binary.Write(c, binary.BigEndian, c.Height()); err != nil {
return err return err
} }
if err := binary.Write(c, binary.BigEndian, srvInit.PixelFormat); err != nil { if err := binary.Write(c, binary.BigEndian, c.PixelFormat()); err != nil {
return err return err
} }
if err := binary.Write(c, binary.BigEndian, srvInit.NameLength); err != nil { if err := binary.Write(c, binary.BigEndian, uint32(len(c.DesktopName()))); err != nil {
return err return err
} }
if err := binary.Write(c, binary.BigEndian, srvInit.NameText); err != nil { if err := binary.Write(c, binary.BigEndian, []byte(c.DesktopName())); err != nil {
return err return err
} }
return c.Flush() return c.Flush()
} }
func ClientClientInitHandler(cfg *ClientConfig, c Conn) error { type DefaultClientClientInitHandler struct{}
func (*DefaultClientClientInitHandler) Handle(c Conn) error {
cfg := c.Config().(*ClientConfig)
var shared uint8 var shared uint8
if cfg.Exclusive { if cfg.Exclusive {
shared = 0 shared = 0
@ -290,7 +315,9 @@ func ClientClientInitHandler(cfg *ClientConfig, c Conn) error {
return c.Flush() return c.Flush()
} }
func ServerClientInitHandler(cfg *ServerConfig, c Conn) error { type DefaultServerClientInitHandler struct{}
func (*DefaultServerClientInitHandler) Handle(c Conn) error {
var shared uint8 var shared uint8
if err := binary.Read(c, binary.BigEndian, &shared); err != nil { if err := binary.Read(c, binary.BigEndian, &shared); err != nil {
return err return err

View File

@ -73,6 +73,10 @@ func (*ServerAuthNone) Type() SecurityType {
return SecTypeNone return SecTypeNone
} }
func (*ServerAuthNone) SubType() SecuritySubType {
return SecSubTypeUnknown
}
func (*ServerAuthNone) Auth(c Conn) error { func (*ServerAuthNone) Auth(c Conn) error {
return nil return nil
} }
@ -115,7 +119,7 @@ func (auth *ClientAuthATEN) Auth(c Conn) error {
return err return err
} }
if (nt&0xffff0ff0)>>0 == 0xaff90fb0 { if (nt&0xffff0ff0)>>0 == 0xaff90fb0 {
fmt.Printf("aten\n") c.SetProtoVersion("aten")
var skip [20]byte var skip [20]byte
binary.Read(c, binary.BigEndian, &skip) binary.Read(c, binary.BigEndian, &skip)
fmt.Printf("skip %s\n", skip) fmt.Printf("skip %s\n", skip)
@ -158,13 +162,22 @@ func (auth *ClientAuthATEN) Auth(c Conn) error {
sendPassword[i] = 0 sendPassword[i] = 0
} }
} }
if err := binary.Write(c, binary.BigEndian, sendUsername); err != nil { if err := binary.Write(c, binary.BigEndian, sendUsername); err != nil {
return err return err
} }
if err := binary.Write(c, binary.BigEndian, sendPassword); err != nil { if err := binary.Write(c, binary.BigEndian, sendPassword); err != nil {
return err return err
} }
return c.Flush()
if err := c.Flush(); err != nil {
return err
}
//var pp [10]byte
//binary.Read(c, binary.BigEndian, &pp)
//fmt.Printf("ddd %v\n", pp)
return nil
} }
func (*ClientAuthVeNCrypt02Plain) Type() SecurityType { func (*ClientAuthVeNCrypt02Plain) Type() SecurityType {
@ -270,7 +283,11 @@ func (auth *ClientAuthVeNCrypt02Plain) Auth(c Conn) error {
} }
// ServerAuthVNC is the standard password authentication. See 7.2.2. // ServerAuthVNC is the standard password authentication. See 7.2.2.
type ServerAuthVNC struct{} type ServerAuthVNC struct {
Challenge []byte
Password []byte
Crypted []byte
}
func (*ServerAuthVNC) Type() SecurityType { func (*ServerAuthVNC) Type() SecurityType {
return SecTypeVNC return SecTypeVNC
@ -279,13 +296,44 @@ func (*ServerAuthVNC) SubType() SecuritySubType {
return SecSubTypeUnknown return SecSubTypeUnknown
} }
func (auth *ServerAuthVNC) WriteChallenge(c Conn) error {
if err := binary.Write(c, binary.BigEndian, auth.Challenge); err != nil {
return err
}
return c.Flush()
}
func (auth *ServerAuthVNC) ReadChallenge(c Conn) error {
var crypted [16]byte
if err := binary.Read(c, binary.BigEndian, &crypted); err != nil {
return err
}
auth.Crypted = crypted[:]
return nil
}
func (auth *ServerAuthVNC) Auth(c Conn) error { func (auth *ServerAuthVNC) Auth(c Conn) error {
if err := auth.WriteChallenge(c); err != nil {
return err
}
if err := auth.ReadChallenge(c); err != nil {
return err
}
encrypted, err := AuthVNCEncode(auth.Password, auth.Challenge)
if err != nil {
return err
}
if !bytes.Equal(encrypted, auth.Crypted) {
return fmt.Errorf("password invalid")
}
return nil return nil
} }
// ClientAuthVNC is the standard password authentication. See 7.2.2. // ClientAuthVNC is the standard password authentication. See 7.2.2.
type ClientAuthVNC struct { type ClientAuthVNC struct {
Challenge [16]byte Challenge []byte
Password []byte Password []byte
} }
@ -300,25 +348,34 @@ func (auth *ClientAuthVNC) Auth(c Conn) error {
if len(auth.Password) == 0 { if len(auth.Password) == 0 {
return fmt.Errorf("Security Handshake failed; no password provided for VNCAuth.") return fmt.Errorf("Security Handshake failed; no password provided for VNCAuth.")
} }
var challenge [16]byte
if err := binary.Read(c, binary.BigEndian, auth.Challenge); err != nil { if err := binary.Read(c, binary.BigEndian, &challenge); err != nil {
return err return err
} }
auth.encode() crypted, err := AuthVNCEncode(auth.Password, challenge[:])
if err != nil {
return err
}
// Send the encrypted challenge back to server // Send the encrypted challenge back to server
if err := binary.Write(c, binary.BigEndian, auth.Challenge); err != nil { if err := binary.Write(c, binary.BigEndian, crypted); err != nil {
return err return err
} }
return c.Flush() return c.Flush()
} }
func (auth *ClientAuthVNC) encode() error { func AuthVNCEncode(password []byte, challenge []byte) ([]byte, error) {
if len(password) > 8 {
return nil, fmt.Errorf("password too long")
}
if len(challenge) != 16 {
return nil, fmt.Errorf("challenge size not 16 byte long")
}
// Copy password string to 8 byte 0-padded slice // Copy password string to 8 byte 0-padded slice
key := make([]byte, 8) key := make([]byte, 8)
copy(key, auth.Password) copy(key, password)
// Each byte of the password needs to be reversed. This is a // Each byte of the password needs to be reversed. This is a
// non RFC-documented behaviour of VNC clients and servers // non RFC-documented behaviour of VNC clients and servers
@ -331,11 +388,11 @@ func (auth *ClientAuthVNC) encode() error {
// Encrypt challenge with key. // Encrypt challenge with key.
cipher, err := des.NewCipher(key) cipher, err := des.NewCipher(key)
if err != nil { if err != nil {
return err return nil, err
} }
for i := 0; i < len(auth.Challenge); i += cipher.BlockSize() { for i := 0; i < len(challenge); i += cipher.BlockSize() {
cipher.Encrypt(auth.Challenge[i:i+cipher.BlockSize()], auth.Challenge[i:i+cipher.BlockSize()]) cipher.Encrypt(challenge[i:i+cipher.BlockSize()], challenge[i:i+cipher.BlockSize()])
} }
return nil return challenge, nil
} }

View File

@ -27,6 +27,10 @@ type ServerInit struct {
var _ Conn = (*ServerConn)(nil) var _ Conn = (*ServerConn)(nil)
func (c *ServerConn) Config() interface{} {
return c.cfg
}
func (c *ServerConn) Conn() net.Conn { func (c *ServerConn) Conn() net.Conn {
return c.c return c.c
} }
@ -71,13 +75,13 @@ func (c *ServerConn) ColorMap() *ColorMap {
func (c *ServerConn) SetColorMap(cm *ColorMap) { func (c *ServerConn) SetColorMap(cm *ColorMap) {
c.colorMap = cm c.colorMap = cm
} }
func (c *ServerConn) DesktopName() string { func (c *ServerConn) DesktopName() []byte {
return c.desktopName return c.desktopName
} }
func (c *ServerConn) PixelFormat() *PixelFormat { func (c *ServerConn) PixelFormat() *PixelFormat {
return c.pixelFormat return c.pixelFormat
} }
func (c *ServerConn) SetDesktopName(name string) { func (c *ServerConn) SetDesktopName(name []byte) {
c.desktopName = name c.desktopName = name
} }
func (c *ServerConn) SetPixelFormat(pf *PixelFormat) error { func (c *ServerConn) SetPixelFormat(pf *PixelFormat) error {
@ -182,7 +186,7 @@ type ServerConn struct {
colorMap *ColorMap colorMap *ColorMap
// Name associated with the desktop, sent from the server. // Name associated with the desktop, sent from the server.
desktopName string desktopName []byte
// Encodings supported by the client. This should not be modified // Encodings supported by the client. This should not be modified
// directly. Instead, SetEncodings() should be used. // directly. Instead, SetEncodings() should be used.
@ -198,28 +202,35 @@ type ServerConn struct {
// be modified. If you wish to set a new pixel format, use the // be modified. If you wish to set a new pixel format, use the
// SetPixelFormat method. // SetPixelFormat method.
pixelFormat *PixelFormat pixelFormat *PixelFormat
quit chan struct{}
} }
type ServerHandler func(*ServerConfig, Conn) error type ServerHandler interface {
Handle(Conn) error
}
var (
DefaultServerHandlers []ServerHandler = []ServerHandler{
&DefaultServerVersionHandler{},
&DefaultServerSecurityHandler{},
&DefaultServerClientInitHandler{},
&DefaultServerServerInitHandler{},
&DefaultServerMessageHandler{},
}
)
type ServerConfig struct { type ServerConfig struct {
VersionHandler ServerHandler Handlers []ServerHandler
SecurityHandler ServerHandler
SecurityHandlers []SecurityHandler SecurityHandlers []SecurityHandler
ClientInitHandler ServerHandler
ServerInitHandler ServerHandler
Encodings []Encoding Encodings []Encoding
PixelFormat *PixelFormat PixelFormat *PixelFormat
ColorMap *ColorMap ColorMap *ColorMap
ClientMessageCh chan ClientMessage ClientMessageCh chan ClientMessage
ServerMessageCh chan ServerMessage ServerMessageCh chan ServerMessage
ErrorCh chan error
ClientMessages []ClientMessage ClientMessages []ClientMessage
DesktopName []byte DesktopName []byte
Height uint16 Height uint16
Width uint16 Width uint16
errorCh chan error
} }
func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) { func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) {
@ -236,7 +247,7 @@ func NewServerConn(c net.Conn, cfg *ServerConfig) (*ServerConn, error) {
br: bufio.NewReader(c), br: bufio.NewReader(c),
bw: bufio.NewWriter(c), bw: bufio.NewWriter(c),
cfg: cfg, cfg: cfg,
quit: make(chan struct{}), desktopName: cfg.DesktopName,
encodings: cfg.Encodings, encodings: cfg.Encodings,
pixelFormat: cfg.PixelFormat, pixelFormat: cfg.PixelFormat,
fbWidth: cfg.Width, fbWidth: cfg.Width,
@ -256,80 +267,69 @@ func Serve(ctx context.Context, ln net.Listener, cfg *ServerConfig) error {
if err != nil { if err != nil {
continue continue
} }
if len(cfg.Handlers) == 0 {
cfg.Handlers = DefaultServerHandlers
}
if err := cfg.VersionHandler(cfg, conn); err != nil { for _, h := range cfg.Handlers {
if err := h.Handle(conn); err != nil {
conn.Close() conn.Close()
continue continue
} }
if err := cfg.SecurityHandler(cfg, conn); err != nil {
conn.Close()
continue
} }
if err := cfg.ClientInitHandler(cfg, conn); err != nil {
conn.Close()
continue
}
if err := cfg.ServerInitHandler(cfg, conn); err != nil {
conn.Close()
continue
}
go conn.Handle()
} }
} }
func (c *ServerConn) Handle() error { type DefaultServerMessageHandler struct{}
func (*DefaultServerMessageHandler) Handle(c Conn) error {
cfg := c.Config().(*ServerConfig)
var err error var err error
var wg sync.WaitGroup var wg sync.WaitGroup
defer c.Close() defer c.Close()
clientMessages := make(map[ClientMessageType]ClientMessage) clientMessages := make(map[ClientMessageType]ClientMessage)
for _, m := range c.cfg.ClientMessages { for _, m := range cfg.ClientMessages {
clientMessages[m.Type()] = m clientMessages[m.Type()] = m
} }
wg.Add(2) wg.Add(2)
// server // server
go func() error { go func() {
defer wg.Done() defer wg.Done()
for { for {
select { select {
case msg := <-c.cfg.ServerMessageCh: case msg := <-cfg.ServerMessageCh:
if err = msg.Write(c); err != nil { if err = msg.Write(c); err != nil {
return err cfg.errorCh <- err
return
} }
case <-c.quit:
return nil
} }
} }
}() }()
// client // client
go func() error { go func() {
defer wg.Done() defer wg.Done()
for { for {
select { select {
case <-c.quit:
return nil
default: default:
var messageType ClientMessageType var messageType ClientMessageType
if err := binary.Read(c, binary.BigEndian, &messageType); err != nil { if err := binary.Read(c, binary.BigEndian, &messageType); err != nil {
return err cfg.errorCh <- err
return
} }
msg, ok := clientMessages[messageType] msg, ok := clientMessages[messageType]
if !ok { if !ok {
return fmt.Errorf("unsupported message-type: %v", messageType) cfg.errorCh <- fmt.Errorf("unsupported message-type: %v", messageType)
return
} }
parsedMsg, err := msg.Read(c) parsedMsg, err := msg.Read(c)
if err != nil { if err != nil {
fmt.Printf("srv err %s\n", err.Error()) cfg.errorCh <- err
return err return
} }
c.cfg.ClientMessageCh <- parsedMsg cfg.ClientMessageCh <- parsedMsg
} }
} }
}() }()