海量积分排序的二分实现


简单描述下问题:有一个网站,现在有很多用户,每个用户有一个积分,要求在用户登录时显示用户的积分排名,用户的积分会有变动。注意考虑数据量比较大的场景。

原理是以树的方式来分解排名的计算。参考在这里
实际代码分为三部分:

  • 查询某个积分的排名
  • 更新积分
  • 构造积分树

查询积分排名

原理是计算所有右子树或者同级节点的和,再加1。每个节点包含当前积分或者积分段内的用户总数。右子树或者同级节点代表大于指定积分的积分节点。排名的实际含义就是计算所有大于自己积分的人的总数。最后加1是因为大于最大积分的人的总数为0,但是排名不能为0,所以实际显示是必须加1的。
这块实际依赖树的结构,可能是数组或者是链表,但通用逻辑如下(递归):

in: current node, score
if current node is score node
  if current node is left leaf
    return count of right sibling
  else # current node is right node, no right sibling
    return 0
else # current node is score range node
  if score in left subtree
    return count of score in left subtree + count of right subtree
  else # score in right subtree
    return count of score in right subtree

更新积分

同样是递归逻辑,伪代码如下:

in: current node, score, delta
if current node is score node
  count of current node += delta
else # current node is score range node
  count cache of current node += delta
  if score in left subtree
    update count of left subtree
  else # in right subtree
    update count of right subtree
  

顺便说一句,某个用户积分变化,实际上积分树的变化是分两步:减少旧分数的总数,增加新积分的总数。

构造排名树

理论上树的结构应该先定义,不过这边我是功能优先设计的。
首先,每个节点必须有count,积分节点的count就是拥有某个积分的总数,积分区间节点的count是区间内所有积分节点的总数和,是个缓存值。
查询时需要一个同级右节点,还有左右子树。
为了防止意外的分数,需要提供一个最大最小积分便于检查。
一开始我在python中用数组实现的,后来觉得pattern match好像比较适合,就用scala重写了。

最终代码如下:

abstract class AbstractScoreNode {
  var rightSibling: Option[AbstractScoreNode] = None
  def getCount(): Int

  val minScore: Int
  val maxScore: Int
}

class ScoreNode(val score: Int, var count: Int) extends AbstractScoreNode {
  def getCount() = count
  def updateCount(delta: Int): Unit = count += delta

  val minScore = score
  val maxScore = score

  override def toString(): String = "ScoreNode(score = %d, count = %d)".format(score, count)
}

class ScoreRangeNode(val leftChild: AbstractScoreNode, val rightChild: AbstractScoreNode) extends AbstractScoreNode {
  leftChild.rightSibling = Some(rightChild)

  private var countCache = leftChild.getCount() + rightChild.getCount()
  def getCount() = countCache
  def updateCountCache(delta: Int): Unit = countCache += delta

  val minScore: Int = leftChild match {
    case x: ScoreNode => x.score
    case x: ScoreRangeNode => x.minScore
  }
  val maxScore: Int = rightChild match {
    case x: ScoreNode => x.score
    case x: ScoreRangeNode => x.maxScore
  }
  val pivotScore: Int = (minScore + maxScore) / 2

  override def toString(): String = "ScoreRangeNode(min = %d, max =%d, count = %d)".format(minScore, maxScore, getCount())
}

object ScoreRanking {
  def apply(maxScore: Int): ScoreRanking = apply(0, maxScore)
  def apply(minScore: Int, maxScore: Int): ScoreRanking = {
    if(minScore > maxScore) throw new IllegalArgumentException("min score must less than max score")
    new ScoreRanking(createNode(minScore, maxScore))
  }
  private def createNode(minScore: Int, maxScore: Int): AbstractScoreNode = {
    if(minScore == maxScore) return new ScoreNode(minScore, 0)
    val pivot = (minScore + maxScore) / 2
    new ScoreRangeNode(createNode(minScore, pivot), createNode(pivot + 1, maxScore))
  }
}

class ScoreRanking(root: AbstractScoreNode) {
  def getRankingOf(score: Int): Int = {checkScore(score); getCount(root, score) + 1 } 
  private def checkScore(score: Int): Unit = {
    if(score < root.minScore || score > root.maxScore) {
      throw new IllegalArgumentException("score should be in [%d, %d]".format(root.minScore, root.maxScore))
    }
  }
  private def getCount(node: AbstractScoreNode, score: Int): Int = {
    node match { 
      case x: ScoreNode => x.rightSibling match {
        case Some(rightSibling) => getCount(rightSibling, score)
        case _ => 0
      }
      case x: ScoreRangeNode => if(score <= x.pivotScore) {
        getCount(x.leftChild, score) + x.rightChild.getCount()
      } else {
        getCount(x.rightChild, score)
      }
    }
  }
  def increaseCountOf(score: Int): Unit = {checkScore(score); updateCountOf(root, score, +1)}
  def decreaseCountOf(score: Int): Unit = {checkScore(score); updateCountOf(root, score, -1)}
  private def updateCountOf(node: AbstractScoreNode, score: Int, delta: Int): Unit = {
    node match {
      case x: ScoreNode => x.updateCount(delta)
      case x: ScoreRangeNode => {
        if(score <= x.pivotScore) {
          updateCountOf(x.leftChild, score, delta)
        } else {
          updateCountOf(x.rightChild, score, delta)
        }
        x.updateCountCache(delta)
      }
    }
  }
  override def toString(): String = root.toString()
}

object Ranking {
  def main(args: Array[String]): Unit = {
    val scoreRanking = ScoreRanking(3)
    scoreRanking.increaseCountOf(1)
    scoreRanking.increaseCountOf(2)
    scoreRanking.increaseCountOf(3)
    scoreRanking.increaseCountOf(0)
    scoreRanking.increaseCountOf(3)
    println(scoreRanking.getRankingOf(0));
    println(scoreRanking.getRankingOf(1));
    println(scoreRanking.getRankingOf(2));
    println(scoreRanking.getRankingOf(3));
  }
}

主逻辑:构造[0, 3]的积分排名表,导入用户积分1, 2, 3, 0, 3,获取积分为0, 1, 2, 3的人的排名。以下是执行输出

$ scala Ranking
5
4
3
1