diff --git a/src/compress/flate/huffman_code.go b/src/compress/flate/huffman_code.go index ade4c8fb281..1e26d19300f 100644 --- a/src/compress/flate/huffman_code.go +++ b/src/compress/flate/huffman_code.go @@ -5,9 +5,10 @@ package flate import ( + "cmp" "math" "math/bits" - "sort" + "slices" ) // hcode is a huffman code with a bit code and bit length. @@ -19,8 +20,6 @@ type huffmanEncoder struct { codes []hcode freqcache []literalNode bitCount [17]int32 - lns byLiteral // stored to avoid repeated allocation in generate - lfs byFreq // stored to avoid repeated allocation in generate } type literalNode struct { @@ -256,7 +255,9 @@ func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalN // assigned in literal order (not frequency order). chunk := list[len(list)-int(bits):] - h.lns.sort(chunk) + slices.SortFunc(chunk, func(a, b literalNode) int { + return cmp.Compare(a.literal, b.literal) + }) for _, node := range chunk { h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)} code++ @@ -299,7 +300,12 @@ func (h *huffmanEncoder) generate(freq []int32, maxBits int32) { } return } - h.lfs.sort(list) + slices.SortFunc(list, func(a, b literalNode) int { + if c := cmp.Compare(a.freq, b.freq); c != 0 { + return c + } + return cmp.Compare(a.literal, b.literal) + }) // Get the number of literals for each bit count bitCount := h.bitCounts(list, maxBits) @@ -307,39 +313,6 @@ func (h *huffmanEncoder) generate(freq []int32, maxBits int32) { h.assignEncodingAndSize(bitCount, list) } -type byLiteral []literalNode - -func (s *byLiteral) sort(a []literalNode) { - *s = byLiteral(a) - sort.Sort(s) -} - -func (s byLiteral) Len() int { return len(s) } - -func (s byLiteral) Less(i, j int) bool { - return s[i].literal < s[j].literal -} - -func (s byLiteral) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type byFreq []literalNode - -func (s *byFreq) sort(a []literalNode) { - *s = byFreq(a) - sort.Sort(s) -} - -func (s byFreq) Len() int { return len(s) } - -func (s byFreq) Less(i, j int) bool { - if s[i].freq == s[j].freq { - return s[i].literal < s[j].literal - } - return s[i].freq < s[j].freq -} - -func (s byFreq) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - func reverseBits(number uint16, bitLength byte) uint16 { return bits.Reverse16(number << (16 - bitLength)) }