Geoff/dereg callbacks on disconnect (#69)
Deregister all callbacks when disconnecting from libvirt. Deregister all callbacks when we lose or close the connection to libvirt. This fixes a problem where goroutines with outstanding requests waiting for replies would block forever if the libvirt connection dies, whether because disconnect is called, or the libvirt daemon crashes or restarts.
This commit is contained in:
parent
17c24de803
commit
2b098b4625
@ -3,7 +3,7 @@ libvirt [![GoDoc](http://godoc.org/github.com/digitalocean/go-libvirt?status.svg
|
||||
|
||||
Package `go-libvirt` provides a pure Go interface for interacting with libvirt.
|
||||
|
||||
Rather than using Libvirt's C bindings, this package makes use of
|
||||
Rather than using libvirt's C bindings, this package makes use of
|
||||
libvirt's RPC interface, as documented [here](https://libvirt.org/internals/rpc.html).
|
||||
Connections to the libvirt server may be local, or remote. RPC packets are encoded
|
||||
using the XDR standard as defined by [RFC 4506](https://tools.ietf.org/html/rfc4506.html).
|
||||
@ -20,6 +20,7 @@ and produces go bindings for all the remote procedures defined there.
|
||||
|
||||
How to Use This Library
|
||||
-----------------------
|
||||
|
||||
Once you've vendored go-libvirt into your project, you'll probably want to call
|
||||
some libvirt functions. There's some example code below showing how to connect
|
||||
to libvirt and make one such call, but once you get past the introduction you'll
|
||||
@ -108,8 +109,9 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
//c, err := net.DialTimeout("tcp", "127.0.0.1:16509", 2*time.Second)
|
||||
//c, err := net.DialTimeout("tcp", "192.168.1.12:16509", 2*time.Second)
|
||||
// This dials libvirt on the local machine, but you can substitute the first
|
||||
// two parameters with "tcp", "<ip address>:<port>" to connect to libvirt on
|
||||
// a remote machine.
|
||||
c, err := net.DialTimeout("unix", "/var/run/libvirt/libvirt-sock", 2*time.Second)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to dial libvirt: %v", err)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2016 The go-libvirt Authors.
|
||||
// Copyright 2018 The go-libvirt Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -124,6 +124,10 @@ func (l *Libvirt) Disconnect() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Deregister all the callbacks so that clients with outstanding requests
|
||||
// will unblock.
|
||||
l.deregisterAll()
|
||||
|
||||
_, err := l.request(constants.ProcConnectClose, constants.Program, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
32
rpc.go
32
rpc.go
@ -1,4 +1,4 @@
|
||||
// Copyright 2016 The go-libvirt Authors.
|
||||
// Copyright 2018 The go-libvirt Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -117,6 +117,9 @@ func (e libvirtError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// checkError is used to check whether an error is a libvirtError, and if it is,
|
||||
// whether its error code matches the one passed in. It will return false if
|
||||
// these conditions are not met.
|
||||
func checkError(err error, expectedError errorNumber) bool {
|
||||
e, ok := err.(libvirtError)
|
||||
if ok {
|
||||
@ -140,6 +143,7 @@ func (l *Libvirt) listen() {
|
||||
// When the underlying connection EOFs or is closed, stop
|
||||
// this goroutine
|
||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||
l.deregisterAll()
|
||||
return
|
||||
}
|
||||
|
||||
@ -267,8 +271,24 @@ func (l *Libvirt) register(id uint32, c chan response) {
|
||||
// deregister destroys a method response callback
|
||||
func (l *Libvirt) deregister(id uint32) {
|
||||
l.cm.Lock()
|
||||
close(l.callbacks[id])
|
||||
delete(l.callbacks, id)
|
||||
if _, ok := l.callbacks[id]; ok {
|
||||
close(l.callbacks[id])
|
||||
delete(l.callbacks, id)
|
||||
}
|
||||
l.cm.Unlock()
|
||||
}
|
||||
|
||||
// deregisterAll closes all the waiting callback channels. This is used to clean
|
||||
// up if the connection to libvirt is lost. Callers waiting for responses will
|
||||
// return an error when the response channel is closed, rather than just
|
||||
// hanging.
|
||||
func (l *Libvirt) deregisterAll() {
|
||||
l.cm.Lock()
|
||||
for id := range l.callbacks {
|
||||
// can't call deregister() here because we're already holding the lock.
|
||||
close(l.callbacks[id])
|
||||
delete(l.callbacks, id)
|
||||
}
|
||||
l.cm.Unlock()
|
||||
}
|
||||
|
||||
@ -304,7 +324,11 @@ func (l *Libvirt) processIncomingStream(c chan response, inStream io.Writer) (re
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, outStream io.Reader, inStream io.Writer) (response, error) {
|
||||
// requestStream performs a libvirt RPC request. The outStream and inStream
|
||||
// parameters are optional, and should be nil for RPC endpoints that don't
|
||||
// return a stream.
|
||||
func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte,
|
||||
outStream io.Reader, inStream io.Writer) (response, error) {
|
||||
serial := l.serial()
|
||||
c := make(chan response)
|
||||
|
||||
|
24
rpc_test.go
24
rpc_test.go
@ -375,4 +375,28 @@ func TestLookup(t *testing.T) {
|
||||
if d.Name != name {
|
||||
t.Errorf("expected domain %s, got %s", name, d.Name)
|
||||
}
|
||||
|
||||
// The callback should now be deregistered.
|
||||
if _, ok := l.callbacks[id]; ok {
|
||||
t.Error("expected callback to deregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeregisterAll(t *testing.T) {
|
||||
conn := libvirttest.New()
|
||||
c1 := make(chan response)
|
||||
c2 := make(chan response)
|
||||
l := New(conn)
|
||||
if len(l.callbacks) != 0 {
|
||||
t.Error("expected callback map to be empty at test start")
|
||||
}
|
||||
l.register(1, c1)
|
||||
l.register(2, c2)
|
||||
if len(l.callbacks) != 2 {
|
||||
t.Error("expected callback map to have 2 entries after inserts")
|
||||
}
|
||||
l.deregisterAll()
|
||||
if len(l.callbacks) != 0 {
|
||||
t.Error("expected callback map to be empty after deregisterAll")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user