Dijkstra非负加权有向图最短路径算法


成果在这里,前天完成的,花了大约半天的时间。参考资料主要是维基百科一篇博客

在分析自己代码前,先写一点自己对于Dijkstra算法的理解。

   u
s     v

假定上午的s, u, v分别是某个非负加权有向图的顶点,要求从s到v的最短路径。定义d(x)为从s到x定点的最短路径距离,则到v顶点的最短路径据路即是d(v)。同时定义从顶点x到顶点y的权值为w(x, y),假如存在从x到y的一条有向边,则w(x, y)即这条边的权值,否则w(x, y)无意义,或者说是无穷大。
Dijkstra算法认为:假如存在从u到v的一条有向边,并且d(u) + w(u, v) < d(v),则认为到v顶点最短路径经过u顶点。为了计算方法,可以认为d(v)原先为无穷大,实际计算的是所有到达点是v的u中最小的u。 单独看上述的理论时,个人理解为计算v到s的最短路径应该是一个回溯,更确切的讲应该是DP。但是Dijkstra算法实际为广度搜索,可以计算所有顶点的最短路径。我的目的是学习这个经典算法,所以就按照算法的步骤来。 以下是来自维基百科的伪代码

  function Dijkstra(G, w, s)
     for each vertex v in V[G]                        // 初始化
           d[v] := infinity                                 // 將各點的已知最短距離先設成無窮大
           previous[v] := undefined                         // 各点的已知最短路径上的前趋都未知
     d[s] := 0                                              // 因为出发点到出发点间不需移动任何距离,所以可以直接将s到s的最小距离设为0
     S := empty set
     Q := set of all vertices
     while Q is not an empty set                      // Dijkstra演算法主體
           u := Extract_Min(Q)
           S.append(u)
           for each edge outgoing from u as (u,v)
                  if d[v] > d[u] + w(u,v)             // 拓展边(u,v)。w(u,v)为从u到v的路径长度。
                        d[v] := d[u] + w(u,v)               // 更新路径长度到更小的那个和值。
                        previous[v] := u                    // 紀錄前趨頂點

老实说单独看这段伪代码不太好理解算法步骤,所以我又看了博客中的算法步骤图后才大致理解了。不过实际实现时还是遇到了一些问题,还是先从伪代码讲起。

伪代码中首先是初始化,设置所有d(x)为无穷大,先忽略路径回溯用的顶点数组previous,无穷大我实际设置为Integer.MAX_VALUE。
接下来是设置d(s) = 0,这个很好理解,从起点到起点的最短距离肯定是0。
S和Q是两个集合,分别保存已经访问过的顶点和未访问过的顶点。
extract_min是算法的重点之一,维基百科中的说明是取得d(x)最小的顶点。比如第一步时d(s)为0,其他d(x)都为无穷大,所以毫无疑问第一次u肯定为s。
接下来取得邻接u的顶点v集合,设置d(v),思路和之前我理解的Dijkstra算法的思想类似。实际d(v)一开始为无穷大,根据边拓展得到可用的路径和相应的路径。这里会有一个疑问,拓展得到的d(v)肯定是最短的么?严格来说是不保证的,考虑如下情况:

s -(w0)-> v0
 \       /
  \     /
  (w1)(w10)
    \ /
    v1

存在邻接表: s -> v0, w0; s -> v1, w1; v1-> v0, w10
一开始从s的拓展会更新d(v0) = w0,d(v1) = w1。因为从v0的拓展不会更新d(v0)本身,如果w0 > w1 + w10,即实际d(v0)应该为w1 + w10,那么算法就应该是错的。实际中算法还会拓展v1,比较d(v0) 和 d(v1) + w10的关系,如果比较了就可以得到正确的最短路径。但是拓展v0和v1存在先后关系,假设先拓展v0,算法就会出问题。于是焦点就变为extract_min步骤,如果先拓展v0,那么算法就有问题。但是w0 > w1 + w10 并且w为非负数 => w0 > w1 => d(v0) > d(v1) => extract_min(q) 结果为v1,即实际不会先拓展v0。换句话说,一次替换d(v)不会得到最小值,但和extract_min一起保证能取到最小值。

从上面的伪代码翻译为实际代码主要关注两部分:第一部分是图的表示,我选用邻接表,第二部分是综合extract_min, d[v]获取和修改的顶点距离。
针对上述伪代码的邻接表接口定义为:

List<Integer> getEndVectorFrom(int startVector)
int weight(int startVector, int endVector)

两个函数分别针对while中for的数据来源和w(u, v)。内部使用哈希表,保存映射(start vector) -> (end vector, weight)。具体代码如下,输入为(u, v, w)的有向边和权值的表示方法:

static class AdjacencyList {

  // start -> [(end, weight)]
  private Map<Integer, List<VectorAndWeight>> connection =
      new HashMap<Integer, List<VectorAndWeight>>();

  AdjacencyList(int[][] adjacencyList) {
    for (int[] edge : adjacencyList) {
      // start, end, weight
      int startVector = edge[0];
      if (!connection.containsKey(startVector)) {
        connection.put(startVector, new ArrayList<VectorAndWeight>());
      }
      connection.get(startVector).add(new VectorAndWeight(edge[1], edge[2]));
    }
  }

