diff --git a/rpc.go b/rpc.go index e1ee7ff..c2ced94 100644 --- a/rpc.go +++ b/rpc.go @@ -329,38 +329,7 @@ func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, out abortOutStream = make(chan bool) outStreamErr = make(chan error) go func() { - var err error - var n int - buf := make([]byte, 1<<22-24) - for { - select { - case <-abortOutStream: - err = l.sendPacket(serial, proc, program, nil, Stream, StatusError) - break - default: - } - n, err = outStream.Read(buf) - if err != nil { - if err == io.EOF { - err = l.sendPacket(serial, proc, program, nil, Stream, StatusOK) - } else { - // keep original error - err := l.sendPacket(serial, proc, program, nil, Stream, StatusError) - if err != nil { - outStreamErr <- err - return - } - } - break - } - if n > 0 { - err = l.sendPacket(serial, proc, program, buf[:n], Stream, StatusContinue) - if err != nil { - break - } - } - } - outStreamErr <- err + outStreamErr <- l.sendStream(serial, proc, program, outStream, abortOutStream) }() } @@ -402,6 +371,35 @@ func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, out return resp, nil } +func (l *Libvirt) sendStream(serial uint32, proc uint32, program uint32, stream io.Reader, abort chan bool) error { + buf := make([]byte, 1<<22-24) + for { + select { + case <-abort: + return l.sendPacket(serial, proc, program, nil, Stream, StatusError) + default: + } + n, err := stream.Read(buf) + if err != nil { + if err == io.EOF { + return l.sendPacket(serial, proc, program, nil, Stream, StatusOK) + } + // keep original error + err2 := l.sendPacket(serial, proc, program, nil, Stream, StatusError) + if err2 != nil { + return err2 + } + return err + } + if n > 0 { + err = l.sendPacket(serial, proc, program, buf[:n], Stream, StatusContinue) + if err != nil { + return err + } + } + } +} + func (l *Libvirt) sendPacket(serial uint32, proc uint32, program uint32, payload []byte, typ uint32, status uint32) error { size := constants.PacketLengthSize + constants.HeaderSize if payload != nil {