prevent connection write collisions (#52)
* prevent connection write collisions When multiple calls are made on the same connection, it's possible for the writes to collide with each other. This adds a write mutex when communicating with the libvirt daemon.
This commit is contained in:
parent
7de663d9cc
commit
165035e03c
@ -39,6 +39,7 @@ type Libvirt struct {
|
|||||||
conn net.Conn
|
conn net.Conn
|
||||||
r *bufio.Reader
|
r *bufio.Reader
|
||||||
w *bufio.Writer
|
w *bufio.Writer
|
||||||
|
mu *sync.Mutex
|
||||||
|
|
||||||
// method callbacks
|
// method callbacks
|
||||||
cm sync.Mutex
|
cm sync.Mutex
|
||||||
@ -809,6 +810,7 @@ func New(conn net.Conn) *Libvirt {
|
|||||||
s: 0,
|
s: 0,
|
||||||
r: bufio.NewReader(conn),
|
r: bufio.NewReader(conn),
|
||||||
w: bufio.NewWriter(conn),
|
w: bufio.NewWriter(conn),
|
||||||
|
mu: &sync.Mutex{},
|
||||||
callbacks: make(map[uint32]chan response),
|
callbacks: make(map[uint32]chan response),
|
||||||
events: make(map[uint32]chan *DomainEvent),
|
events: make(map[uint32]chan *DomainEvent),
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ package libvirt
|
|||||||
import (
|
import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -251,6 +252,49 @@ func TestXMLIntegration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// verify we're able to concurrently communicate with libvirtd.
|
||||||
|
// see: https://github.com/digitalocean/go-libvirt/pull/52
|
||||||
|
func Test_concurrentWrite(t *testing.T) {
|
||||||
|
l := New(testConn(t))
|
||||||
|
|
||||||
|
if err := l.Connect(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
defer l.Disconnect()
|
||||||
|
|
||||||
|
count := 10
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(count)
|
||||||
|
|
||||||
|
start := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
|
||||||
|
_, err := l.Domains()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
close(start)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for execution to complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func testConn(t *testing.T) net.Conn {
|
func testConn(t *testing.T) net.Conn {
|
||||||
conn, err := net.DialTimeout("tcp", testAddr, time.Second*2)
|
conn, err := net.DialTimeout("tcp", testAddr, time.Second*2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
2
rpc.go
2
rpc.go
@ -338,6 +338,8 @@ func (l *Libvirt) request(proc uint32, program uint32, payload *bytes.Buffer) (<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write header
|
// write header
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
err := binary.Write(l.w, binary.BigEndian, p)
|
err := binary.Write(l.w, binary.BigEndian, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user