  List<Integer> getEndVectorFrom(int startVector) {
    List<Integer> vectors = new ArrayList<Integer>();
    if (connection.containsKey(startVector)) {
      for (VectorAndWeight item : connection.get(startVector)) {
        vectors.add(item.vector);
      }
    }
    return vectors;
  }

  int weight(int startVector, int endVector) {
    if (connection.containsKey(startVector)) {
      for (VectorAndWeight item : connection.get(startVector)) {
        if (item.vector == endVector) return item.weight;
      }
    }
    throw new IllegalArgumentException("no edge between vector ["
        + startVector + "] and vector [" + endVector + "]");
  }

}

static class VectorAndWeight {

  int vector;
  int weight;

  VectorAndWeight(int vector, int weight) {
    this.vector = vector;
    this.weight = weight;
  }

}

顶点距离相对难实现一些,伪代码中需要的功能是:

boolean isEmpty();
VectorAndDistance removeMin();
void updateDistance(int vector, int newDistance);
int getDistance(int vector);

个人理解为最小堆移除功能,又要随机访问功能。其次,理论上移除后顶点到了另外一个集合,但是设置距离要跨越集合。感觉是一个很纠结的数据结构?简单起见,我就用数组实现。同时,为了区分顶点是否访问过,我增加了accessed属性。

class SimpleVectorDistanceSet implements VectorDistanceSet {

  private DistanceAndStatus[] elements;
  private int size;

  SimpleVectorDistanceSet(int vectorCount) {
    if (vectorCount < 1)
      throw new IllegalArgumentException("vector count must >= 1");

    elements = new DistanceAndStatus[size = vectorCount];
    for (int i = 0; i < size; i++) {
      elements[i] = new DistanceAndStatus(Integer.MAX_VALUE, false);
    }
  }

  public boolean isEmpty() {
    return size == 0;
  }

  public VectorAndDistance removeMin() {
    if (isEmpty()) throw new IllegalStateException("empty set");
    int minDistance = Integer.MAX_VALUE;
    int minVector = -1;
    for (int i = 0; i < elements.length; i++) {
      if (!elements[i].accessed && elements[i].distance <= minDistance) {
        minDistance = elements[i].distance;
        minVector = i;
      }
    }
    elements[minVector].accessed = true;
    size--;
    return new VectorAndDistance(minVector, minDistance);
  }

  public void updateDistance(int vector, int newDistance) {
    elements[vector].distance = newDistance;
  }

  public int getDistance(int vector) {
    return elements[vector].distance;
  }

}

class DistanceAndStatus {

  int distance;
  boolean accessed;

  DistanceAndStatus(int distance, boolean accessed) {
    this.distance = distance;
    this.accessed = accessed;
  }

}

设计好这两个主要部分之后,原先的伪代码可以转化为:

public void findMinDistance(AdjacencyList adjacencyList, int vectorCount,
      int startVector, int endVector) {
    VectorDistanceSet distanceSet = new SimpleVectorDistanceSet(vectorCount);
    int[] previousLink = new int[vectorCount];

    distanceSet.updateDistance(startVector, 0);
    while (!distanceSet.isEmpty()) {
      VectorAndDistance vd = distanceSet.removeMin();
      int u = vd.vector;
      if (u == endVector) break;
      for (int v : adjacencyList.getEndVectorFrom(u)) {
        int weight = adjacencyList.weight(u, v);
        if (distanceSet.getDistance(v) > vd.distance + weight) {
          distanceSet.updateDistance(v, vd.distance + weight);
          previousLink[v] = u;
        }
      }
    }

    int minDistance = distanceSet.getDistance(endVector);
    System.out.println(minDistance);
    if (minDistance == Integer.MAX_VALUE) return; // no path
    printPath(startVector, endVector, previousLink);
  }

上述代码有一个之前未提到的点,Dijkstra虽然能获取到所有点的最短距离,但是有时候我们只需要指定两点就行,所以在extract_min是指定的目标顶点时,就可以跳出。(来自维基百科的提示)
测试代码

AdjacencyList adjacencyList =
    new AdjacencyList(new int[][] { {0, 1, 6}, {0, 2, 3}, {1, 2, 2},
        {1, 3, 5}, {2, 3, 3}, {2, 4, 4}, {3, 4, 2}, {3, 5, 3}, {4, 5, 5}});
Dijkstra dijkstra = new Dijkstra();
dijkstra.findMinDistance(adjacencyList, 6, 0, 4);

除了获取最短路径距离之外,如何记录最短路径也是需要考虑的。因为最短路径距离肯定关联一个前向点,每个前向点的最短距离下也肯定有前向点。所以代码中只要记录每个最短路径距离时的前向点即可。即上述代码中的previousLink。相应的,恢复路径需要回溯这个前向数组。

private void printPath(int startVector, int endVector, int[] previousLink) {
  LinkedList<Integer> stack = new LinkedList<Integer>();
  stack.push(endVector);
  int vx = endVector;
  while (previousLink[vx] != startVector) {
    stack.push(previousLink[vx]);
    vx = previousLink[vx];
  }
  stack.push(startVector);
  System.out.println(stack);
}

这里使用了栈来回溯,如果两点之间不连接的话,在findMinDistance中就已经停止了,不会再尝试恢复最短路径。

最后,小结一下个人感受:Dijkstra是一个涉及多种数据结构的算法,想要完全了解他个人还需要进一步深化数据结构的认识。