prevent connection write collisions (#52)

* prevent connection write collisions

When multiple calls are made on the same connection, it's possible for
the writes to collide with each other. This adds a write mutex when
communicating with the libvirt daemon.
This commit is contained in:
Ben LeMasurier 2017-12-07 10:47:52 -07:00 committed by GitHub
parent 7de663d9cc
commit 165035e03c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 0 deletions

View File

@ -39,6 +39,7 @@ type Libvirt struct {
conn net.Conn conn net.Conn
r *bufio.Reader r *bufio.Reader
w *bufio.Writer w *bufio.Writer
mu *sync.Mutex
// method callbacks // method callbacks
cm sync.Mutex cm sync.Mutex
@ -809,6 +810,7 @@ func New(conn net.Conn) *Libvirt {
s: 0, s: 0,
r: bufio.NewReader(conn), r: bufio.NewReader(conn),
w: bufio.NewWriter(conn), w: bufio.NewWriter(conn),
mu: &sync.Mutex{},
callbacks: make(map[uint32]chan response), callbacks: make(map[uint32]chan response),
events: make(map[uint32]chan *DomainEvent), events: make(map[uint32]chan *DomainEvent),
} }

View File

@ -19,6 +19,7 @@ package libvirt
import ( import (
"encoding/xml" "encoding/xml"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@ -251,6 +252,49 @@ func TestXMLIntegration(t *testing.T) {
} }
} }
// verify we're able to concurrently communicate with libvirtd.
// see: https://github.com/digitalocean/go-libvirt/pull/52
func Test_concurrentWrite(t *testing.T) {
l := New(testConn(t))
if err := l.Connect(); err != nil {
t.Error(err)
}
defer l.Disconnect()
count := 10
wg := sync.WaitGroup{}
wg.Add(count)
start := make(chan struct{})
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
<-start
_, err := l.Domains()
if err != nil {
t.Fatal(err)
}
}()
}
close(start)
select {
case <-done:
case <-time.After(10 * time.Second):
t.Fatal("timed out waiting for execution to complete")
}
}
func testConn(t *testing.T) net.Conn { func testConn(t *testing.T) net.Conn {
conn, err := net.DialTimeout("tcp", testAddr, time.Second*2) conn, err := net.DialTimeout("tcp", testAddr, time.Second*2)
if err != nil { if err != nil {

2
rpc.go
View File

@ -338,6 +338,8 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
} }
// write header // write header
l.mu.Lock()
defer l.mu.Unlock()
err := binary.Write(l.w, binary.BigEndian, p) err := binary.Write(l.w, binary.BigEndian, p)
if err != nil { if err != nil {
return nil, err return nil, err