Adds test package

This splits out the mock libvirt server for testing within
other packages. Constants from the main libvirt package
have been moved to a separate package, constants, for shared access.
This commit is contained in:
Ben LeMasurier 2016-05-19 20:56:27 -06:00
parent 2ccd33a8df
commit 65d878e075
6 changed files with 120 additions and 92 deletions

View File

@ -0,0 +1,54 @@
// Copyright 2016 The go-libvirt Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package constants provides shared data for the libvirt package.
package constants
// magic program numbers
// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD
const (
ProgramVersion = 1
ProgramRemote = 0x20008086
ProgramQEMU = 0x20008087
ProgramKeepAlive = 0x6b656570
)
// libvirt procedure identifiers
const (
ProcConnectOpen = 1
ProcConnectClose = 2
ProcDomainLookupByName = 23
ProcAuthList = 66
ProcConnectGetLibVersion = 157
ProcConnectListAllDomains = 273
)
// qemu procedure identifiers
const (
QEMUDomainMonitor = 1
QEMUConnectDomainMonitorEventRegister = 4
QEMUConnectDomainMonitorEventDeregister = 5
QEMUDomainMonitorEvent = 6
)
const (
// PacketLengthSize is the packet length, in bytes.
PacketLengthSize = 4
// HeaderSize is the packet header size, in bytes.
HeaderSize = 24
// UUIDSize is the length of a UUID, in bytes.
UUIDSize = 16
)

View File

