prevent connection write collisions (#52)
* prevent connection write collisions When multiple calls are made on the same connection, it's possible for the writes to collide with each other. This adds a write mutex when communicating with the libvirt daemon.
This commit is contained in:
		| @@ -39,6 +39,7 @@ type Libvirt struct { | ||||
| 	conn net.Conn | ||||
| 	r    *bufio.Reader | ||||
| 	w    *bufio.Writer | ||||
| 	mu   *sync.Mutex | ||||
|  | ||||
| 	// method callbacks | ||||
| 	cm        sync.Mutex | ||||
| @@ -809,6 +810,7 @@ func New(conn net.Conn) *Libvirt { | ||||
| 		s:         0, | ||||
| 		r:         bufio.NewReader(conn), | ||||
| 		w:         bufio.NewWriter(conn), | ||||
| 		mu:        &sync.Mutex{}, | ||||
| 		callbacks: make(map[uint32]chan response), | ||||
| 		events:    make(map[uint32]chan *DomainEvent), | ||||
| 	} | ||||
|   | ||||
| @@ -19,6 +19,7 @@ package libvirt | ||||
| import ( | ||||
| 	"encoding/xml" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| @@ -251,6 +252,49 @@ func TestXMLIntegration(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // verify we're able to concurrently communicate with libvirtd. | ||||
| // see: https://github.com/digitalocean/go-libvirt/pull/52 | ||||
| func Test_concurrentWrite(t *testing.T) { | ||||
| 	l := New(testConn(t)) | ||||
|  | ||||
| 	if err := l.Connect(); err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	defer l.Disconnect() | ||||
|  | ||||
| 	count := 10 | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(count) | ||||
|  | ||||
| 	start := make(chan struct{}) | ||||
| 	done := make(chan struct{}) | ||||
|  | ||||
| 	go func() { | ||||
| 		wg.Wait() | ||||
| 		close(done) | ||||
| 	}() | ||||
|  | ||||
| 	for i := 0; i < count; i++ { | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			<-start | ||||
|  | ||||
| 			_, err := l.Domains() | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	close(start) | ||||
|  | ||||
| 	select { | ||||
| 	case <-done: | ||||
| 	case <-time.After(10 * time.Second): | ||||
| 		t.Fatal("timed out waiting for execution to complete") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func testConn(t *testing.T) net.Conn { | ||||
| 	conn, err := net.DialTimeout("tcp", testAddr, time.Second*2) | ||||
| 	if err != nil { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user