1
0
mirror of https://github.com/golang/go synced 2024-11-24 00:00:23 -07:00

compress/flate: simplify sorting in huffman_code

Replace custom sort implementations with slices.SortFunc and
cmp.Compare for clarity and simplicity. This change removes
the previously defined byLiteral and byFreq types and their
associated methods, leveraging the standard library's generic
functions to achieve the same functionality.

The update makes use of "slices" and "cmp" packages for sorting,
streamlining the codebase and improving readability.
This commit is contained in:
aimuz 2024-04-14 16:27:11 +08:00
parent 519f6a00e4
commit 1070520c71
No known key found for this signature in database
GPG Key ID: 63C3DC9FBA22D9D7

View File

@ -5,9 +5,10 @@
package flate
import (
"cmp"
"math"
"math/bits"
"sort"
"slices"
)
// hcode is a huffman code with a bit code and bit length.
@ -19,8 +20,6 @@ type huffmanEncoder struct {
codes []hcode
freqcache []literalNode
bitCount [17]int32
lns byLiteral // stored to avoid repeated allocation in generate
lfs byFreq // stored to avoid repeated allocation in generate
}
type literalNode struct {
@ -256,7 +255,9 @@ func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalN
// assigned in literal order (not frequency order).
chunk := list[len(list)-int(bits):]
h.lns.sort(chunk)
slices.SortFunc(chunk, func(a, b literalNode) int {
return cmp.Compare(a.literal, b.literal)
})
for _, node := range chunk {
h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)}
code++
@ -299,7 +300,12 @@ func (h *huffmanEncoder) generate(freq []int32, maxBits int32) {
}
return
}
h.lfs.sort(list)
slices.SortFunc(list, func(a, b literalNode) int {
if c := cmp.Compare(a.freq, b.freq); c != 0 {
return c
}
return cmp.Compare(a.literal, b.literal)
})
// Get the number of literals for each bit count
bitCount := h.bitCounts(list, maxBits)
@ -307,39 +313,6 @@ func (h *huffmanEncoder) generate(freq []int32, maxBits int32) {
h.assignEncodingAndSize(bitCount, list)
}
type byLiteral []literalNode
func (s *byLiteral) sort(a []literalNode) {
*s = byLiteral(a)
sort.Sort(s)
}
func (s byLiteral) Len() int { return len(s) }
func (s byLiteral) Less(i, j int) bool {
return s[i].literal < s[j].literal
}
func (s byLiteral) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type byFreq []literalNode
func (s *byFreq) sort(a []literalNode) {
*s = byFreq(a)
sort.Sort(s)
}
func (s byFreq) Len() int { return len(s) }
func (s byFreq) Less(i, j int) bool {
if s[i].freq == s[j].freq {
return s[i].literal < s[j].literal
}
return s[i].freq < s[j].freq
}
func (s byFreq) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func reverseBits(number uint16, bitLength byte) uint16 {
return bits.Reverse16(number << (16 - bitLength))
}