1
0
mirror of https://github.com/golang/go synced 2024-11-22 04:14:42 -07:00

container/list: add Sort() method

This is an initial implementation of the mergesort algorithm.
This commit is contained in:
Sergey Kamardin 2021-08-26 17:26:39 +03:00
parent 8b471db71b
commit 7958c2c0ca
2 changed files with 126 additions and 1 deletions

View File

@ -236,3 +236,72 @@ func (l *List) PushFrontList(other *List) {
l.insertValue(e.Value, &l.root) l.insertValue(e.Value, &l.root)
} }
} }
// Sort sorts list l given the provided less function. The less function
// receives the list element values and reports whether the first value stands
// before the second.
//
// The runtime complexity is O(n*log(n)).
func (l *List) Sort(less func(a, b interface{}) bool) {
if l.len <= 1 {
return
}
front := mergeSort(l.root.next, less)
// Fix prev pointers after shuffling the list.
var back *Element
for el := front; el != nil; back, el = el, el.Next() {
el.prev = back
}
// Fix prev and next pointers for the front and back elements.
front.prev = &l.root
back.next = &l.root
l.root = Element{
next: front,
prev: back,
}
}
func mergeSort(el *Element, less func(a, b interface{}) bool) *Element {
if el == nil || el.Next() == nil {
return el
}
lo, hi := split(el)
return merge(
mergeSort(lo, less),
mergeSort(hi, less),
less,
)
}
func split(el *Element) (lo, hi *Element) {
// Find out the middle element of the list.
var slow, fast *Element
for slow, fast = el, el.Next(); fast != nil && fast.Next() != nil; {
slow = slow.Next()
fast = fast.Next().Next()
}
lo = el
hi = slow.Next()
// Slow is the last element of the lo half, so mark its next pointer as nil
// temporarily. We will fix pointers later in the Sort().
slow.next = nil
return lo, hi
}
func merge(a, b *Element, less func(a, b interface{}) bool) (ret *Element) {
if a == nil {
return b
}
if b == nil {
return a
}
if less(a.Value, b.Value) {
a.next = merge(a.Next(), b, less)
ret = a
} else {
b.next = merge(a, b.Next(), less)
ret = b
}
return ret
}

View File

@ -4,7 +4,10 @@
package list package list
import "testing" import (
"strconv"
"testing"
)
func checkListLen(t *testing.T, l *List, len int) bool { func checkListLen(t *testing.T, l *List, len int) bool {
if n := l.Len(); n != len { if n := l.Len(); n != len {
@ -341,3 +344,56 @@ func TestMoveUnknownMark(t *testing.T) {
checkList(t, &l1, []interface{}{1}) checkList(t, &l1, []interface{}{1})
checkList(t, &l2, []interface{}{2}) checkList(t, &l2, []interface{}{2})
} }
func TestSort(t *testing.T) {
for _, size := range []int{
1, 2, 3, 4, 5, 6, 7, 8,
} {
t.Run(strconv.Itoa(size), func(t *testing.T) {
sorted := make([]int, size)
for i := range sorted {
sorted[i] = i
}
values := make([]interface{}, size)
for i, x := range sorted {
values[i] = x
}
for _, xs := range perm(sorted) {
l := New()
expPtr := make([]*Element, len(xs))
for _, x := range xs {
// x is the same here as an index in the sorted slice.
// That is, x holds index at which element should be after
// l.Sort().
expPtr[x] = l.PushBack(x)
}
l.Sort(func(a, b interface{}) bool {
return a.(int) < b.(int)
})
checkList(t, l, values)
checkListPointers(t, l, expPtr)
}
})
}
}
func perm(xs []int) [][]int {
var f func(int, []int) [][]int
f = func(head int, tail []int) (ret [][]int) {
if len(tail) == 0 {
return [][]int{{head}}
}
for _, xs := range f(tail[0], tail[1:]) {
h := len(xs)
xs = append(xs, head)
ret = append(ret, xs) // one with head at highest index.
for i := 0; i < h; i++ {
cp := append(([]int)(nil), xs...)
cp[i], cp[h] = cp[h], cp[i]
ret = append(ret, cp)
}
}
return ret
}
return f(xs[0], xs[1:])
}