@ -25,6 +25,7 @@ import (
"sync" "sync"
"github.com/davecgh/go-xdr/xdr2" "github.com/davecgh/go-xdr/xdr2"
"github.com/digitalocean/go-libvirt/internal/constants"
) )
// ErrEventsNotSupported is returned by Events() if event streams // ErrEventsNotSupported is returned by Events() if event streams
@ -52,7 +53,7 @@ type Libvirt struct {
// Domain represents a domain as seen by libvirt. // Domain represents a domain as seen by libvirt.
type Domain struct { type Domain struct {
Name string Name string
UUID [uuidSize]byte UUID [constants.UUIDSize]byte
ID int ID int
} }
@ -108,7 +109,7 @@ func (l *Libvirt) Domains() ([]Domain, error) {
return nil, err return nil, err
} }
resp, err := l.request(procConnectListAllDomains, programRemote, &buf) resp, err := l.request(constants.ProcConnectListAllDomains, constants.ProgramRemote, &buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -159,7 +160,7 @@ func (l *Libvirt) Events(dom string) (<-chan DomainEvent, error) {
return nil, err return nil, err
} }
resp, err := l.request(qemuConnectDomainMonitorEventRegister, programQEMU, &buf) resp, err := l.request(constants.QEMUConnectDomainMonitorEventRegister, constants.ProgramQEMU, &buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -218,7 +219,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) {
return nil, err return nil, err
} }
resp, err := l.request(qemuDomainMonitor, programQEMU, &buf) resp, err := l.request(constants.QEMUDomainMonitor, constants.ProgramQEMU, &buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,7 +243,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) {
// Version returns the version of the libvirt daemon. // Version returns the version of the libvirt daemon.
func (l *Libvirt) Version() (string, error) { func (l *Libvirt) Version() (string, error) {
resp, err := l.request(procConnectGetLibVersion, programRemote, nil) resp, err := l.request(constants.ProcConnectGetLibVersion, constants.ProgramRemote, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -286,7 +287,7 @@ func (l *Libvirt) lookup(name string) (*Domain, error) {
return nil, err return nil, err
} }
resp, err := l.request(procDomainLookupByName, programRemote, &buf) resp, err := l.request(constants.ProcDomainLookupByName, constants.ProgramRemote, &buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -19,10 +19,12 @@ import (
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/digitalocean/go-libvirt/libvirttest"
) )
func TestConnect(t *testing.T) { func TestConnect(t *testing.T) {
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
err := l.Connect() err := l.Connect()
@ -32,7 +34,7 @@ func TestConnect(t *testing.T) {
} }
func TestDisconnect(t *testing.T) { func TestDisconnect(t *testing.T) {
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
err := l.Disconnect() err := l.Disconnect()
@ -42,7 +44,7 @@ func TestDisconnect(t *testing.T) {
} }
func TestDomains(t *testing.T) { func TestDomains(t *testing.T) {
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
domains, err := l.Domains() domains, err := l.Domains()
@ -70,7 +72,7 @@ func TestDomains(t *testing.T) {
} }
func TestEvents(t *testing.T) { func TestEvents(t *testing.T) {
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
done := make(chan struct{}) done := make(chan struct{})
@ -108,14 +110,14 @@ func TestEvents(t *testing.T) {
}() }()
// send an event to the listener goroutine // send an event to the listener goroutine
conn.test.Write(append(testEventHeader, testEvent...)) conn.Test.Write(append(testEventHeader, testEvent...))
// wait for completion // wait for completion
<-done <-done
} }
func TestRun(t *testing.T) { func TestRun(t *testing.T) {
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
res, err := l.Run("test", []byte(`{"query-version"}`)) res, err := l.Run("test", []byte(`{"query-version"}`))
@ -147,7 +149,7 @@ func TestRun(t *testing.T) {
} }
func TestVersion(t *testing.T) { func TestVersion(t *testing.T) {
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
version, err := l.Version() version, err := l.Version()

View File

@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package libvirt // Package libvirttest provides a mock libvirt server for RPC testing.
package libvirttest
import ( import (
"encoding/binary" "encoding/binary"
"net" "net"
"sync/atomic" "sync/atomic"
"github.com/digitalocean/go-libvirt/internal/constants"
) )
var testDomainResponse = []byte{ var testDomainResponse = []byte{
@ -164,18 +167,20 @@ var testVersionReply = []byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0xfc, // version (1003004) 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0xfc, // version (1003004)
} }
type mockLibvirt struct { // MockLibvirt provides a mock libvirt server for testing.
type MockLibvirt struct {
net.Conn net.Conn
test net.Conn Test net.Conn
serial uint32 serial uint32
} }
func setupTest() *mockLibvirt { // New creates a new mock Libvirt server.
func New() *MockLibvirt {
serv, conn := net.Pipe() serv, conn := net.Pipe()
m := &mockLibvirt{ m := &MockLibvirt{
Conn: conn, Conn: conn,
test: serv, Test: serv,
} }
go m.handle(serv) go m.handle(serv)
@ -183,9 +188,10 @@ func setupTest() *mockLibvirt {
return m return m
} }
func (m *mockLibvirt) handle(conn net.Conn) { func (m *MockLibvirt) handle(conn net.Conn) {
for { for {
buf := make([]byte, packetLengthSize+headerSize) // packetLengthSize + headerSize
buf := make([]byte, 28)
conn.Read(buf) conn.Read(buf)
// extract program // extract program
@ -195,45 +201,45 @@ func (m *mockLibvirt) handle(conn net.Conn) {
proc := binary.BigEndian.Uint32(buf[12:16]) proc := binary.BigEndian.Uint32(buf[12:16])
switch prog { switch prog {
case programRemote: case constants.ProgramRemote:
m.handleRemote(proc, conn) m.handleRemote(proc, conn)
case programQEMU: case constants.ProgramQEMU:
m.handleQEMU(proc, conn) m.handleQEMU(proc, conn)
} }
} }
} }
func (m *mockLibvirt) handleRemote(procedure uint32, conn net.Conn) { func (m *MockLibvirt) handleRemote(procedure uint32, conn net.Conn) {
switch procedure { switch procedure {
case procAuthList: case constants.ProcAuthList:
conn.Write(m.reply(testAuthReply)) conn.Write(m.reply(testAuthReply))
case procConnectOpen: case constants.ProcConnectOpen:
conn.Write(m.reply(testConnectReply)) conn.Write(m.reply(testConnectReply))
case procConnectClose: case constants.ProcConnectClose:
conn.Write(m.reply(testDisconnectReply)) conn.Write(m.reply(testDisconnectReply))
case procConnectGetLibVersion: case constants.ProcConnectGetLibVersion:
conn.Write(m.reply(testVersionReply)) conn.Write(m.reply(testVersionReply))
case procDomainLookupByName: case constants.ProcDomainLookupByName:
conn.Write(m.reply(testDomainResponse)) conn.Write(m.reply(testDomainResponse))
case procConnectListAllDomains: case constants.ProcConnectListAllDomains:
conn.Write(m.reply(testDomainsReply)) conn.Write(m.reply(testDomainsReply))
} }
} }
func (m *mockLibvirt) handleQEMU(procedure uint32, conn net.Conn) { func (m *MockLibvirt) handleQEMU(procedure uint32, conn net.Conn) {
switch procedure { switch procedure {
case qemuConnectDomainMonitorEventRegister: case constants.QEMUConnectDomainMonitorEventRegister:
conn.Write(m.reply(testRegisterEvent)) conn.Write(m.reply(testRegisterEvent))
case qemuConnectDomainMonitorEventDeregister: case constants.QEMUConnectDomainMonitorEventDeregister:
conn.Write(m.reply(testDeregisterEvent)) conn.Write(m.reply(testDeregisterEvent))
case qemuDomainMonitor: case constants.QEMUDomainMonitor:
conn.Write(m.reply(testRunReply)) conn.Write(m.reply(testRunReply))
} }
} }
// reply automatically injects the correct serial // reply automatically injects the correct serial
// number into the provided response buffer. // number into the provided response buffer.
func (m *mockLibvirt) reply(buf []byte) []byte { func (m *MockLibvirt) reply(buf []byte) []byte {
atomic.AddUint32(&m.serial, 1) atomic.AddUint32(&m.serial, 1)
binary.BigEndian.PutUint32(buf[20:24], m.serial) binary.BigEndian.PutUint32(buf[20:24], m.serial)

59
rpc.go
View File

@ -23,6 +23,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/davecgh/go-xdr/xdr2" "github.com/davecgh/go-xdr/xdr2"
"github.com/digitalocean/go-libvirt/internal/constants"
) )
// ErrUnsupported is returned if a procedure is not supported by libvirt // ErrUnsupported is returned if a procedure is not supported by libvirt
@ -68,44 +69,6 @@ const (
StatusContinue StatusContinue
) )
// magic program numbers
// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD
const (
programVersion = 1
programRemote = 0x20008086
programQEMU = 0x20008087
programKeepAlive = 0x6b656570
)
// libvirt procedure identifiers
const (
procConnectOpen = 1
procConnectClose = 2
procDomainLookupByName = 23
procAuthList = 66
procConnectGetLibVersion = 157
procConnectListAllDomains = 273
)
// qemu procedure identifiers
const (
qemuDomainMonitor = 1
qemuConnectDomainMonitorEventRegister = 4
qemuConnectDomainMonitorEventDeregister = 5
qemuDomainMonitorEvent = 6
)
const (
// packet length, in bytes.
packetLengthSize = 4
// packet header, in bytes.
headerSize = 24
// UUID size, in bytes.
uuidSize = 16
)
// header is a libvirt rpc packet header // header is a libvirt rpc packet header
type header struct { type header struct {
// Program identifier // Program identifier
@ -168,7 +131,7 @@ func (l *Libvirt) connect() error {
// libvirt requires that we call auth-list prior to connecting, // libvirt requires that we call auth-list prior to connecting,
// event when no authentication is used. // event when no authentication is used.
resp, err := l.request(procAuthList, programRemote, &buf) resp, err := l.request(constants.ProcAuthList, constants.ProgramRemote, &buf)
if err != nil { if err != nil {
return err return err
} }
@ -178,7 +141,7 @@ func (l *Libvirt) connect() error {
return decodeError(r.Payload) return decodeError(r.Payload)
} }
resp, err = l.request(procConnectOpen, programRemote, &buf) resp, err = l.request(constants.ProcConnectOpen, constants.ProgramRemote, &buf)
if err != nil { if err != nil {
return err return err
} }
@ -192,7 +155,7 @@ func (l *Libvirt) connect() error {
} }
func (l *Libvirt) disconnect() error { func (l *Libvirt) disconnect() error {
resp, err := l.request(procConnectClose, programRemote, nil) resp, err := l.request(constants.ProcConnectClose, constants.ProgramRemote, nil)
if err != nil { if err != nil {
return err return err
} }
@ -230,7 +193,7 @@ func (l *Libvirt) listen() {
} }
// payload: packet length minus what was previously read // payload: packet length minus what was previously read
size := int(length) - (packetLengthSize + headerSize) size := int(length) - (constants.PacketLengthSize + constants.HeaderSize)
buf := make([]byte, size) buf := make([]byte, size)
for n := 0; n < size; { for n := 0; n < size; {
nn, err := l.r.Read(buf) nn, err := l.r.Read(buf)
@ -260,7 +223,7 @@ func (l *Libvirt) callback(id uint32, res response) {
// route sends incoming packets to their listeners. // route sends incoming packets to their listeners.
func (l *Libvirt) route(h *header, buf []byte) { func (l *Libvirt) route(h *header, buf []byte) {
// route events to their respective listener // route events to their respective listener
if h.Program == programQEMU && h.Procedure == qemuDomainMonitorEvent { if h.Program == constants.ProgramQEMU && h.Procedure == constants.QEMUDomainMonitorEvent {
l.stream(buf) l.stream(buf)
return return
} }
@ -320,7 +283,7 @@ func (l *Libvirt) removeStream(id uint32) error {
return err return err
} }
resp, err := l.request(qemuConnectDomainMonitorEventDeregister, programQEMU, &buf) resp, err := l.request(constants.QEMUConnectDomainMonitorEventDeregister, constants.ProgramQEMU, &buf)
if err != nil { if err != nil {
return err return err
} }
@ -361,7 +324,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
l.register(serial, c) l.register(serial, c)
size := packetLengthSize + headerSize size := constants.PacketLengthSize + constants.HeaderSize
if payload != nil { if payload != nil {
size += payload.Len() size += payload.Len()
} }
@ -370,7 +333,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
Len: uint32(size), Len: uint32(size),
Header: header{ Header: header{
Program: program, Program: program,
Version: programVersion, Version: constants.ProgramVersion,
Procedure: proc, Procedure: proc,
Type: Call, Type: Call,
Serial: serial, Serial: serial,
@ -442,7 +405,7 @@ func decodeEvent(buf []byte) (*DomainEvent, error) {
// If an error is encountered reading the provided Reader, the // If an error is encountered reading the provided Reader, the
// error is returned and response length will be 0. // error is returned and response length will be 0.
func pktlen(r io.Reader) (uint32, error) { func pktlen(r io.Reader) (uint32, error) {
buf := make([]byte, packetLengthSize) buf := make([]byte, constants.PacketLengthSize)
for n := 0; n < cap(buf); { for n := 0; n < cap(buf); {
nn, err := r.Read(buf) nn, err := r.Read(buf)
@ -458,7 +421,7 @@ func pktlen(r io.Reader) (uint32, error) {
// extractHeader returns the decoded header from an incoming response. // extractHeader returns the decoded header from an incoming response.
func extractHeader(r io.Reader) (*header, error) { func extractHeader(r io.Reader) (*header, error) {
buf := make([]byte, headerSize) buf := make([]byte, constants.HeaderSize)
for n := 0; n < cap(buf); { for n := 0; n < cap(buf); {
nn, err := r.Read(buf) nn, err := r.Read(buf)

View File

@ -20,11 +20,13 @@ import (
"testing" "testing"
"github.com/davecgh/go-xdr/xdr2" "github.com/davecgh/go-xdr/xdr2"
"github.com/digitalocean/go-libvirt/internal/constants"
"github.com/digitalocean/go-libvirt/libvirttest"
) )
var ( var (
// dc229f87d4de47198cfd2e21c6105b01 // dc229f87d4de47198cfd2e21c6105b01
testUUID = [uuidSize]byte{ testUUID = [constants.UUIDSize]byte{
0xdc, 0x22, 0x9f, 0x87, 0xd4, 0xde, 0x47, 0x19, 0xdc, 0x22, 0x9f, 0x87, 0xd4, 0xde, 0x47, 0x19,
0x8c, 0xfd, 0x2e, 0x21, 0xc6, 0x10, 0x5b, 0x01, 0x8c, 0xfd, 0x2e, 0x21, 0xc6, 0x10, 0x5b, 0x01,
} }
@ -118,16 +120,16 @@ func TestExtractHeader(t *testing.T) {
t.Error(err) t.Error(err)
} }
if h.Program != programRemote { if h.Program != constants.ProgramRemote {
t.Errorf("expected Program %q, got %q", programRemote, h.Program) t.Errorf("expected Program %q, got %q", constants.ProgramRemote, h.Program)
} }
if h.Version != programVersion { if h.Version != constants.ProgramVersion {
t.Errorf("expected version %q, got %q", programVersion, h.Version) t.Errorf("expected version %q, got %q", constants.ProgramVersion, h.Version)
} }
if h.Procedure != procConnectOpen { if h.Procedure != constants.ProcConnectOpen {
t.Errorf("expected procedure %q, got %q", procConnectOpen, h.Procedure) t.Errorf("expected procedure %q, got %q", constants.ProcConnectOpen, h.Procedure)
} }
if h.Type != Call { if h.Type != Call {
@ -270,7 +272,7 @@ func TestAddStream(t *testing.T) {
func TestRemoveStream(t *testing.T) { func TestRemoveStream(t *testing.T) {
id := uint32(1) id := uint32(1)
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
l.events[id] = make(chan *DomainEvent) l.events[id] = make(chan *DomainEvent)
@ -328,7 +330,7 @@ func TestLookup(t *testing.T) {
c := make(chan response) c := make(chan response)
name := "test" name := "test"
conn := setupTest() conn := libvirttest.New()
l := New(conn) l := New(conn)
l.register(id, c) l.register(id, c)