Adds test package
This splits out the mock libvirt server for testing within other packages. Constants from the main libvirt package have been moved to a separate package, constants, for shared access.
This commit is contained in:
parent
2ccd33a8df
commit
65d878e075
54
internal/constants/constants.go
Normal file
54
internal/constants/constants.go
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright 2016 The go-libvirt Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package constants provides shared data for the libvirt package.
|
||||
package constants
|
||||
|
||||
// magic program numbers
|
||||
// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD
|
||||
const (
|
||||
ProgramVersion = 1
|
||||
ProgramRemote = 0x20008086
|
||||
ProgramQEMU = 0x20008087
|
||||
ProgramKeepAlive = 0x6b656570
|
||||
)
|
||||
|
||||
// libvirt procedure identifiers
|
||||
const (
|
||||
ProcConnectOpen = 1
|
||||
ProcConnectClose = 2
|
||||
ProcDomainLookupByName = 23
|
||||
ProcAuthList = 66
|
||||
ProcConnectGetLibVersion = 157
|
||||
ProcConnectListAllDomains = 273
|
||||
)
|
||||
|
||||
// qemu procedure identifiers
|
||||
const (
|
||||
QEMUDomainMonitor = 1
|
||||
QEMUConnectDomainMonitorEventRegister = 4
|
||||
QEMUConnectDomainMonitorEventDeregister = 5
|
||||
QEMUDomainMonitorEvent = 6
|
||||
)
|
||||
|
||||
const (
|
||||
// PacketLengthSize is the packet length, in bytes.
|
||||
PacketLengthSize = 4
|
||||
|
||||
// HeaderSize is the packet header size, in bytes.
|
||||
HeaderSize = 24
|
||||
|
||||
// UUIDSize is the length of a UUID, in bytes.
|
||||
UUIDSize = 16
|
||||
)
|
13
libvirt.go
13
libvirt.go
@ -25,6 +25,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/davecgh/go-xdr/xdr2"
|
||||
"github.com/digitalocean/go-libvirt/internal/constants"
|
||||
)
|
||||
|
||||
// ErrEventsNotSupported is returned by Events() if event streams
|
||||
@ -52,7 +53,7 @@ type Libvirt struct {
|
||||
// Domain represents a domain as seen by libvirt.
|
||||
type Domain struct {
|
||||
Name string
|
||||
UUID [uuidSize]byte
|
||||
UUID [constants.UUIDSize]byte
|
||||
ID int
|
||||
}
|
||||
|
||||
@ -108,7 +109,7 @@ func (l *Libvirt) Domains() ([]Domain, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := l.request(procConnectListAllDomains, programRemote, &buf)
|
||||
resp, err := l.request(constants.ProcConnectListAllDomains, constants.ProgramRemote, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -159,7 +160,7 @@ func (l *Libvirt) Events(dom string) (<-chan DomainEvent, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := l.request(qemuConnectDomainMonitorEventRegister, programQEMU, &buf)
|
||||
resp, err := l.request(constants.QEMUConnectDomainMonitorEventRegister, constants.ProgramQEMU, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -218,7 +219,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := l.request(qemuDomainMonitor, programQEMU, &buf)
|
||||
resp, err := l.request(constants.QEMUDomainMonitor, constants.ProgramQEMU, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -242,7 +243,7 @@ func (l *Libvirt) Run(dom string, cmd []byte) ([]byte, error) {
|
||||
|
||||
// Version returns the version of the libvirt daemon.
|
||||
func (l *Libvirt) Version() (string, error) {
|
||||
resp, err := l.request(procConnectGetLibVersion, programRemote, nil)
|
||||
resp, err := l.request(constants.ProcConnectGetLibVersion, constants.ProgramRemote, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -286,7 +287,7 @@ func (l *Libvirt) lookup(name string) (*Domain, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := l.request(procDomainLookupByName, programRemote, &buf)
|
||||
resp, err := l.request(constants.ProcDomainLookupByName, constants.ProgramRemote, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -19,10 +19,12 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/digitalocean/go-libvirt/libvirttest"
|
||||
)
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
|
||||
err := l.Connect()
|
||||
@ -32,7 +34,7 @@ func TestConnect(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDisconnect(t *testing.T) {
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
|
||||
err := l.Disconnect()
|
||||
@ -42,7 +44,7 @@ func TestDisconnect(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDomains(t *testing.T) {
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
|
||||
domains, err := l.Domains()
|
||||
@ -70,7 +72,7 @@ func TestDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEvents(t *testing.T) {
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
done := make(chan struct{})
|
||||
|
||||
@ -108,14 +110,14 @@ func TestEvents(t *testing.T) {
|
||||
}()
|
||||
|
||||
// send an event to the listener goroutine
|
||||
conn.test.Write(append(testEventHeader, testEvent...))
|
||||
conn.Test.Write(append(testEventHeader, testEvent...))
|
||||
|
||||
// wait for completion
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
|
||||
res, err := l.Run("test", []byte(`{"query-version"}`))
|
||||
@ -147,7 +149,7 @@ func TestRun(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
|
||||
version, err := l.Version()
|
||||
|
@ -12,12 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package libvirt
|
||||
// Package libvirttest provides a mock libvirt server for RPC testing.
|
||||
package libvirttest
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/digitalocean/go-libvirt/internal/constants"
|
||||
)
|
||||
|
||||
var testDomainResponse = []byte{
|
||||
@ -164,18 +167,20 @@ var testVersionReply = []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0xfc, // version (1003004)
|
||||
}
|
||||
|
||||
type mockLibvirt struct {
|
||||
// MockLibvirt provides a mock libvirt server for testing.
|
||||
type MockLibvirt struct {
|
||||
net.Conn
|
||||
test net.Conn
|
||||
Test net.Conn
|
||||
serial uint32
|
||||
}
|
||||
|
||||
func setupTest() *mockLibvirt {
|
||||
// New creates a new mock Libvirt server.
|
||||
func New() *MockLibvirt {
|
||||
serv, conn := net.Pipe()
|
||||
|
||||
m := &mockLibvirt{
|
||||
m := &MockLibvirt{
|
||||
Conn: conn,
|
||||
test: serv,
|
||||
Test: serv,
|
||||
}
|
||||
|
||||
go m.handle(serv)
|
||||
@ -183,9 +188,10 @@ func setupTest() *mockLibvirt {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockLibvirt) handle(conn net.Conn) {
|
||||
func (m *MockLibvirt) handle(conn net.Conn) {
|
||||
for {
|
||||
buf := make([]byte, packetLengthSize+headerSize)
|
||||
// packetLengthSize + headerSize
|
||||
buf := make([]byte, 28)
|
||||
conn.Read(buf)
|
||||
|
||||
// extract program
|
||||
@ -195,45 +201,45 @@ func (m *mockLibvirt) handle(conn net.Conn) {
|
||||
proc := binary.BigEndian.Uint32(buf[12:16])
|
||||
|
||||
switch prog {
|
||||
case programRemote:
|
||||
case constants.ProgramRemote:
|
||||
m.handleRemote(proc, conn)
|
||||
case programQEMU:
|
||||
case constants.ProgramQEMU:
|
||||
m.handleQEMU(proc, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockLibvirt) handleRemote(procedure uint32, conn net.Conn) {
|
||||
func (m *MockLibvirt) handleRemote(procedure uint32, conn net.Conn) {
|
||||
switch procedure {
|
||||
case procAuthList:
|
||||
case constants.ProcAuthList:
|
||||
conn.Write(m.reply(testAuthReply))
|
||||
case procConnectOpen:
|
||||
case constants.ProcConnectOpen:
|
||||
conn.Write(m.reply(testConnectReply))
|
||||
case procConnectClose:
|
||||
case constants.ProcConnectClose:
|
||||
conn.Write(m.reply(testDisconnectReply))
|
||||
case procConnectGetLibVersion:
|
||||
case constants.ProcConnectGetLibVersion:
|
||||
conn.Write(m.reply(testVersionReply))
|
||||
case procDomainLookupByName:
|
||||
case constants.ProcDomainLookupByName:
|
||||
conn.Write(m.reply(testDomainResponse))
|
||||
case procConnectListAllDomains:
|
||||
case constants.ProcConnectListAllDomains:
|
||||
conn.Write(m.reply(testDomainsReply))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockLibvirt) handleQEMU(procedure uint32, conn net.Conn) {
|
||||
func (m *MockLibvirt) handleQEMU(procedure uint32, conn net.Conn) {
|
||||
switch procedure {
|
||||
case qemuConnectDomainMonitorEventRegister:
|
||||
case constants.QEMUConnectDomainMonitorEventRegister:
|
||||
conn.Write(m.reply(testRegisterEvent))
|
||||
case qemuConnectDomainMonitorEventDeregister:
|
||||
case constants.QEMUConnectDomainMonitorEventDeregister:
|
||||
conn.Write(m.reply(testDeregisterEvent))
|
||||
case qemuDomainMonitor:
|
||||
case constants.QEMUDomainMonitor:
|
||||
conn.Write(m.reply(testRunReply))
|
||||
}
|
||||
}
|
||||
|
||||
// reply automatically injects the correct serial
|
||||
// number into the provided response buffer.
|
||||
func (m *mockLibvirt) reply(buf []byte) []byte {
|
||||
func (m *MockLibvirt) reply(buf []byte) []byte {
|
||||
atomic.AddUint32(&m.serial, 1)
|
||||
binary.BigEndian.PutUint32(buf[20:24], m.serial)
|
||||
|
59
rpc.go
59
rpc.go
@ -23,6 +23,7 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/davecgh/go-xdr/xdr2"
|
||||
"github.com/digitalocean/go-libvirt/internal/constants"
|
||||
)
|
||||
|
||||
// ErrUnsupported is returned if a procedure is not supported by libvirt
|
||||
@ -68,44 +69,6 @@ const (
|
||||
StatusContinue
|
||||
)
|
||||
|
||||
// magic program numbers
|
||||
// see: https://libvirt.org/git/?p=libvirt.git;a=blob_plain;f=src/remote/remote_protocol.x;hb=HEAD
|
||||
const (
|
||||
programVersion = 1
|
||||
programRemote = 0x20008086
|
||||
programQEMU = 0x20008087
|
||||
programKeepAlive = 0x6b656570
|
||||
)
|
||||
|
||||
// libvirt procedure identifiers
|
||||
const (
|
||||
procConnectOpen = 1
|
||||
procConnectClose = 2
|
||||
procDomainLookupByName = 23
|
||||
procAuthList = 66
|
||||
procConnectGetLibVersion = 157
|
||||
procConnectListAllDomains = 273
|
||||
)
|
||||
|
||||
// qemu procedure identifiers
|
||||
const (
|
||||
qemuDomainMonitor = 1
|
||||
qemuConnectDomainMonitorEventRegister = 4
|
||||
qemuConnectDomainMonitorEventDeregister = 5
|
||||
qemuDomainMonitorEvent = 6
|
||||
)
|
||||
|
||||
const (
|
||||
// packet length, in bytes.
|
||||
packetLengthSize = 4
|
||||
|
||||
// packet header, in bytes.
|
||||
headerSize = 24
|
||||
|
||||
// UUID size, in bytes.
|
||||
uuidSize = 16
|
||||
)
|
||||
|
||||
// header is a libvirt rpc packet header
|
||||
type header struct {
|
||||
// Program identifier
|
||||
@ -168,7 +131,7 @@ func (l *Libvirt) connect() error {
|
||||
|
||||
// libvirt requires that we call auth-list prior to connecting,
|
||||
// event when no authentication is used.
|
||||
resp, err := l.request(procAuthList, programRemote, &buf)
|
||||
resp, err := l.request(constants.ProcAuthList, constants.ProgramRemote, &buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -178,7 +141,7 @@ func (l *Libvirt) connect() error {
|
||||
return decodeError(r.Payload)
|
||||
}
|
||||
|
||||
resp, err = l.request(procConnectOpen, programRemote, &buf)
|
||||
resp, err = l.request(constants.ProcConnectOpen, constants.ProgramRemote, &buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -192,7 +155,7 @@ func (l *Libvirt) connect() error {
|
||||
}
|
||||
|
||||
func (l *Libvirt) disconnect() error {
|
||||
resp, err := l.request(procConnectClose, programRemote, nil)
|
||||
resp, err := l.request(constants.ProcConnectClose, constants.ProgramRemote, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -230,7 +193,7 @@ func (l *Libvirt) listen() {
|
||||
}
|
||||
|
||||
// payload: packet length minus what was previously read
|
||||
size := int(length) - (packetLengthSize + headerSize)
|
||||
size := int(length) - (constants.PacketLengthSize + constants.HeaderSize)
|
||||
buf := make([]byte, size)
|
||||
for n := 0; n < size; {
|
||||
nn, err := l.r.Read(buf)
|
||||
@ -260,7 +223,7 @@ func (l *Libvirt) callback(id uint32, res response) {
|
||||
// route sends incoming packets to their listeners.
|
||||
func (l *Libvirt) route(h *header, buf []byte) {
|
||||
// route events to their respective listener
|
||||
if h.Program == programQEMU && h.Procedure == qemuDomainMonitorEvent {
|
||||
if h.Program == constants.ProgramQEMU && h.Procedure == constants.QEMUDomainMonitorEvent {
|
||||
l.stream(buf)
|
||||
return
|
||||
}
|
||||
@ -320,7 +283,7 @@ func (l *Libvirt) removeStream(id uint32) error {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := l.request(qemuConnectDomainMonitorEventDeregister, programQEMU, &buf)
|
||||
resp, err := l.request(constants.QEMUConnectDomainMonitorEventDeregister, constants.ProgramQEMU, &buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -361,7 +324,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
|
||||
|
||||
l.register(serial, c)
|
||||
|
||||
size := packetLengthSize + headerSize
|
||||
size := constants.PacketLengthSize + constants.HeaderSize
|
||||
if payload != nil {
|
||||
size += payload.Len()
|
||||
}
|
||||
@ -370,7 +333,7 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
|
||||
Len: uint32(size),
|
||||
Header: header{
|
||||
Program: program,
|
||||
Version: programVersion,
|
||||
Version: constants.ProgramVersion,
|
||||
Procedure: proc,
|
||||
Type: Call,
|
||||
Serial: serial,
|
||||
@ -442,7 +405,7 @@ func decodeEvent(buf []byte) (*DomainEvent, error) {
|
||||
// If an error is encountered reading the provided Reader, the
|
||||
// error is returned and response length will be 0.
|
||||
func pktlen(r io.Reader) (uint32, error) {
|
||||
buf := make([]byte, packetLengthSize)
|
||||
buf := make([]byte, constants.PacketLengthSize)
|
||||
|
||||
for n := 0; n < cap(buf); {
|
||||
nn, err := r.Read(buf)
|
||||
@ -458,7 +421,7 @@ func pktlen(r io.Reader) (uint32, error) {
|
||||
|
||||
// extractHeader returns the decoded header from an incoming response.
|
||||
func extractHeader(r io.Reader) (*header, error) {
|
||||
buf := make([]byte, headerSize)
|
||||
buf := make([]byte, constants.HeaderSize)
|
||||
|
||||
for n := 0; n < cap(buf); {
|
||||
nn, err := r.Read(buf)
|
||||
|
20
rpc_test.go
20
rpc_test.go
@ -20,11 +20,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-xdr/xdr2"
|
||||
"github.com/digitalocean/go-libvirt/internal/constants"
|
||||
"github.com/digitalocean/go-libvirt/libvirttest"
|
||||
)
|
||||
|
||||
var (
|
||||
// dc229f87d4de47198cfd2e21c6105b01
|
||||
testUUID = [uuidSize]byte{
|
||||
testUUID = [constants.UUIDSize]byte{
|
||||
0xdc, 0x22, 0x9f, 0x87, 0xd4, 0xde, 0x47, 0x19,
|
||||
0x8c, 0xfd, 0x2e, 0x21, 0xc6, 0x10, 0x5b, 0x01,
|
||||
}
|
||||
@ -118,16 +120,16 @@ func TestExtractHeader(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.Program != programRemote {
|
||||
t.Errorf("expected Program %q, got %q", programRemote, h.Program)
|
||||
if h.Program != constants.ProgramRemote {
|
||||
t.Errorf("expected Program %q, got %q", constants.ProgramRemote, h.Program)
|
||||
}
|
||||
|
||||
if h.Version != programVersion {
|
||||
t.Errorf("expected version %q, got %q", programVersion, h.Version)
|
||||
if h.Version != constants.ProgramVersion {
|
||||
t.Errorf("expected version %q, got %q", constants.ProgramVersion, h.Version)
|
||||
}
|
||||
|
||||
if h.Procedure != procConnectOpen {
|
||||
t.Errorf("expected procedure %q, got %q", procConnectOpen, h.Procedure)
|
||||
if h.Procedure != constants.ProcConnectOpen {
|
||||
t.Errorf("expected procedure %q, got %q", constants.ProcConnectOpen, h.Procedure)
|
||||
}
|
||||
|
||||
if h.Type != Call {
|
||||
@ -270,7 +272,7 @@ func TestAddStream(t *testing.T) {
|
||||
func TestRemoveStream(t *testing.T) {
|
||||
id := uint32(1)
|
||||
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
l.events[id] = make(chan *DomainEvent)
|
||||
|
||||
@ -328,7 +330,7 @@ func TestLookup(t *testing.T) {
|
||||
c := make(chan response)
|
||||
name := "test"
|
||||
|
||||
conn := setupTest()
|
||||
conn := libvirttest.New()
|
||||
l := New(conn)
|
||||
|
||||
l.register(id, c)
|
||||
|
Loading…
Reference in New Issue
Block a user