From 3deb572f72950c0247542085648362428bf40840 Mon Sep 17 00:00:00 2001 From: pugnack Date: Sun, 15 Jun 2025 19:24:48 +0500 Subject: [PATCH] [v4] fix out-of-bounds behavior in seeker buffer and add tests (#219) * add check negative position to Read() and write tests * add tests for Write() method * add tests for Write() method * add checks of whence and negative position to Seek() and write tests * add tests for Rewind() * add tests for Close() * add tests for Reset() * add tests for Len() * add tests for Bytes() * tests polishing * tests polishing * tests polishing * tests polishing --- logger/slog/slog_test.go | 5 +- store/noop_test.go | 4 +- util/buffer/seeker_buffer.go | 46 +++- util/buffer/seeker_buffer_test.go | 406 +++++++++++++++++++++++++++--- 4 files changed, 412 insertions(+), 49 deletions(-) diff --git a/logger/slog/slog_test.go b/logger/slog/slog_test.go index f6d46393..4ba8783a 100644 --- a/logger/slog/slog_test.go +++ b/logger/slog/slog_test.go @@ -80,7 +80,7 @@ func TestTime(t *testing.T) { WithHandlerFunc(slog.NewTextHandler), logger.WithAddStacktrace(true), logger.WithTimeFunc(func() time.Time { - return time.Unix(0, 0) + return time.Unix(0, 0).UTC() }), ) if err := l.Init(logger.WithFields("key1", "val1")); err != nil { @@ -89,8 +89,7 @@ func TestTime(t *testing.T) { l.Error(ctx, "msg1", errors.New("err")) - if !bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T03:00:00.000000000+03:00`)) && - !bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T00:00:00.000000000Z`)) { + if !bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T00:00:00.000000000Z`)) { t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) } } diff --git a/store/noop_test.go b/store/noop_test.go index 19fab2d7..009f7a70 100644 --- a/store/noop_test.go +++ b/store/noop_test.go @@ -2,6 +2,7 @@ package store import ( "context" + "errors" "testing" ) @@ -25,7 +26,8 @@ func TestHook(t *testing.T) { t.Fatal(err) } - if err := s.Exists(context.TODO(), "test"); err != nil { + err := s.Exists(context.TODO(), "test") + if !errors.Is(err, ErrNotFound) { t.Fatal(err) } diff --git a/util/buffer/seeker_buffer.go b/util/buffer/seeker_buffer.go index 5d48c826..924ccb57 100644 --- a/util/buffer/seeker_buffer.go +++ b/util/buffer/seeker_buffer.go @@ -1,13 +1,16 @@ package buffer -import "io" +import ( + "fmt" + "io" +) var _ interface { io.ReadCloser io.ReadSeeker } = (*SeekerBuffer)(nil) -// Buffer is a ReadWriteCloser that supports seeking. It's intended to +// SeekerBuffer is a ReadWriteCloser that supports seeking. It's intended to // replicate the functionality of bytes.Buffer that I use in my projects. // // Note that the seeking is limited to the read marker; all writes are @@ -23,6 +26,7 @@ func NewSeekerBuffer(data []byte) *SeekerBuffer { } } +// Read reads up to len(p) bytes into p from the current read position. func (b *SeekerBuffer) Read(p []byte) (int, error) { if b.pos >= int64(len(b.data)) { return 0, io.EOF @@ -30,29 +34,51 @@ func (b *SeekerBuffer) Read(p []byte) (int, error) { n := copy(p, b.data[b.pos:]) b.pos += int64(n) + return n, nil } +// Write appends the contents of p to the end of the buffer. It does not affect the read position. func (b *SeekerBuffer) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + b.data = append(b.data, p...) + return len(p), nil } -// Seek sets the read pointer to pos. +// Seek sets the offset for the next Read operation. +// The offset is interpreted according to whence: +// - io.SeekStart: relative to the beginning of the buffer +// - io.SeekCurrent: relative to the current position +// - io.SeekEnd: relative to the end of the buffer +// +// Returns an error if the resulting position is negative or if whence is invalid. func (b *SeekerBuffer) Seek(offset int64, whence int) (int64, error) { + var newPos int64 + switch whence { case io.SeekStart: - b.pos = offset + newPos = offset case io.SeekEnd: - b.pos = int64(len(b.data)) + offset + newPos = int64(len(b.data)) + offset case io.SeekCurrent: - b.pos += offset + newPos = b.pos + offset + default: + return 0, fmt.Errorf("invalid whence: %d", whence) } + if newPos < 0 { + return 0, fmt.Errorf("invalid seek: resulting position %d is negative", newPos) + } + + b.pos = newPos return b.pos, nil } -// Rewind resets the read pointer to 0. +// Rewind resets the read position to 0. func (b *SeekerBuffer) Rewind() error { if _, err := b.Seek(0, io.SeekStart); err != nil { return err @@ -75,10 +101,16 @@ func (b *SeekerBuffer) Reset() { // Len returns the length of data remaining to be read. func (b *SeekerBuffer) Len() int { + if b.pos >= int64(len(b.data)) { + return 0 + } return len(b.data[b.pos:]) } // Bytes returns the underlying bytes from the current position. func (b *SeekerBuffer) Bytes() []byte { + if b.pos >= int64(len(b.data)) { + return []byte{} + } return b.data[b.pos:] } diff --git a/util/buffer/seeker_buffer_test.go b/util/buffer/seeker_buffer_test.go index 915aa19d..2afc7126 100644 --- a/util/buffer/seeker_buffer_test.go +++ b/util/buffer/seeker_buffer_test.go @@ -2,54 +2,384 @@ package buffer import ( "fmt" - "strings" + "io" "testing" + + "github.com/stretchr/testify/require" ) -func noErrorT(t *testing.T, err error) { - if nil != err { - t.Fatalf("%s", err) +func TestNewSeekerBuffer(t *testing.T) { + input := []byte{'a', 'b', 'c', 'd', 'e'} + expected := &SeekerBuffer{data: []byte{'a', 'b', 'c', 'd', 'e'}, pos: 0} + require.Equal(t, expected, NewSeekerBuffer(input)) +} + +func TestSeekerBuffer_Read(t *testing.T) { + tests := []struct { + name string + data []byte + initPos int64 + readBuf []byte + expectedN int + expectedData []byte + expectedErr error + expectedPos int64 + }{ + { + name: "read with empty buffer", + data: []byte("hello"), + initPos: 0, + readBuf: []byte{}, + expectedN: 0, + expectedData: []byte{}, + expectedErr: nil, + expectedPos: 0, + }, + { + name: "read with nil buffer", + data: []byte("hello"), + initPos: 0, + readBuf: nil, + expectedN: 0, + expectedData: nil, + expectedErr: nil, + expectedPos: 0, + }, + { + name: "read full buffer", + data: []byte("hello"), + initPos: 0, + readBuf: make([]byte, 5), + expectedN: 5, + expectedData: []byte("hello"), + expectedErr: nil, + expectedPos: 5, + }, + { + name: "read partial buffer", + data: []byte("hello"), + initPos: 2, + readBuf: make([]byte, 2), + expectedN: 2, + expectedData: []byte("ll"), + expectedErr: nil, + expectedPos: 4, + }, + { + name: "read after end", + data: []byte("hello"), + initPos: 5, + readBuf: make([]byte, 5), + expectedN: 0, + expectedData: make([]byte, 5), + expectedErr: io.EOF, + expectedPos: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sb := NewSeekerBuffer(tt.data) + sb.pos = tt.initPos + + n, err := sb.Read(tt.readBuf) + + if tt.expectedErr != nil { + require.Equal(t, err, tt.expectedErr) + } else { + require.NoError(t, err) + } + + require.Equal(t, tt.expectedN, n) + require.Equal(t, tt.expectedData, tt.readBuf) + require.Equal(t, tt.expectedPos, sb.pos) + }) } } -func boolT(t *testing.T, cond bool, s ...string) { - if !cond { - what := strings.Join(s, ", ") - if len(what) > 0 { - what = ": " + what - } - t.Fatalf("assert.Bool failed%s", what) +func TestSeekerBuffer_Write(t *testing.T) { + tests := []struct { + name string + initialData []byte + initialPos int64 + writeData []byte + expectedData []byte + expectedN int + }{ + { + name: "write empty slice", + initialData: []byte("data"), + initialPos: 0, + writeData: []byte{}, + expectedData: []byte("data"), + expectedN: 0, + }, + { + name: "write nil slice", + initialData: []byte("data"), + initialPos: 0, + writeData: nil, + expectedData: []byte("data"), + expectedN: 0, + }, + { + name: "write to empty buffer", + initialData: nil, + initialPos: 0, + writeData: []byte("abc"), + expectedData: []byte("abc"), + expectedN: 3, + }, + { + name: "write to existing buffer", + initialData: []byte("hello"), + initialPos: 0, + writeData: []byte(" world"), + expectedData: []byte("hello world"), + expectedN: 6, + }, + { + name: "write after read", + initialData: []byte("abc"), + initialPos: 2, + writeData: []byte("XYZ"), + expectedData: []byte("abcXYZ"), + expectedN: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sb := NewSeekerBuffer(tt.initialData) + sb.pos = tt.initialPos + + n, err := sb.Write(tt.writeData) + require.NoError(t, err) + require.Equal(t, tt.expectedN, n) + require.Equal(t, tt.expectedData, sb.data) + require.Equal(t, tt.initialPos, sb.pos) + }) } } -func TestSeeking(t *testing.T) { - partA := []byte("hello, ") - partB := []byte("world!") +func TestSeekerBuffer_Seek(t *testing.T) { + tests := []struct { + name string + initialData []byte + initialPos int64 + offset int64 + whence int + expectedPos int64 + expectedErr error + }{ + { + name: "seek with invalid whence", + initialData: []byte("abcdef"), + initialPos: 0, + offset: 1, + whence: 12345, + expectedPos: 0, + expectedErr: fmt.Errorf("invalid whence: %d", 12345), + }, + { + name: "seek negative from start", + initialData: []byte("abcdef"), + initialPos: 0, + offset: -1, + whence: io.SeekStart, + expectedPos: 0, + expectedErr: fmt.Errorf("invalid seek: resulting position %d is negative", -1), + }, + { + name: "seek from start to 0", + initialData: []byte("abcdef"), + initialPos: 0, + offset: 0, + whence: io.SeekStart, + expectedPos: 0, + expectedErr: nil, + }, + { + name: "seek from start to 3", + initialData: []byte("abcdef"), + initialPos: 0, + offset: 3, + whence: io.SeekStart, + expectedPos: 3, + expectedErr: nil, + }, + { + name: "seek from end to -1 (last byte)", + initialData: []byte("abcdef"), + initialPos: 0, + offset: -1, + whence: io.SeekEnd, + expectedPos: 5, + expectedErr: nil, + }, + { + name: "seek from current forward", + initialData: []byte("abcdef"), + initialPos: 2, + offset: 2, + whence: io.SeekCurrent, + expectedPos: 4, + expectedErr: nil, + }, + { + name: "seek from current backward", + initialData: []byte("abcdef"), + initialPos: 4, + offset: -2, + whence: io.SeekCurrent, + expectedPos: 2, + expectedErr: nil, + }, + { + name: "seek to end exactly", + initialData: []byte("abcdef"), + initialPos: 0, + offset: 0, + whence: io.SeekEnd, + expectedPos: 6, + expectedErr: nil, + }, + { + name: "seek to out of range", + initialData: []byte("abcdef"), + initialPos: 0, + offset: 2, + whence: io.SeekEnd, + expectedPos: 8, + expectedErr: nil, + }, + } - buf := NewSeekerBuffer(partA) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sb := NewSeekerBuffer(tt.initialData) + sb.pos = tt.initialPos - boolT(t, buf.Len() == len(partA), fmt.Sprintf("on init: have length %d, want length %d", buf.Len(), len(partA))) + newPos, err := sb.Seek(tt.offset, tt.whence) - b := make([]byte, 32) - - n, err := buf.Read(b) - noErrorT(t, err) - boolT(t, buf.Len() == 0, fmt.Sprintf("after reading 1: have length %d, want length 0", buf.Len())) - boolT(t, n == len(partA), fmt.Sprintf("after reading 2: have length %d, want length %d", n, len(partA))) - - n, err = buf.Write(partB) - noErrorT(t, err) - boolT(t, n == len(partB), fmt.Sprintf("after writing: have length %d, want length %d", n, len(partB))) - - n, err = buf.Read(b) - noErrorT(t, err) - boolT(t, buf.Len() == 0, fmt.Sprintf("after rereading 1: have length %d, want length 0", buf.Len())) - boolT(t, n == len(partB), fmt.Sprintf("after rereading 2: have length %d, want length %d", n, len(partB))) - - partsLen := len(partA) + len(partB) - _ = buf.Rewind() - boolT(t, buf.Len() == partsLen, fmt.Sprintf("after rewinding: have length %d, want length %d", buf.Len(), partsLen)) - - buf.Close() - boolT(t, buf.Len() == 0, fmt.Sprintf("after closing, have length %d, want length 0", buf.Len())) + if tt.expectedErr != nil { + require.Equal(t, tt.expectedErr, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedPos, newPos) + require.Equal(t, tt.expectedPos, sb.pos) + } + }) + } +} + +func TestSeekerBuffer_Rewind(t *testing.T) { + buf := NewSeekerBuffer([]byte("hello world")) + buf.pos = 4 + + require.NoError(t, buf.Rewind()) + require.Equal(t, []byte("hello world"), buf.data) + require.Equal(t, int64(0), buf.pos) +} + +func TestSeekerBuffer_Close(t *testing.T) { + buf := NewSeekerBuffer([]byte("hello world")) + buf.pos = 2 + + require.NoError(t, buf.Close()) + require.Nil(t, buf.data) + require.Equal(t, int64(0), buf.pos) +} + +func TestSeekerBuffer_Reset(t *testing.T) { + buf := NewSeekerBuffer([]byte("hello world")) + buf.pos = 2 + + buf.Reset() + require.Nil(t, buf.data) + require.Equal(t, int64(0), buf.pos) +} + +func TestSeekerBuffer_Len(t *testing.T) { + tests := []struct { + name string + data []byte + pos int64 + expected int + }{ + { + name: "full buffer", + data: []byte("abcde"), + pos: 0, + expected: 5, + }, + { + name: "partial read", + data: []byte("abcde"), + pos: 2, + expected: 3, + }, + { + name: "fully read", + data: []byte("abcde"), + pos: 5, + expected: 0, + }, + { + name: "pos > len", + data: []byte("abcde"), + pos: 10, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := NewSeekerBuffer(tt.data) + buf.pos = tt.pos + require.Equal(t, tt.expected, buf.Len()) + }) + } +} + +func TestSeekerBuffer_Bytes(t *testing.T) { + tests := []struct { + name string + data []byte + pos int64 + expected []byte + }{ + { + name: "start of buffer", + data: []byte("abcde"), + pos: 0, + expected: []byte("abcde"), + }, + { + name: "middle of buffer", + data: []byte("abcde"), + pos: 2, + expected: []byte("cde"), + }, + { + name: "end of buffer", + data: []byte("abcde"), + pos: 5, + expected: []byte{}, + }, + { + name: "pos beyond end", + data: []byte("abcde"), + pos: 10, + expected: []byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := NewSeekerBuffer(tt.data) + buf.pos = tt.pos + require.Equal(t, tt.expected, buf.Bytes()) + }) + } }