Adds test package

This splits out the mock libvirt server for testing within
other packages. Constants from the main libvirt package
have been moved to a separate package, constants, for shared access.
This commit is contained in:
Ben LeMasurier 2016-05-19 20:56:27 -06:00
parent 2ccd33a8df
commit 65d878e075
6 changed files with 120 additions and 92 deletions

View File

@ -0,0 +1,54 @@
// Copyright 2016 The go-libvirt Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package constants provides shared data for the libvirt package.
package constants
// magic program numbers
// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD
const (
ProgramVersion = 1
ProgramRemote = 0x20008086
ProgramQEMU = 0x20008087
ProgramKeepAlive = 0x6b656570
)
// libvirt procedure identifiers
const (
ProcConnectOpen = 1
ProcConnectClose = 2
ProcDomainLookupByName = 23
ProcAuthList = 66
ProcConnectGetLibVersion = 157
ProcConnectListAllDomains = 273
)
// qemu procedure identifiers
const (
QEMUDomainMonitor = 1
QEMUConnectDomainMonitorEventRegister = 4
QEMUConnectDomainMonitorEventDeregister = 5
QEMUDomainMonitorEvent = 6
)
const (
// PacketLengthSize is the packet length, in bytes.
PacketLengthSize = 4
// HeaderSize is the packet header size, in bytes.
HeaderSize = 24
// UUIDSize is the length of a UUID, in bytes.
UUIDSize = 16
)

View File

