From 65d878e075584013e907f932bf90aa0ed4bb7fe7 Mon Sep 17 00:00:00 2001 From: Ben LeMasurier Date: Thu, 19 May 2016 20:56:27 -0600 Subject: [PATCH] Adds test package This splits out the mock libvirt server for testing within other packages. Constants from the main libvirt package have been moved to a separate package, constants, for shared access. --- internal/constants/constants.go | 54 +++++++++++++++++++++++ libvirt.go | 13 +++--- libvirt_test.go | 16 ++++--- mock_test.go => libvirttest/libvirt.go | 50 ++++++++++++---------- rpc.go | 59 +++++--------------------- rpc_test.go | 20 +++++---- 6 files changed, 120 insertions(+), 92 deletions(-) create mode 100644 internal/constants/constants.go rename mock_test.go => libvirttest/libvirt.go (84%) diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..9db50c1 --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,54 @@ +// Copyright 2016 The go-libvirt Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package constants provides shared data for the libvirt package. +package constants + +// magic program numbers +// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD +const ( + ProgramVersion = 1 + ProgramRemote = 0x20008086 + ProgramQEMU = 0x20008087 + ProgramKeepAlive = 0x6b656570 +) + +// libvirt procedure identifiers +const ( + ProcConnectOpen = 1 + ProcConnectClose = 2 + ProcDomainLookupByName = 23 + ProcAuthList = 66 + ProcConnectGetLibVersion = 157 + ProcConnectListAllDomains = 273 +) + +// qemu procedure identifiers +const ( + QEMUDomainMonitor = 1 + QEMUConnectDomainMonitorEventRegister = 4 + QEMUConnectDomainMonitorEventDeregister = 5 + QEMUDomainMonitorEvent = 6 +) + +const ( + // PacketLengthSize is the packet length, in bytes. + PacketLengthSize = 4 + + // HeaderSize is the packet header size, in bytes. + HeaderSize = 24 + + // UUIDSize is the length of a UUID, in bytes. + UUIDSize = 16 +) diff --git a/libvirt.go b/libvirt.go index ead2d71..34d9375 100644 --- a/libvirt.go +++ b/libvirt.go @@ -25,6 +25,7 @@ import ( "sync" "github.com/davecgh/go-xdr/xdr2" + "github.com/digitalocean/go-libvirt/internal/constants" ) // ErrEventsNotSupported is returned by Events() if event streams @@ -52,7 +53,7 @@ type Libvirt struct { // Domain represents a domain as seen by libvirt. type Domain struct { Name string - UUID [uuidSize]byte + UUID [constants.UUIDSize]byte ID int } @@ -108,7 +109,7 @@ func (l *Libvirt) Domains() ([]Domain, error) { return nil, err } - resp, err := l.request(procConnectListAllDomains, programRemote, &buf) + resp, err := l.request(constants.ProcConnectListAllDomains, constants.ProgramRemote, &buf) if err != nil { return nil, err } @@ -159,7 +160,7 @@ func (l *Libvirt) Events(dom string) (<-chan DomainEvent, error) { return nil, err } - resp, err := l.request(qemuConnectDomainMonitorEventRegister, programQEMU, &buf) + resp, err := l.request(constants.QEMUConnectDomainMonitorEventRegister, constants.ProgramQEMU, &buf) if err != nil { return nil, err } @@ -218,7 +219,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) { return nil, err } - resp, err := l.request(qemuDomainMonitor, programQEMU, &buf) + resp, err := l.request(constants.QEMUDomainMonitor, constants.ProgramQEMU, &buf) if err != nil { return nil, err } @@ -242,7 +243,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) { // Version returns the version of the libvirt daemon. func (l *Libvirt) Version() (string, error) { - resp, err := l.request(procConnectGetLibVersion, programRemote, nil) + resp, err := l.request(constants.ProcConnectGetLibVersion, constants.ProgramRemote, nil) if err != nil { return "", err } @@ -286,7 +287,7 @@ func (l *Libvirt) lookup(name string) (*Domain, error) { return nil, err } - resp, err := l.request(procDomainLookupByName, programRemote, &buf) + resp, err := l.request(constants.ProcDomainLookupByName, constants.ProgramRemote, &buf) if err != nil { return nil, err } diff --git a/libvirt_test.go b/libvirt_test.go index dbed492..8b987e0 100644 --- a/libvirt_test.go +++ b/libvirt_test.go @@ -19,10 +19,12 @@ import ( "fmt" "testing" "time" + + "github.com/digitalocean/go-libvirt/libvirttest" ) func TestConnect(t *testing.T) { - conn := setupTest() + conn := libvirttest.New() l := New(conn) err := l.Connect() @@ -32,7 +34,7 @@ func TestConnect(t *testing.T) { } func TestDisconnect(t *testing.T) { - conn := setupTest() + conn := libvirttest.New() l := New(conn) err := l.Disconnect() @@ -42,7 +44,7 @@ func TestDisconnect(t *testing.T) { } func TestDomains(t *testing.T) { - conn := setupTest() + conn := libvirttest.New() l := New(conn) domains, err := l.Domains() @@ -70,7 +72,7 @@ func TestDomains(t *testing.T) { } func TestEvents(t *testing.T) { - conn := setupTest() + conn := libvirttest.New() l := New(conn) done := make(chan struct{}) @@ -108,14 +110,14 @@ func TestEvents(t *testing.T) { }() // send an event to the listener goroutine - conn.test.Write(append(testEventHeader, testEvent...)) + conn.Test.Write(append(testEventHeader, testEvent...)) // wait for completion <-done } func TestRun(t *testing.T) { - conn := setupTest() + conn := libvirttest.New() l := New(conn) res, err := l.Run("test", []byte(`{"query-version"}`)) @@ -147,7 +149,7 @@ func TestRun(t *testing.T) { } func TestVersion(t *testing.T) { - conn := setupTest() + conn := libvirttest.New() l := New(conn) version, err := l.Version() diff --git a/mock_test.go b/libvirttest/libvirt.go similarity index 84% rename from mock_test.go rename to libvirttest/libvirt.go index e79707c..77e097c 100644 --- a/mock_test.go +++ b/libvirttest/libvirt.go @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package libvirt +// Package libvirttest provides a mock libvirt server for RPC testing. +package libvirttest import ( "encoding/binary" "net" "sync/atomic" + + "github.com/digitalocean/go-libvirt/internal/constants" ) var testDomainResponse = []byte{ @@ -164,18 +167,20 @@ var testVersionReply = []byte{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0xfc, // version (1003004) } -type mockLibvirt struct { +// MockLibvirt provides a mock libvirt server for testing. +type MockLibvirt struct { net.Conn - test net.Conn + Test net.Conn serial uint32 } -func setupTest() *mockLibvirt { +// New creates a new mock Libvirt server. +func New() *MockLibvirt { serv, conn := net.Pipe() - m := &mockLibvirt{ + m := &MockLibvirt{ Conn: conn, - test: serv, + Test: serv, } go m.handle(serv) @@ -183,9 +188,10 @@ func setupTest() *mockLibvirt { return m } -func (m *mockLibvirt) handle(conn net.Conn) { +func (m *MockLibvirt) handle(conn net.Conn) { for { - buf := make([]byte, packetLengthSize+headerSize) + // packetLengthSize + headerSize + buf := make([]byte, 28) conn.Read(buf) // extract program @@ -195,45 +201,45 @@ func (m *mockLibvirt) handle(conn net.Conn) { proc := binary.BigEndian.Uint32(buf[12:16]) switch prog { - case programRemote: + case constants.ProgramRemote: m.handleRemote(proc, conn) - case programQEMU: + case constants.ProgramQEMU: m.handleQEMU(proc, conn) } } } -func (m *mockLibvirt) handleRemote(procedure uint32, conn net.Conn) { +func (m *MockLibvirt) handleRemote(procedure uint32, conn net.Conn) { switch procedure { - case procAuthList: + case constants.ProcAuthList: conn.Write(m.reply(testAuthReply)) - case procConnectOpen: + case constants.ProcConnectOpen: conn.Write(m.reply(testConnectReply)) - case procConnectClose: + case constants.ProcConnectClose: conn.Write(m.reply(testDisconnectReply)) - case procConnectGetLibVersion: + case constants.ProcConnectGetLibVersion: conn.Write(m.reply(testVersionReply)) - case procDomainLookupByName: + case constants.ProcDomainLookupByName: conn.Write(m.reply(testDomainResponse)) - case procConnectListAllDomains: + case constants.ProcConnectListAllDomains: conn.Write(m.reply(testDomainsReply)) } } -func (m *mockLibvirt) handleQEMU(procedure uint32, conn net.Conn) { +func (m *MockLibvirt) handleQEMU(procedure uint32, conn net.Conn) { switch procedure { - case qemuConnectDomainMonitorEventRegister: + case constants.QEMUConnectDomainMonitorEventRegister: conn.Write(m.reply(testRegisterEvent)) - case qemuConnectDomainMonitorEventDeregister: + case constants.QEMUConnectDomainMonitorEventDeregister: conn.Write(m.reply(testDeregisterEvent)) - case qemuDomainMonitor: + case constants.QEMUDomainMonitor: conn.Write(m.reply(testRunReply)) } } // reply automatically injects the correct serial // number into the provided response buffer. -func (m *mockLibvirt) reply(buf []byte) []byte { +func (m *MockLibvirt) reply(buf []byte) []byte { atomic.AddUint32(&m.serial, 1) binary.BigEndian.PutUint32(buf[20:24], m.serial) diff --git a/rpc.go b/rpc.go index b8f2121..c0c1c46 100644 --- a/rpc.go +++ b/rpc.go @@ -23,6 +23,7 @@ import ( "sync/atomic" "github.com/davecgh/go-xdr/xdr2" + "github.com/digitalocean/go-libvirt/internal/constants" ) // ErrUnsupported is returned if a procedure is not supported by libvirt @@ -68,44 +69,6 @@ const ( StatusContinue ) -// magic program numbers -// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD -const ( - programVersion = 1 - programRemote = 0x20008086 - programQEMU = 0x20008087 - programKeepAlive = 0x6b656570 -) - -// libvirt procedure identifiers -const ( - procConnectOpen = 1 - procConnectClose = 2 - procDomainLookupByName = 23 - procAuthList = 66 - procConnectGetLibVersion = 157 - procConnectListAllDomains = 273 -) - -// qemu procedure identifiers -const ( - qemuDomainMonitor = 1 - qemuConnectDomainMonitorEventRegister = 4 - qemuConnectDomainMonitorEventDeregister = 5 - qemuDomainMonitorEvent = 6 -) - -const ( - // packet length, in bytes. - packetLengthSize = 4 - - // packet header, in bytes. - headerSize = 24 - - // UUID size, in bytes. - uuidSize = 16 -) - // header is a libvirt rpc packet header type header struct { // Program identifier @@ -168,7 +131,7 @@ func (l *Libvirt) connect() error { // libvirt requires that we call auth-list prior to connecting, // event when no authentication is used. - resp, err := l.request(procAuthList, programRemote, &buf) + resp, err := l.request(constants.ProcAuthList, constants.ProgramRemote, &buf) if err != nil { return err } @@ -178,7 +141,7 @@ func (l *Libvirt) connect() error { return decodeError(r.Payload) } - resp, err = l.request(procConnectOpen, programRemote, &buf) + resp, err = l.request(constants.ProcConnectOpen, constants.ProgramRemote, &buf) if err != nil { return err } @@ -192,7 +155,7 @@ func (l *Libvirt) connect() error { } func (l *Libvirt) disconnect() error { - resp, err := l.request(procConnectClose, programRemote, nil) + resp, err := l.request(constants.ProcConnectClose, constants.ProgramRemote, nil) if err != nil { return err } @@ -230,7 +193,7 @@ func (l *Libvirt) listen() { } // payload: packet length minus what was previously read - size := int(length) - (packetLengthSize + headerSize) + size := int(length) - (constants.PacketLengthSize + constants.HeaderSize) buf := make([]byte, size) for n := 0; n < size; { nn, err := l.r.Read(buf) @@ -260,7 +223,7 @@ func (l *Libvirt) callback(id uint32, res response) { // route sends incoming packets to their listeners. func (l *Libvirt) route(h *header, buf []byte) { // route events to their respective listener - if h.Program == programQEMU && h.Procedure == qemuDomainMonitorEvent { + if h.Program == constants.ProgramQEMU && h.Procedure == constants.QEMUDomainMonitorEvent { l.stream(buf) return } @@ -320,7 +283,7 @@ func (l *Libvirt) removeStream(id uint32) error { return err } - resp, err := l.request(qemuConnectDomainMonitorEventDeregister, programQEMU, &buf) + resp, err := l.request(constants.QEMUConnectDomainMonitorEventDeregister, constants.ProgramQEMU, &buf) if err != nil { return err } @@ -361,7 +324,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (< l.register(serial, c) - size := packetLengthSize + headerSize + size := constants.PacketLengthSize + constants.HeaderSize if payload != nil { size += payload.Len() } @@ -370,7 +333,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (< Len: uint32(size), Header: header{ Program: program, - Version: programVersion, + Version: constants.ProgramVersion, Procedure: proc, Type: Call, Serial: serial, @@ -442,7 +405,7 @@ func decodeEvent(buf []byte) (*DomainEvent, error) { // If an error is encountered reading the provided Reader, the // error is returned and response length will be 0. func pktlen(r io.Reader) (uint32, error) { - buf := make([]byte, packetLengthSize) + buf := make([]byte, constants.PacketLengthSize) for n := 0; n < cap(buf); { nn, err := r.Read(buf) @@ -458,7 +421,7 @@ func pktlen(r io.Reader) (uint32, error) { // extractHeader returns the decoded header from an incoming response. func extractHeader(r io.Reader) (*header, error) { - buf := make([]byte, headerSize) + buf := make([]byte, constants.HeaderSize) for n := 0; n < cap(buf); { nn, err := r.Read(buf) diff --git a/rpc_test.go b/rpc_test.go index 391a7bf..ae84340 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -20,11 +20,13 @@ import ( "testing" "github.com/davecgh/go-xdr/xdr2" + "github.com/digitalocean/go-libvirt/internal/constants" + "github.com/digitalocean/go-libvirt/libvirttest" ) var ( // dc229f87d4de47198cfd2e21c6105b01 - testUUID = [uuidSize]byte{ + testUUID = [constants.UUIDSize]byte{ 0xdc, 0x22, 0x9f, 0x87, 0xd4, 0xde, 0x47, 0x19, 0x8c, 0xfd, 0x2e, 0x21, 0xc6, 0x10, 0x5b, 0x01, } @@ -118,16 +120,16 @@ func TestExtractHeader(t *testing.T) { t.Error(err) } - if h.Program != programRemote { - t.Errorf("expected Program %q, got %q", programRemote, h.Program) + if h.Program != constants.ProgramRemote { + t.Errorf("expected Program %q, got %q", constants.ProgramRemote, h.Program) } - if h.Version != programVersion { - t.Errorf("expected version %q, got %q", programVersion, h.Version) + if h.Version != constants.ProgramVersion { + t.Errorf("expected version %q, got %q", constants.ProgramVersion, h.Version) } - if h.Procedure != procConnectOpen { - t.Errorf("expected procedure %q, got %q", procConnectOpen, h.Procedure) + if h.Procedure != constants.ProcConnectOpen { + t.Errorf("expected procedure %q, got %q", constants.ProcConnectOpen, h.Procedure) } if h.Type != Call { @@ -270,7 +272,7 @@ func TestAddStream(t *testing.T) { func TestRemoveStream(t *testing.T) { id := uint32(1) - conn := setupTest() + conn := libvirttest.New() l := New(conn) l.events[id] = make(chan *DomainEvent) @@ -328,7 +330,7 @@ func TestLookup(t *testing.T) { c := make(chan response) name := "test" - conn := setupTest() + conn := libvirttest.New() l := New(conn) l.register(id, c)