diff --git a/README.md b/README.md index 969a6b6..ef7c5ce 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ libvirt [![GoDoc](http://godoc.org/github.com/digitalocean/go-libvirt?status.svg Package `go-libvirt` provides a pure Go interface for interacting with libvirt. -Rather than using Libvirt's C bindings, this package makes use of +Rather than using libvirt's C bindings, this package makes use of libvirt's RPC interface, as documented [here](https://libvirt.org/internals/rpc.html). Connections to the libvirt server may be local, or remote. RPC packets are encoded using the XDR standard as defined by [RFC 4506](https://tools.ietf.org/html/rfc4506.html). @@ -20,6 +20,7 @@ and produces go bindings for all the remote procedures defined there. How to Use This Library ----------------------- + Once you've vendored go-libvirt into your project, you'll probably want to call some libvirt functions. There's some example code below showing how to connect to libvirt and make one such call, but once you get past the introduction you'll @@ -108,8 +109,9 @@ import ( ) func main() { - //c, err := net.DialTimeout("tcp", "127.0.0.1:16509", 2*time.Second) - //c, err := net.DialTimeout("tcp", "192.168.1.12:16509", 2*time.Second) + // This dials libvirt on the local machine, but you can substitute the first + // two parameters with "tcp", ":" to connect to libvirt on + // a remote machine. c, err := net.DialTimeout("unix", "/var/run/libvirt/libvirt-sock", 2*time.Second) if err != nil { log.Fatalf("failed to dial libvirt: %v", err) diff --git a/libvirt.go b/libvirt.go index bf029d7..dd7baa5 100644 --- a/libvirt.go +++ b/libvirt.go @@ -1,4 +1,4 @@ -// Copyright 2016 The go-libvirt Authors. +// Copyright 2018 The go-libvirt Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -124,6 +124,10 @@ func (l *Libvirt) Disconnect() error { } } + // Deregister all the callbacks so that clients with outstanding requests + // will unblock. + l.deregisterAll() + _, err := l.request(constants.ProcConnectClose, constants.Program, nil) if err != nil { return err diff --git a/rpc.go b/rpc.go index 20f8f34..6c54abe 100644 --- a/rpc.go +++ b/rpc.go @@ -1,4 +1,4 @@ -// Copyright 2016 The go-libvirt Authors. +// Copyright 2018 The go-libvirt Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -117,6 +117,9 @@ func (e libvirtError) Error() string { return e.Message } +// checkError is used to check whether an error is a libvirtError, and if it is, +// whether its error code matches the one passed in. It will return false if +// these conditions are not met. func checkError(err error, expectedError errorNumber) bool { e, ok := err.(libvirtError) if ok { @@ -140,6 +143,7 @@ func (l *Libvirt) listen() { // When the underlying connection EOFs or is closed, stop // this goroutine if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { + l.deregisterAll() return } @@ -267,8 +271,24 @@ func (l *Libvirt) register(id uint32, c chan response) { // deregister destroys a method response callback func (l *Libvirt) deregister(id uint32) { l.cm.Lock() - close(l.callbacks[id]) - delete(l.callbacks, id) + if _, ok := l.callbacks[id]; ok { + close(l.callbacks[id]) + delete(l.callbacks, id) + } + l.cm.Unlock() +} + +// deregisterAll closes all the waiting callback channels. This is used to clean +// up if the connection to libvirt is lost. Callers waiting for responses will +// return an error when the response channel is closed, rather than just +// hanging. +func (l *Libvirt) deregisterAll() { + l.cm.Lock() + for id := range l.callbacks { + // can't call deregister() here because we're already holding the lock. + close(l.callbacks[id]) + delete(l.callbacks, id) + } l.cm.Unlock() } @@ -304,7 +324,11 @@ func (l *Libvirt) processIncomingStream(c chan response, inStream io.Writer) (re } } -func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, outStream io.Reader, inStream io.Writer) (response, error) { +// requestStream performs a libvirt RPC request. The outStream and inStream +// parameters are optional, and should be nil for RPC endpoints that don't +// return a stream. +func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, + outStream io.Reader, inStream io.Writer) (response, error) { serial := l.serial() c := make(chan response) diff --git a/rpc_test.go b/rpc_test.go index 4393512..bfa7684 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -375,4 +375,28 @@ func TestLookup(t *testing.T) { if d.Name != name { t.Errorf("expected domain %s, got %s", name, d.Name) } + + // The callback should now be deregistered. + if _, ok := l.callbacks[id]; ok { + t.Error("expected callback to deregister") + } +} + +func TestDeregisterAll(t *testing.T) { + conn := libvirttest.New() + c1 := make(chan response) + c2 := make(chan response) + l := New(conn) + if len(l.callbacks) != 0 { + t.Error("expected callback map to be empty at test start") + } + l.register(1, c1) + l.register(2, c2) + if len(l.callbacks) != 2 { + t.Error("expected callback map to have 2 entries after inserts") + } + l.deregisterAll() + if len(l.callbacks) != 0 { + t.Error("expected callback map to be empty after deregisterAll") + } }