@ -25,6 +25,7 @@ import (
"sync"
"github.com/davecgh/go-xdr/xdr2"
"github.com/digitalocean/go-libvirt/internal/constants"
)
// ErrEventsNotSupported is returned by Events() if event streams
@ -52,7 +53,7 @@ type Libvirt struct {
// Domain represents a domain as seen by libvirt.
type Domain struct {
Name string
UUID [uuidSize]byte
UUID [constants.UUIDSize]byte
ID int
}
@ -108,7 +109,7 @@ func (l *Libvirt) Domains() ([]Domain, error) {
return nil, err
}
resp, err := l.request(procConnectListAllDomains, programRemote, &buf)
resp, err := l.request(constants.ProcConnectListAllDomains, constants.ProgramRemote, &buf)
if err != nil {
return nil, err
}
@ -159,7 +160,7 @@ func (l *Libvirt) Events(dom string) (<-chan DomainEvent, error) {
return nil, err
}
resp, err := l.request(qemuConnectDomainMonitorEventRegister, programQEMU, &buf)
resp, err := l.request(constants.QEMUConnectDomainMonitorEventRegister, constants.ProgramQEMU, &buf)
if err != nil {
return nil, err
}
@ -218,7 +219,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) {
return nil, err
}
resp, err := l.request(qemuDomainMonitor, programQEMU, &buf)
resp, err := l.request(constants.QEMUDomainMonitor, constants.ProgramQEMU, &buf)
if err != nil {
return nil, err
}
@ -242,7 +243,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) {
// Version returns the version of the libvirt daemon.
func (l *Libvirt) Version() (string, error) {
resp, err := l.request(procConnectGetLibVersion, programRemote, nil)
resp, err := l.request(constants.ProcConnectGetLibVersion, constants.ProgramRemote, nil)
if err != nil {
return "", err
}
@ -286,7 +287,7 @@ func (l *Libvirt) lookup(name string) (*Domain, error) {
return nil, err
}
resp, err := l.request(procDomainLookupByName, programRemote, &buf)
resp, err := l.request(constants.ProcDomainLookupByName, constants.ProgramRemote, &buf)
if err != nil {
return nil, err
}

View File

@ -19,10 +19,12 @@ import (
"fmt"
"testing"
"time"
"github.com/digitalocean/go-libvirt/libvirttest"
)
func TestConnect(t *testing.T) {
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
err := l.Connect()
@ -32,7 +34,7 @@ func TestConnect(t *testing.T) {
}
func TestDisconnect(t *testing.T) {
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
err := l.Disconnect()
@ -42,7 +44,7 @@ func TestDisconnect(t *testing.T) {
}
func TestDomains(t *testing.T) {
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
domains, err := l.Domains()
@ -70,7 +72,7 @@ func TestDomains(t *testing.T) {
}
func TestEvents(t *testing.T) {
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
done := make(chan struct{})
@ -108,14 +110,14 @@ func TestEvents(t *testing.T) {
}()
// send an event to the listener goroutine
conn.test.Write(append(testEventHeader, testEvent...))
conn.Test.Write(append(testEventHeader, testEvent...))
// wait for completion
<-done
}
func TestRun(t *testing.T) {
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
res, err := l.Run("test", []byte(`{"query-version"}`))
@ -147,7 +149,7 @@ func TestRun(t *testing.T) {
}
func TestVersion(t *testing.T) {
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
version, err := l.Version()

View File

@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package libvirt
// Package libvirttest provides a mock libvirt server for RPC testing.
package libvirttest
import (
"encoding/binary"
"net"
"sync/atomic"
"github.com/digitalocean/go-libvirt/internal/constants"
)
var testDomainResponse = []byte{
@ -164,18 +167,20 @@ var testVersionReply = []byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0xfc, // version (1003004)
}
type mockLibvirt struct {
// MockLibvirt provides a mock libvirt server for testing.
type MockLibvirt struct {
net.Conn
test net.Conn
Test net.Conn
serial uint32
}
func setupTest() *mockLibvirt {
// New creates a new mock Libvirt server.
func New() *MockLibvirt {
serv, conn := net.Pipe()
m := &mockLibvirt{
m := &MockLibvirt{
Conn: conn,
test: serv,
Test: serv,
}
go m.handle(serv)
@ -183,9 +188,10 @@ func setupTest() *mockLibvirt {
return m
}
func (m *mockLibvirt) handle(conn net.Conn) {
func (m *MockLibvirt) handle(conn net.Conn) {
for {
buf := make([]byte, packetLengthSize+headerSize)
// packetLengthSize + headerSize
buf := make([]byte, 28)
conn.Read(buf)
// extract program
@ -195,45 +201,45 @@ func (m *mockLibvirt) handle(conn net.Conn) {
proc := binary.BigEndian.Uint32(buf[12:16])
switch prog {
case programRemote:
case constants.ProgramRemote:
m.handleRemote(proc, conn)
case programQEMU:
case constants.ProgramQEMU:
m.handleQEMU(proc, conn)
}
}
}
func (m *mockLibvirt) handleRemote(procedure uint32, conn net.Conn) {
func (m *MockLibvirt) handleRemote(procedure uint32, conn net.Conn) {
switch procedure {
case procAuthList:
case constants.ProcAuthList:
conn.Write(m.reply(testAuthReply))
case procConnectOpen:
case constants.ProcConnectOpen:
conn.Write(m.reply(testConnectReply))
case procConnectClose:
case constants.ProcConnectClose:
conn.Write(m.reply(testDisconnectReply))
case procConnectGetLibVersion:
case constants.ProcConnectGetLibVersion:
conn.Write(m.reply(testVersionReply))
case procDomainLookupByName:
case constants.ProcDomainLookupByName:
conn.Write(m.reply(testDomainResponse))
case procConnectListAllDomains:
case constants.ProcConnectListAllDomains:
conn.Write(m.reply(testDomainsReply))
}
}
func (m *mockLibvirt) handleQEMU(procedure uint32, conn net.Conn) {
func (m *MockLibvirt) handleQEMU(procedure uint32, conn net.Conn) {
switch procedure {
case qemuConnectDomainMonitorEventRegister:
case constants.QEMUConnectDomainMonitorEventRegister:
conn.Write(m.reply(testRegisterEvent))
case qemuConnectDomainMonitorEventDeregister:
case constants.QEMUConnectDomainMonitorEventDeregister:
conn.Write(m.reply(testDeregisterEvent))
case qemuDomainMonitor:
case constants.QEMUDomainMonitor:
conn.Write(m.reply(testRunReply))
}
}
// reply automatically injects the correct serial
// number into the provided response buffer.
func (m *mockLibvirt) reply(buf []byte) []byte {
func (m *MockLibvirt) reply(buf []byte) []byte {
atomic.AddUint32(&m.serial, 1)
binary.BigEndian.PutUint32(buf[20:24], m.serial)

59
rpc.go
View File

@ -23,6 +23,7 @@ import (
"sync/atomic"
"github.com/davecgh/go-xdr/xdr2"
"github.com/digitalocean/go-libvirt/internal/constants"
)
// ErrUnsupported is returned if a procedure is not supported by libvirt
@ -68,44 +69,6 @@ const (
StatusContinue
)
// magic program numbers
// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD
const (
programVersion = 1
programRemote = 0x20008086
programQEMU = 0x20008087
programKeepAlive = 0x6b656570
)
// libvirt procedure identifiers
const (
procConnectOpen = 1
procConnectClose = 2
procDomainLookupByName = 23
procAuthList = 66
procConnectGetLibVersion = 157
procConnectListAllDomains = 273
)
// qemu procedure identifiers
const (
qemuDomainMonitor = 1
qemuConnectDomainMonitorEventRegister = 4
qemuConnectDomainMonitorEventDeregister = 5
qemuDomainMonitorEvent = 6
)
const (
// packet length, in bytes.
packetLengthSize = 4
// packet header, in bytes.
headerSize = 24
// UUID size, in bytes.
uuidSize = 16
)
// header is a libvirt rpc packet header
type header struct {
// Program identifier
@ -168,7 +131,7 @@ func (l *Libvirt) connect() error {
// libvirt requires that we call auth-list prior to connecting,
// event when no authentication is used.
resp, err := l.request(procAuthList, programRemote, &buf)
resp, err := l.request(constants.ProcAuthList, constants.ProgramRemote, &buf)
if err != nil {
return err
}
@ -178,7 +141,7 @@ func (l *Libvirt) connect() error {
return decodeError(r.Payload)
}
resp, err = l.request(procConnectOpen, programRemote, &buf)
resp, err = l.request(constants.ProcConnectOpen, constants.ProgramRemote, &buf)
if err != nil {
return err
}
@ -192,7 +155,7 @@ func (l *Libvirt) connect() error {
}
func (l *Libvirt) disconnect() error {
resp, err := l.request(procConnectClose, programRemote, nil)
resp, err := l.request(constants.ProcConnectClose, constants.ProgramRemote, nil)
if err != nil {
return err
}
@ -230,7 +193,7 @@ func (l *Libvirt) listen() {
}
// payload: packet length minus what was previously read
size := int(length) - (packetLengthSize + headerSize)
size := int(length) - (constants.PacketLengthSize + constants.HeaderSize)
buf := make([]byte, size)
for n := 0; n < size; {
nn, err := l.r.Read(buf)
@ -260,7 +223,7 @@ func (l *Libvirt) callback(id uint32, res response) {
// route sends incoming packets to their listeners.
func (l *Libvirt) route(h *header, buf []byte) {
// route events to their respective listener
if h.Program == programQEMU && h.Procedure == qemuDomainMonitorEvent {
if h.Program == constants.ProgramQEMU && h.Procedure == constants.QEMUDomainMonitorEvent {
l.stream(buf)
return
}
@ -320,7 +283,7 @@ func (l *Libvirt) removeStream(id uint32) error {
return err
}
resp, err := l.request(qemuConnectDomainMonitorEventDeregister, programQEMU, &buf)
resp, err := l.request(constants.QEMUConnectDomainMonitorEventDeregister, constants.ProgramQEMU, &buf)
if err != nil {
return err
}
@ -361,7 +324,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
l.register(serial, c)
size := packetLengthSize + headerSize
size := constants.PacketLengthSize + constants.HeaderSize
if payload != nil {
size += payload.Len()
}
@ -370,7 +333,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
Len: uint32(size),
Header: header{
Program: program,
Version: programVersion,
Version: constants.ProgramVersion,
Procedure: proc,
Type: Call,
Serial: serial,
@ -442,7 +405,7 @@ func decodeEvent(buf []byte) (*DomainEvent, error) {
// If an error is encountered reading the provided Reader, the
// error is returned and response length will be 0.
func pktlen(r io.Reader) (uint32, error) {
buf := make([]byte, packetLengthSize)
buf := make([]byte, constants.PacketLengthSize)
for n := 0; n < cap(buf); {
nn, err := r.Read(buf)
@ -458,7 +421,7 @@ func pktlen(r io.Reader) (uint32, error) {
// extractHeader returns the decoded header from an incoming response.
func extractHeader(r io.Reader) (*header, error) {
buf := make([]byte, headerSize)
buf := make([]byte, constants.HeaderSize)
for n := 0; n < cap(buf); {
nn, err := r.Read(buf)

View File

@ -20,11 +20,13 @@ import (
"testing"
"github.com/davecgh/go-xdr/xdr2"
"github.com/digitalocean/go-libvirt/internal/constants"
"github.com/digitalocean/go-libvirt/libvirttest"
)
var (
// dc229f87d4de47198cfd2e21c6105b01
testUUID = [uuidSize]byte{
testUUID = [constants.UUIDSize]byte{
0xdc, 0x22, 0x9f, 0x87, 0xd4, 0xde, 0x47, 0x19,
0x8c, 0xfd, 0x2e, 0x21, 0xc6, 0x10, 0x5b, 0x01,
}
@ -118,16 +120,16 @@ func TestExtractHeader(t *testing.T) {
t.Error(err)
}
if h.Program != programRemote {
t.Errorf("expected Program %q, got %q", programRemote, h.Program)
if h.Program != constants.ProgramRemote {
t.Errorf("expected Program %q, got %q", constants.ProgramRemote, h.Program)
}
if h.Version != programVersion {
t.Errorf("expected version %q, got %q", programVersion, h.Version)
if h.Version != constants.ProgramVersion {
t.Errorf("expected version %q, got %q", constants.ProgramVersion, h.Version)
}
if h.Procedure != procConnectOpen {
t.Errorf("expected procedure %q, got %q", procConnectOpen, h.Procedure)
if h.Procedure != constants.ProcConnectOpen {
t.Errorf("expected procedure %q, got %q", constants.ProcConnectOpen, h.Procedure)
}
if h.Type != Call {
@ -270,7 +272,7 @@ func TestAddStream(t *testing.T) {
func TestRemoveStream(t *testing.T) {
id := uint32(1)
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
l.events[id] = make(chan *DomainEvent)
@ -328,7 +330,7 @@ func TestLookup(t *testing.T) {
c := make(chan response)
name := "test"
conn := setupTest()
conn := libvirttest.New()
l := New(conn)
l.register(id, c)