diff --git a/src/encoding/pem/pem.go b/src/encoding/pem/pem.go index 4b4f7490210..951e9163571 100644 --- a/src/encoding/pem/pem.go +++ b/src/encoding/pem/pem.go @@ -193,30 +193,29 @@ type lineBreaker struct { var nl = []byte{'\n'} func (l *lineBreaker) Write(b []byte) (n int, err error) { - if l.used+len(b) < pemLineLength { - copy(l.line[l.used:], b) - l.used += len(b) - return len(b), nil + var n1 int + for len(b) > 0 { + if l.used+len(b) < pemLineLength { + copy(l.line[l.used:], b) + l.used += len(b) + n += len(b) + return + } + _, err = l.out.Write(l.line[0:l.used]) + excess := pemLineLength - l.used + l.used = 0 + n1, err = l.out.Write(b[0:excess]) + n += n1 + if err != nil { + return n, err + } + _, err = l.out.Write(nl) + if err != nil { + return + } + b = b[excess:] } - - n, err = l.out.Write(l.line[0:l.used]) - if err != nil { - return - } - excess := pemLineLength - l.used - l.used = 0 - - n, err = l.out.Write(b[0:excess]) - if err != nil { - return - } - - n, err = l.out.Write(nl) - if err != nil { - return - } - - return l.Write(b[excess:]) + return } func (l *lineBreaker) Close() (err error) { diff --git a/src/encoding/pem/pem_test.go b/src/encoding/pem/pem_test.go index 56a7754b220..b0920751aea 100644 --- a/src/encoding/pem/pem_test.go +++ b/src/encoding/pem/pem_test.go @@ -233,6 +233,34 @@ func TestLineBreaker(t *testing.T) { t.Errorf("#%d: (byte by byte) got:%s want:%s", i, got, test.out) } } + + t.Run("SmallBuffer", func(t *testing.T) { + buf := new(strings.Builder) + breaker := lineBreaker{out: buf} + input := bytes.Repeat([]byte("a"), 10) // Less than pemLineLength + + written, err := breaker.Write(input) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if written != len(input) { + t.Errorf("Expected to write %d bytes, wrote %d bytes", len(input), written) + } + }) + + t.Run("LargeBuffer", func(t *testing.T) { + buf := new(strings.Builder) + breaker := lineBreaker{out: buf} + input := bytes.Repeat([]byte("a"), 200) // More than pemLineLength + + written, err := breaker.Write(input) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if written != len(input) { + t.Errorf("Expected to write %d bytes, wrote %d bytes", len(input), written) + } + }) } func TestFuzz(t *testing.T) {