迟到的大学数据结构作业——堆排序


时隔好几年,才发现算法的趣味性,比如堆排序。堆排序同时也是大学时自己未完成的数据结构作业之一,就像《高等数学》一样之后都是抄抄了事……

个人理解,堆排序是一个有名的排序算法,比较适合优先级队列。虽然时间复杂度是O(nlogn),论排序算法中比较快的还是快速排序。堆排序的特点是使用完全或近似完全二叉树来模拟一种父节点肯定比子节点大(最大堆)或者小(最小堆)的结果,并通过多次“删除”根节点并恢复结构来达到“模拟”排序的目的。实践中需要关注的是三点:

如何构造最大/小堆

首先要明确的是最大/最小堆的基础是父节点比子节点大或者小。对于一组原始数据来说,大部分时候都是不满足这个条件的。为了满足这个条件,可以想象构造一个空的二叉堆,逐个往里面加元素。也可以考虑另外一种就地的方法:自底向上恢复结构。考虑用数组[1, 3, 8, 4, 6, 7, 2]构造最大堆。

   1
 3   8
4 6 7 2

尝试从最底层逐个元素恢复结构,2 -> 7 -> 6 -> 4,因为是最底层的元素,肯定不会做实际调整。接下来是第二层的元素,右边的(8, 7, 2)满足最大堆要求,但是左边的(3, 4, 6)不满足,调整左边的子树,得到如下结构:

   1
 6   8
4 3 7 2

接下来是顶层根节点,(1, 6, 8)肯定不满足要求,调整时同时向下传递,最后得到最大堆:

   8
 6   7
4 3 1 2

恢复结构的核心步骤

不管是构造最大/小堆还是删除元素时的调整,恢复结构的代码肯定是最核心的逻辑。删除时恢复结构的步骤在大学数据结构中有提到过,是类似shift_down的一种步骤,从上往下的。刚才的自底向上构造最大/最小堆时起始使用的也是shift_down。shift_down的原理是找出当前元素和其子节点(最多三个节点)中最大/小的元素,把最大/小的放在最上面,实际操作是交换最上面的节点和最大/小的节点。如果最上面的节点本来就是最大/小的,shift_down停止,否则继续shift_down被交换的子节点。代码如下:

private void shiftDown(int[] array, int heapSize, int index) {
  int maxIndex = index;
  int leftChildIndex = (index << 1) + 1;
  if (leftChildIndex >= heapSize) return; // no left child
  if (array[maxIndex] < array[leftChildIndex]) maxIndex = leftChildIndex;
  int rightChildIndex = (index << 1) + 2;
  if (rightChildIndex < heapSize && array[maxIndex] < array[rightChildIndex])
    maxIndex = rightChildIndex;
  if (maxIndex == index) return;
  swap(array, maxIndex, index);
  shiftDown(array, heapSize, maxIndex);
}

注意上面是从0开始计算的(基于数组的二叉堆)。那么左右子节点是按照2i + 1和2i + 2计算的。其次存在左右节点均不存在的情况(叶子节点),只有左节点情况(近似完全二叉树),两个子节点都存在,所以代码中需要判断计算的子节点的索引,如果超出了就认为节点不存在。如果左节点不存在直接可以退出,因为是叶子节点。

如何“删除”根节点

理论上来说,逐个删除根节点,恢复结构可以枚举出最大/小值。其中恢复结构是把最后一个元素放在根节点位置,shift_down根节点。另外为了实现就地,删除操作上有一点技巧。直接交换数组的第一个元素(根节点,索引0)和最后一个元素,代表删除根节点并把最后一个元素放在了根节点的位置。这样不断交换索引0位置的元素和二叉堆最后一个元素,最大堆就会变成升序序列。

最后,给出实现代码:

private void swap(int[] array, int i, int j) {
  int x = array[i];
  array[i] = array[j];
  array[j] = x;
}

private void shiftDown(int[] array, int heapSize, int index) {
  int maxIndex = index;
  int leftChildIndex = (index << 1) + 1;
  if (leftChildIndex >= heapSize) return; // no left child
  if (array[maxIndex] < array[leftChildIndex]) maxIndex = leftChildIndex;
  int rightChildIndex = (index << 1) + 2;
  if (rightChildIndex < heapSize && array[maxIndex] < array[rightChildIndex])
    maxIndex = rightChildIndex;
  if (maxIndex == index) return;
  swap(array, maxIndex, index);
  shiftDown(array, heapSize, maxIndex);
}

public void sort(int[] ns) {
  int length = ns.length;
  // build max heap
  for (int i = length - 1; i >= 0; i--) {
    shiftDown(ns, length, i);
  }

  // remove root element in-place
  for (int i = length - 1; i >= 0; i--) {
    swap(ns, 0, i);
    shiftDown(ns, i, 0);
  }
}

代码可能很简单,但是值得花时间理解一下。