Geoff/dereg callbacks on disconnect (#69)

Deregister all callbacks when disconnecting from libvirt.

Deregister all callbacks when we lose or close the connection to libvirt. This fixes a problem where goroutines with outstanding requests waiting for replies would block forever if the libvirt connection dies, whether because disconnect is called, or the libvirt daemon crashes or restarts.
This commit is contained in:
Geoff Hickey 2018-06-27 18:38:11 -04:00 committed by GitHub
parent 17c24de803
commit 2b098b4625
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 8 deletions

View File

@ -3,7 +3,7 @@ libvirt [![GoDoc](http://godoc.org/github.com/digitalocean/go-libvirt?status.svg
Package `go-libvirt` provides a pure Go interface for interacting with libvirt.
Rather than using Libvirt's C bindings, this package makes use of
Rather than using libvirt's C bindings, this package makes use of
libvirt's RPC interface, as documented [here](https://libvirt.org/internals/rpc.html).
Connections to the libvirt server may be local, or remote. RPC packets are encoded
using the XDR standard as defined by [RFC 4506](https://tools.ietf.org/html/rfc4506.html).
@ -20,6 +20,7 @@ and produces go bindings for all the remote procedures defined there.
How to Use This Library
-----------------------
Once you've vendored go-libvirt into your project, you'll probably want to call
some libvirt functions. There's some example code below showing how to connect
to libvirt and make one such call, but once you get past the introduction you'll
@ -108,8 +109,9 @@ import (
)
func main() {
//c, err := net.DialTimeout("tcp", "127.0.0.1:16509", 2*time.Second)
//c, err := net.DialTimeout("tcp", "192.168.1.12:16509", 2*time.Second)
// This dials libvirt on the local machine, but you can substitute the first
// two parameters with "tcp", "<ip address>:<port>" to connect to libvirt on
// a remote machine.
c, err := net.DialTimeout("unix", "/var/run/libvirt/libvirt-sock", 2*time.Second)
if err != nil {
log.Fatalf("failed to dial libvirt: %v", err)

View File

@ -1,4 +1,4 @@
// Copyright 2016 The go-libvirt Authors.
// Copyright 2018 The go-libvirt Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -124,6 +124,10 @@ func (l *Libvirt) Disconnect() error {
}
}
// Deregister all the callbacks so that clients with outstanding requests
// will unblock.
l.deregisterAll()
_, err := l.request(constants.ProcConnectClose, constants.Program, nil)
if err != nil {
return err

32
rpc.go
View File

@ -1,4 +1,4 @@
// Copyright 2016 The go-libvirt Authors.
// Copyright 2018 The go-libvirt Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -117,6 +117,9 @@ func (e libvirtError) Error() string {
return e.Message
}
// checkError is used to check whether an error is a libvirtError, and if it is,
// whether its error code matches the one passed in. It will return false if
// these conditions are not met.
func checkError(err error, expectedError errorNumber) bool {
e, ok := err.(libvirtError)
if ok {
@ -140,6 +143,7 @@ func (l *Libvirt) listen() {
// When the underlying connection EOFs or is closed, stop
// this goroutine
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
l.deregisterAll()
return
}
@ -267,8 +271,24 @@ func (l *Libvirt) register(id uint32, c chan response) {
// deregister destroys a method response callback
func (l *Libvirt) deregister(id uint32) {
l.cm.Lock()
close(l.callbacks[id])
delete(l.callbacks, id)
if _, ok := l.callbacks[id]; ok {
close(l.callbacks[id])
delete(l.callbacks, id)
}
l.cm.Unlock()
}
// deregisterAll closes all the waiting callback channels. This is used to clean
// up if the connection to libvirt is lost. Callers waiting for responses will
// return an error when the response channel is closed, rather than just
// hanging.
func (l *Libvirt) deregisterAll() {
l.cm.Lock()
for id := range l.callbacks {
// can't call deregister() here because we're already holding the lock.
close(l.callbacks[id])
delete(l.callbacks, id)
}
l.cm.Unlock()
}
@ -304,7 +324,11 @@ func (l *Libvirt) processIncomingStream(c chan response, inStream io.Writer) (re
}
}
func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, outStream io.Reader, inStream io.Writer) (response, error) {
// requestStream performs a libvirt RPC request. The outStream and inStream
// parameters are optional, and should be nil for RPC endpoints that don't
// return a stream.
func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte,
outStream io.Reader, inStream io.Writer) (response, error) {
serial := l.serial()
c := make(chan response)

View File

@ -375,4 +375,28 @@ func TestLookup(t *testing.T) {
if d.Name != name {
t.Errorf("expected domain %s, got %s", name, d.Name)
}
// The callback should now be deregistered.
if _, ok := l.callbacks[id]; ok {
t.Error("expected callback to deregister")
}
}
func TestDeregisterAll(t *testing.T) {
conn := libvirttest.New()
c1 := make(chan response)
c2 := make(chan response)
l := New(conn)
if len(l.callbacks) != 0 {
t.Error("expected callback map to be empty at test start")
}
l.register(1, c1)
l.register(2, c2)
if len(l.callbacks) != 2 {
t.Error("expected callback map to have 2 entries after inserts")
}
l.deregisterAll()
if len(l.callbacks) != 0 {
t.Error("expected callback map to be empty after deregisterAll")
}
}