diff --git a/src/container/list/list.go b/src/container/list/list.go index 210424ceed7..5ed7cbd718b 100644 --- a/src/container/list/list.go +++ b/src/container/list/list.go @@ -236,3 +236,72 @@ func (l *List) PushFrontList(other *List) { l.insertValue(e.Value, &l.root) } } + +// Sort sorts list l given the provided less function. The less function +// receives the list element values and reports whether the first value stands +// before the second. +// +// The runtime complexity is O(n*log(n)). +func (l *List) Sort(less func(a, b interface{}) bool) { + if l.len <= 1 { + return + } + front := mergeSort(l.root.next, less) + // Fix prev pointers after shuffling the list. + var back *Element + for el := front; el != nil; back, el = el, el.Next() { + el.prev = back + } + // Fix prev and next pointers for the front and back elements. + front.prev = &l.root + back.next = &l.root + + l.root = Element{ + next: front, + prev: back, + } +} + +func mergeSort(el *Element, less func(a, b interface{}) bool) *Element { + if el == nil || el.Next() == nil { + return el + } + lo, hi := split(el) + return merge( + mergeSort(lo, less), + mergeSort(hi, less), + less, + ) +} + +func split(el *Element) (lo, hi *Element) { + // Find out the middle element of the list. + var slow, fast *Element + for slow, fast = el, el.Next(); fast != nil && fast.Next() != nil; { + slow = slow.Next() + fast = fast.Next().Next() + } + lo = el + hi = slow.Next() + // Slow is the last element of the lo half, so mark its next pointer as nil + // temporarily. We will fix pointers later in the Sort(). + slow.next = nil + return lo, hi +} + +func merge(a, b *Element, less func(a, b interface{}) bool) (ret *Element) { + if a == nil { + return b + } + if b == nil { + return a + } + if less(a.Value, b.Value) { + a.next = merge(a.Next(), b, less) + ret = a + } else { + b.next = merge(a, b.Next(), less) + ret = b + } + return ret +} diff --git a/src/container/list/list_test.go b/src/container/list/list_test.go index 99e006f39fd..83d5159ff0b 100644 --- a/src/container/list/list_test.go +++ b/src/container/list/list_test.go @@ -4,7 +4,10 @@ package list -import "testing" +import ( + "strconv" + "testing" +) func checkListLen(t *testing.T, l *List, len int) bool { if n := l.Len(); n != len { @@ -341,3 +344,56 @@ func TestMoveUnknownMark(t *testing.T) { checkList(t, &l1, []interface{}{1}) checkList(t, &l2, []interface{}{2}) } + +func TestSort(t *testing.T) { + for _, size := range []int{ + 1, 2, 3, 4, 5, 6, 7, 8, + } { + t.Run(strconv.Itoa(size), func(t *testing.T) { + sorted := make([]int, size) + for i := range sorted { + sorted[i] = i + } + values := make([]interface{}, size) + for i, x := range sorted { + values[i] = x + } + for _, xs := range perm(sorted) { + l := New() + expPtr := make([]*Element, len(xs)) + for _, x := range xs { + // x is the same here as an index in the sorted slice. + // That is, x holds index at which element should be after + // l.Sort(). + expPtr[x] = l.PushBack(x) + } + l.Sort(func(a, b interface{}) bool { + return a.(int) < b.(int) + }) + checkList(t, l, values) + checkListPointers(t, l, expPtr) + } + }) + } +} + +func perm(xs []int) [][]int { + var f func(int, []int) [][]int + f = func(head int, tail []int) (ret [][]int) { + if len(tail) == 0 { + return [][]int{{head}} + } + for _, xs := range f(tail[0], tail[1:]) { + h := len(xs) + xs = append(xs, head) + ret = append(ret, xs) // one with head at highest index. + for i := 0; i < h; i++ { + cp := append(([]int)(nil), xs...) + cp[i], cp[h] = cp[h], cp[i] + ret = append(ret, cp) + } + } + return ret + } + return f(xs[0], xs[1:]) +}