本文是Java并发博客的第四篇。按照同步器和并发数据结构交替的顺序,本次是并发数据结构相关。之前介绍ConcurrentLinkedQueue的时候也提到过,并发数据结构不涉及同步器特有的问题,所以相对简单一些。分析的重点在于数据结构本身。
这次的主题是ConcurrentSkipListMap。老实说个人一开始也不知道这个类,既然有ConcurrentHashMap为什么还要有ConcurrentSkipListMap?单线程环境下的TreeMap用的次数本来就不多,并发环境下带排序的Map用得就更少了。话虽这么说,但在阅读了《the art of multiprocessor programming》第14章之后,个人发现ConcurrentSkipListMap是一个非常好的学习范本,特别是Java下HM Linked List(或者叫Harris Linked List)的实际产品代码。
如果你阅读过一些分析ConcurrentSkipListMap代码的文章的话,你可能会知道SkipList,因为名字里面就包括SkipList嘛。但实际上ConcurrentSkipListMap中维护正确性的不是SkipList,而是最底下的那一层HM Linked List。换句话说核心是在HM Linked List上,而不是帮助快速访问的SkipList。
个人建议在阅读ConcurrentSkipListMap的代码之前,了解以下内容
- HM Linked List
- HM Linked List基于marker节点的优化方式
- 普通的(非并发环境下的)SkipList
HM Linked List是一个比较有名的并发环境下的单向链表实现,有兴趣的人可以阅读相关论文。HM Linked List的实现要求,节点能够一次CAS marker和next两个字段。在Java里面对应AtomicMarkableReference。典型实现(《the art of multiprocessor programming》第9章)如下
import java.util.concurrent.atomic.AtomicMarkableReference; /** * HM linked list based on {@code AtomicMarkableReference}. * * @param <T> */ public class ListBasedSet1<T> { private final Node<T> head; public ListBasedSet1() { head = new Node<>(Integer.MIN_VALUE, null, new Node<>(Integer.MAX_VALUE)); } public boolean add(T item) { final int key = item.hashCode(); Window<T> window; Node<T> newNode; while (true) { window = find(head, key); // same key if (window.current.key == key) { return false; } newNode = new Node<>(key, item, window.current); if (window.predecessor.nextAndMark.compareAndSet( window.current, newNode, false, false)) { return true; } // retry } } public boolean remove(T item) { final int key = item.hashCode(); Window<T> window; Node<T> successor; while (true) { window = find(head, key); // not found if (window.current.key != key) { return false; } successor = window.current.nextAndMark.getReference(); // logical delete if (window.current.nextAndMark.attemptMark(successor, true)) { // physical delete window.predecessor.nextAndMark.compareAndSet( window.current, successor, false, false); // fail is ok return true; } // retry } } public boolean contains(T item) { int key = item.hashCode(); Node<T> current = head.next(); while (current.key < key) { current = current.next(); } return current.key == key && !current.isMarked(); } private Window<T> find(Node<T> head, int key) { boolean[] markedHolder = {false}; boolean snip; retry: for (Node<T> predecessor = head, current = predecessor.next(), successor; ; ) { while (true) { successor = current.nextAndMark.get(markedHolder); // current node is deleted while (markedHolder[0]) { snip = predecessor.nextAndMark.compareAndSet( current, successor, false, false); if (!snip) { continue retry; } current = successor; successor = current.nextAndMark.get(markedHolder); } if (current.key >= key) { return new Window<>(predecessor, current); } predecessor = current; current = successor; } } } private static class Window<T> { private final Node<T> predecessor; private final Node<T> current; Window(Node<T> predecessor, Node<T> current) { this.predecessor = predecessor; this.current = current; } } private static class Node<T> { private final int key; private final T item; private final AtomicMarkableReference<Node<T>> nextAndMark; Node(int key) { this.key = key; this.item = null; nextAndMark = new AtomicMarkableReference<>(null, false); } Node(int key, T item, Node<T> next) { this.key = key; this.item = item; nextAndMark = new AtomicMarkableReference<>(next, false); } boolean isMarked() { return nextAndMark.isMarked(); } Node<T> next() { return nextAndMark.getReference(); } } }
这里是一个ListBasedSet,即基于链表的Set。实现了contains/add/remove三个方法。基本想法是构建一个排序的链表,然后从起始节点遍历,找到指定节点(contains, remove)或者位置插入节点(add)。
你可能认为这是一个效率不高的Set实现,因为最坏情况下的时间复杂度是O(N)。但是加上SkipList之后,时间复杂度会变成O(Log(N))。这也是ConcurrentSkipListMap内部所做的事情。
回到ListBasedSet。以上实现的正确性如何证明?或者说为什么单线程环境下的单向链表(只有next指针)在并发环境下会出问题?具体出什么问题?ListBasedSet是如何解决的?
考虑如下单向链表
a -> b -> c -> d
假设只对next指针CAS的情况,线程1想把b节点删除,同时线程2想删除c。运行中出现如下情况
thread 1: CAS(a.next, b, b.next) thread 2: CAS(b.next, c, c.next)
两次CAS都能成功,但是c没有被删除,结果不正确。其他还有邻接节点的remove与add也会出现这种某个操作被“忽略”的情况(理论上ConcurrentSkipListMap的Index会出现这种情况,但是这可以被接受,因为正确性是由HM Linked List保证的)。表面原因是CAS了不同next指针,实际问题在于CAS next指针时无法保证参与执行的节点没有被删除。所以HM Linked List要求同时CAS next指针的节点的marker
thread 1: CAS(b, [b.next, false], [b.next, true]) thread 1: CAS(a, [b, false], [b.next, false]) thread 2: CAS(c, [c.next, false], [c.next, true]) thread 2: CAS(b, [c, false], [c.next, false])
这里方括号表示next指针值和marker(true表示被删除,false表示存在)。再次执行后,最后一个CAS会出错。因为b被mark了。具体的正确性证明建议看原论文或者《the art of multiprocessor programming》。
代码上体现上述过程的可以看remove部分的代码。第一次CAS是一个逻辑删除,之后的CAS才是真正的删除。实际实现中,只要求第一个CAS即逻辑删除成功,物理删除允许失败。原因在于add/remove中调用的find看到逻辑删除的节点,会帮助物理删除(当然,如果没有其他线程帮忙的话,物理删除的CAS肯定会成功)。其次,逻辑删除失败时,代码会从起始节点重新开始遍历和尝试删除。逻辑删除失败,一般来说是因为同一个节点被两个线程同时删除。理论上结果必须一个成功,一个失败(这是另外一个只CAS next指针时出现的问题,因为不凑巧的话两次都会返回true)。速度比较慢的线程重试时调用find,内部帮忙删除或者早已被另外一个线程删除之后,访问不到指定的节点肯定会返回false。所以这么做是正确的。
虽然HM Linked List被证明是正确的,实际实现还是不太理想。特别是AtomicMarkableReference在每次变更时会创建新的对象。C/C++可能可以通过操作next指针的某个bit来模拟marker,但是何时回收链表的节点是一个问题(remove的同时有一个遍历的线程的话你不能简单地删除节点)。
ConcurrentSkipListMap中给出的一个优化方案是在需要删除的节点的后面增加一个marker节点(设置节点值为null是针对其他问题的设计,请不要搞混)。假如我要删除以下链表的节点b,我会在b后面增加一个marker节点。
a -> b -> c a -> b -> marker -> c
只在删除的时候增加节点的话,比起原来AtomicMarkableReference的方式要好很多。但是代码会变得比较复杂。原来只需要看a, b, c三个节点,现在要看a, b, marker, c四个节点。
使用marker的参考实现
import java.util.concurrent.atomic.AtomicReference; /** * HM linked list using marker node. * * @param <T> element type */ public class ListBasedSet2<T> { private final Node<T> head; ListBasedSet2() { Node<T> tail = new Node<>(NodeKind.NORMAL, Integer.MAX_VALUE, null, null); head = new Node<>(NodeKind.NORMAL, Integer.MIN_VALUE, null, tail); } public boolean contains(T item) { final int key = item.hashCode(); Node<T> current = head.next(); boolean[] markHolder = {false}; while (current.key < key) { current = current.next(markHolder); // skip marker node assert current != null; } return current.key == key && !markHolder[0]; } @SuppressWarnings("unchecked") public boolean add(T item) { final int key = item.hashCode(); Node<T>[] nodes = (Node<T>[]) new Node[3]; while (true) { // same key if (find(key, nodes)) { return false; } // p -> (n) -> c -> s Node<T> newNode = new Node<>(NodeKind.NORMAL, key, item, nodes[1]); if (nodes[0].next.compareAndSet(nodes[1], newNode)) { return true; } } } private boolean find(int key, Node<T>[] nodes) { Node<T> predecessor; Node<T> current; Node<T> successor; boolean[] markHolder = {false}; boolean snip; restart: for (predecessor = head, current = predecessor.next(); ; ) { assert current != null; successor = current.next(markHolder); while (markHolder[0]) { assert successor != null; snip = predecessor.next.compareAndSet(current, successor); if (!snip) { continue restart; } current = successor; successor = current.next(markHolder); } if (current.key < key) { predecessor = current; current = successor; } else { break; } } nodes[0] = predecessor; nodes[1] = current; nodes[2] = successor; return current.key == key; } @SuppressWarnings("unchecked") public boolean remove(T item) { final int key = item.hashCode(); Node<T>[] nodes = (Node<T>[]) new Node[3]; while (true) { if (!find(key, nodes)) { return false; } // p -> [c -> (m)] -> s // p -> s Node<T> marker = new Node<>(NodeKind.MARKER, 0, null, nodes[2]); if (nodes[1].next.compareAndSet(nodes[2], marker)) { nodes[0].next.compareAndSet(nodes[1], nodes[2]); return true; } } } @Override public String toString() { return "ListBasedSet2{" + "head=" + head + '}'; } private enum NodeKind { NORMAL, MARKER; } private static final class Node<T> { final NodeKind kind; final int key; final T item; final AtomicReference<Node<T>> next; Node(NodeKind kind, int key, T item, Node<T> next) { this.kind = kind; this.key = key; this.item = item; this.next = new AtomicReference<>(next); } Node<T> next() { return next.get(); } Node<T> next(boolean[] markHolder) { Node<T> successor = next.get(); if (successor == null) { // tail return null; } if (successor.kind == NodeKind.MARKER) { markHolder[0] = true; return successor.next(); } markHolder[0] = false; return successor; } @Override public String toString() { StringBuilder builder = new StringBuilder("Node{"); if (kind == NodeKind.NORMAL) { builder.append("kind=NORMAL,"); builder.append("key=").append(key).append(','); builder.append("item=").append(item).append(','); } else { builder.append("kind=MARKER,"); } builder.append("next=\n").append(next.get()).append('}'); return builder.toString(); } } }
为了保持ListBasedSet1差不多的代码,ListBasedSet2用了一个签名为next(boolean[] markHolder)的方法,帮助跳过marker节点。同时,ListBasedSet2为了区分普通和marker节点,增加了NodeKind枚举。marker节点的内容无意义,所以简单的设置为null。注意,你需要很小心,不能在一个普通节点后面插入两个marker节点,也不能在普通节点和之后的marker节点中间插入节点。
你可能会注意到ListBasedSet1和ListBasedSet2使用了head和tail两个哨兵节点(sentinel node),以及按照hashCode大小排序。实际代码中tail不会被用到,其次由用户传入一个比较器(comparator)更好。以下是不使用tail和允许用户传入比较器的实现
import java.util.Comparator; import java.util.concurrent.atomic.AtomicReference; /** * 1. HM linked list using marker node * 2. no tail node * 3. comparator * * @param <T> element type */ @SuppressWarnings("Duplicates") public class ListBasedSet3<T> { private final Comparator<T> comparator; private final Node<T> head; public ListBasedSet3(Comparator<T> comparator) { this.comparator = comparator; head = new Node<>(NodeKind.HEAD, null, null); } private int compareItem(Node<T> node1, T item) { switch (node1.kind) { case HEAD: return -1; case NORMAL: return comparator.compare(node1.item, item); default: throw new IllegalStateException("cannot compareItem"); } } public boolean contains(T item) { int c; for (Node<T> current = head.next.get(), successor; ; ) { if (current == null) { // predecessor is last node return false; } // current.kind never be head assert current.kind != NodeKind.HEAD; successor = current.next.get(); if (current.kind != NodeKind.MARKER) { // skip marker c = compareItem(current, item); if (c == 0) { // current.item == item // current.kind = normal // item present if current is not marked return successor == null || successor.kind != NodeKind.MARKER; } if (c > 0) { // current.item > item // not found return false; } } // 1. current.kind == marker // 2. current.item < item // go next current = successor; } } public boolean add(T item) { int c; boolean[] markHolder = {false}; restart: while (true) { for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) { if (current == null) { // no more node if (insert(predecessor, item, null)) { return true; } continue restart; } successor = current.next(markHolder); // current is deleted while (markHolder[0]) { if (!predecessor.next.compareAndSet(current, successor)) { continue restart; } current = successor; if (current == null) { // no more node if (insert(predecessor, item, null)) { return true; } continue restart; } successor = current.next(markHolder); } c = compareItem(current, item); if (c == 0) { // same key, item will not be replaced return false; } if (c > 0) { // current.item < item if (insert(predecessor, item, successor)) { return true; } continue restart; } predecessor = current; current = successor; } } } private boolean insert(Node<T> predecessor, T item, Node<T> successor) { Node<T> newNode = new Node<>(NodeKind.NORMAL, item, successor); return predecessor.next.compareAndSet(successor, newNode); } public boolean remove(T item) { int c; boolean[] markHolder = {false}; Node<T> marker; restart: while (true) { for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) { if (current == null) { // no more node return false; } successor = current.next(markHolder); while (markHolder[0]) { if (!predecessor.next.compareAndSet(current, successor)) { continue restart; } current = successor; if (current == null) { // no more node return false; } successor = current.next(markHolder); } c = compareItem(current, item); if (c == 0) { marker = new Node<>(NodeKind.MARKER, null, successor); // logical remove if (current.next.compareAndSet(successor, marker)) { // physical remove predecessor.next.compareAndSet(current, successor); return true; } continue restart; } if (c > 0) { // not found return false; } predecessor = current; current = successor; } } } private enum NodeKind { HEAD, NORMAL, MARKER } private static final class Node<T> { final NodeKind kind; final T item; final AtomicReference<Node<T>> next; Node(NodeKind kind, T item, Node<T> next) { this.kind = kind; this.item = item; this.next = new AtomicReference<>(next); } Node<T> next(boolean[] markHolder) { Node<T> successor = next.get(); if (successor == null) { // last element markHolder[0] = false; return null; } if (successor.kind == NodeKind.MARKER) { markHolder[0] = true; return successor.next.get(); } markHolder[0] = false; return successor; } } }
以上代码比较麻烦的是,因为了没有了tail,你需要时刻注意当前节点current是否为null。
以下的ListBasedSet4是个人写的最终版本,与前面的类不同的地方在于允许替换值。注意,HM Linked List默认在碰到相等的值的时候,是不会对节点做任何操作的。在Set的范畴内是正确的,但是在Map中,这很难保证。用户可能会传入相同的key,但是不同的value。所以,你需要一种方式能够允许用户替换值。
在HM Linked List中,直接替换值会碰到一个问题:替换了逻辑删除的节点的值,考虑如下执行序列
thread 1: logical delete node with key 1, value xxx thread 2: replace xxx with yyy thread 1: physical delete node with key 1
理论上,逻辑删除了的节点你是不能去替换的。但是现有策略无法给以上这种执行策略一个先后顺序,所以,ConcurrentSkipListMap提供了第二个逻辑删除的策略,即设置CAS值为null。即删除过程变成了三步
- CAS item not null -> null
- append marker
- CAS predecessor’s next pointer to successor
这样做的好处是,第一步之后,替换的线程会失败。坏处是null不能作为节点值,以及代码会更加复杂。
import com.google.common.base.Preconditions; import java.util.Comparator; import java.util.concurrent.atomic.AtomicReference; /** * 1. HM linked list using marker node * 2. no tail node * 3. comparator * 4. replaceable * * @param <T> element type */ @SuppressWarnings("Duplicates") public class ListBasedSet4<T> { private final Comparator<T> comparator; private final Node<T> head; public ListBasedSet4(Comparator<T> comparator) { this.comparator = comparator; this.head = new Node<>(NodeKind.NORMAL, null, null); } public boolean contains(T item) { Preconditions.checkNotNull(item); int c; T value; for (Node<T> current = head.next.get(), successor; ; ) { if (current == null) { return false; } value = current.item.get(); if (value == null) { // current is deleted successor = current.next.get(); // skip marker if present if (successor != null && successor.kind == NodeKind.MARKER) { successor = successor.next.get(); } current = successor; continue; } c = comparator.compare(value, item); if (c == 0) { return true; } if (c > 0) { // current.item > item return false; } current = current.next.get(); } } public boolean add(T item) { Preconditions.checkNotNull(item); int c; T value; Node<T> newNode; while (true) { for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) { if (current == null) { // no more node newNode = new Node<>(NodeKind.NORMAL, item, null); if (predecessor.casNext(null, newNode)) { return true; } break; // restart } value = current.item.get(); if (value == null) { // current is deleted successor = current.nextNotMarker(); if (predecessor.casNext(current, successor)) { current = successor; continue; } else { break; } } c = comparator.compare(value, item); if (c == 0) { // same key, replace if (current.casItem(value, item)) { return true; } break; } successor = current.next.get(); if (c > 0) { // current.item > item newNode = new Node<>(NodeKind.NORMAL, item, successor); if (predecessor.casNext(current, newNode)) { return true; } break; } predecessor = current; current = successor; } } } public boolean remove(T item) { Preconditions.checkNotNull(item); int c; T value; Node<T> marker; while (true) { for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) { if (current == null) { // no more node return false; } value = current.item.get(); if (value == null) { // current is deleted successor = current.nextNotMarker(); if (predecessor.casNext(current, successor)) { current = successor; continue; } else { break; // restart } } c = comparator.compare(value, item); if (c == 0) { // the node to remove if (!current.casItem(value, null)) { break; } // successor never be a marker node successor = current.next.get(); marker = new Node<>(NodeKind.MARKER, null, successor); if (current.casNext(successor, marker)) { predecessor.casNext(current, successor); return true; } // successor may be deleted at the same time break; } if (c > 0) { // not found return false; } predecessor = current; // current may be deleted at this point current = current.nextNotMarker(); } } } private enum NodeKind { NORMAL, MARKER } @SuppressWarnings("Duplicates") private static final class Node<T> { final NodeKind kind; final AtomicReference<T> item; final AtomicReference<Node<T>> next; Node(NodeKind kind, T item, Node<T> next) { this.kind = kind; this.item = new AtomicReference<>(item); this.next = new AtomicReference<>(next); } Node<T> nextNotMarker() { Node<T> successor = next.get(); if (successor == null) { return null; } if (successor.kind == NodeKind.MARKER) { return successor.next.get(); } return successor; } boolean casItem(T expect, T update) { return item.compareAndSet(expect, update); } boolean casNext(Node<T> expect, Node<T> update) { return next.compareAndSet(expect, update); } } }
你可以看到在所有值比较之前,必须先检查当前值是否null。以上代码实际上已经很接近ConcurrentSkipListMap里面Node的用法了。在看如何套上SkipList代码之前,看一下SkipList的基本实现
import java.util.Random; public class SkipList<T> { private static final int MAX_LEVEL = 5; private static final int NOT_FOUND = -1; private final Random random = new Random(); private final Node<T> head; public SkipList() { head = new Node<>(Integer.MIN_VALUE); final Node<T> tail = new Node<>(Integer.MAX_VALUE); for (int level = 0; level < MAX_LEVEL; level++) { head.next[level] = tail; } } public boolean contains(T x) { int key = x.hashCode(); Node<T> predecessor = head; Node<T> current; for (int level = MAX_LEVEL; level >= 0; level--) { current = predecessor.next[level]; while (current.key < key) { predecessor = current; current = current.next[level]; } if (current.key == key) { return true; } } return false; } private int find(T x, Node<T>[] predecessors, Node<T>[] successors) { int key = x.hashCode(); int foundAtLevel = NOT_FOUND; Node<T> predecessor = head; Node<T> current; for (int level = MAX_LEVEL - 1; level >= 0; level--) { current = predecessor.next[level]; while (current.key < key) { predecessor = current; current = current.next[level]; } predecessors[level] = predecessor; successors[level] = current; if (foundAtLevel == NOT_FOUND && current.key == key) { foundAtLevel = level; } } return foundAtLevel; } @SuppressWarnings("unchecked") public boolean add(T x) { Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL]; Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL]; int foundAtLevel = find(x, predecessors, successors); if (foundAtLevel != NOT_FOUND) { return false; } int topLevel = random.nextInt(MAX_LEVEL); Node<T> node = new Node<>(x.hashCode(), x, topLevel); for (int level = 0; level < topLevel; level++) { node.next[level] = successors[level]; predecessors[level].next[level] = node; } return true; } @SuppressWarnings("unchecked") public boolean remove(T x) { Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL]; Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL]; int foundAtLevel = find(x, predecessors, successors); if (foundAtLevel == NOT_FOUND) { return false; } Node<T> node = successors[foundAtLevel]; for (int level = 0; level <= foundAtLevel; level++) { predecessors[level].next[level] = node.next[level]; } return true; } private static class Node<T> { private final int key; private final T value; private final Node<T>[] next; @SuppressWarnings("unchecked") Node(int key) { this.key = key; value = null; next = (Node<T>[]) new Node[MAX_LEVEL]; } // topLevel >= 1 && topLevel <= MAX_LEVEL @SuppressWarnings("unchecked") Node(int key, T value, int topLevel) { this.key = key; this.value = value; next = (Node<T>[]) new Node[topLevel]; } int topLevel() { return next.length; } @Override public String toString() { return "Node{" + "key=" + key + ", item=" + value + '}'; } } }
以上是《the art of multiprocessor programming》第14章使用的SkipList的简化版。老实说比ConcurrentSkipListMap简单太多了。有tail,不用一直检查null。最高层固定,使用随机数决定节点层数。
接下来是第14章中最终版本的SkipListBasedSet,使用HM Linked List(ListBasedSet1)+SkipList。
import com.google.common.base.Preconditions; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicMarkableReference; @SuppressWarnings("Duplicates") public class SkipListSet1<T> { private static final int MAX_LEVEL = 5; private final ThreadLocalRandom random = ThreadLocalRandom.current(); private final Node<T> head; public SkipListSet1() { head = new Node<>(null, Integer.MIN_VALUE, MAX_LEVEL); final Node<T> tail = new Node<>(null, Integer.MAX_VALUE, MAX_LEVEL); for (int i = 0; i < MAX_LEVEL; i++) { head.next[i] = new AtomicMarkableReference<>(tail, false); } } /** * Test if element in set. * * @param x element * @return true if present, otherwise false */ public boolean contains(T x) { Preconditions.checkNotNull(x); final int key = x.hashCode(); Node<T> predecessor = head; Node<T> current = null; Node<T> successor; boolean[] markHolder = new boolean[1]; for (int i = MAX_LEVEL - 1; i >= 0; i--) { current = predecessor.next(i); while (true) { successor = current.next[i].get(markHolder); while (markHolder[0]) { current = successor; successor = current.next[i].get(markHolder); } if (current.key < key) { predecessor = current; current = current.next(i); } else { break; } } } return current.key == key; } /** * Add element. * * @param x element * @return true if added, false if present */ @SuppressWarnings("unchecked") public boolean add(T x) { Preconditions.checkNotNull(x); final int key = x.hashCode(); final int level = random.nextInt(MAX_LEVEL) + 1; final Node<T> node = new Node<>(x, key, level); Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL]; Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL]; boolean success; while (true) { if (find(key, predecessors, successors)) { return false; } node.next[0] = new AtomicMarkableReference<>(successors[0], false); success = predecessors[0].next[0].compareAndSet(successors[0], node, false, false); if (!success) { continue; // retry } // node added to skip list if success for (int i = 1; i < level; i++) { while (true) { node.next[i] = new AtomicMarkableReference<>(successors[i], false); success = predecessors[i].next[i].compareAndSet(successors[i], node, false, false); if (success) { break; } else { find(key, predecessors, successors); } } } return true; } } /** * Find element in skip list and return predecessors and successors. * * @param key key * @return true if found, otherwise false */ private boolean find(int key, Node<T>[] predecessors, Node<T>[] successors) { Node<T> predecessor = head; Node<T> current = null; Node<T> successor; boolean[] markHolder = new boolean[1]; boolean snip; restart: for (int i = MAX_LEVEL - 1; i >= 0; i--) { current = predecessor.next(i); // find first node // 1. not marked // 2. node's key is not greater than specified key while (true) { successor = current.next[i].get(markHolder); while (markHolder[0]) { snip = predecessor.next[i].compareAndSet(current, successor, false, false); if (!snip) { continue restart; } current = successor; successor = current.next[i].get(markHolder); } // current is not marked if (current.key < key) { predecessor = current; current = successor; } else { break; } } predecessors[i] = predecessor; successors[i] = current; } return current.key == key; } /** * Remove element. * * @param x element * @return true if success, false if not found or removed */ @SuppressWarnings("unchecked") public boolean remove(T x) { final int key = x.hashCode(); Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL]; Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL]; Node<T> node; Node<T> next; boolean[] markHolder = new boolean[1]; boolean success; if (!find(key, predecessors, successors)) { return false; // not found } node = successors[0]; for (int i = node.level() - 1; i >= 1; i--) { next = node.next[i].get(markHolder); while (!markHolder[0]) { success = node.next[i].attemptMark(next, true); if (success) { break; } next = node.next[i].get(markHolder); } } while (true) { next = node.next[0].get(markHolder); // logically removed if (markHolder[0]) { return false; } success = node.next[0].attemptMark(next, true); if (success) { find(key, predecessors, successors); return true; } } } private static final class Node<T> { final T value; final int key; final AtomicMarkableReference<Node<T>>[] next; @SuppressWarnings("unchecked") Node(T value, int key, int level) { this.value = value; this.key = key; next = (AtomicMarkableReference<Node<T>>[]) new AtomicMarkableReference[level]; } Node<T> next(int level) { return next[level].getReference(); } int level() { return next.length; } } }
注意这个类的不变量:书中提到以最后一行为准。因为SkipList其实是多个List的层叠,要所有层都保证一致比较困难。另外具体上述代码的分析,有兴趣的人建议阅读原书。
准备知识至此为主结束。接下来,分析ConcurrentSkipListMap特有的一些设计。
最下层不用说,就是一个HM Linked List。使用marker节点标示删除,同时在设置marker节点之前会CAS item为null。SkipList
- 一开始只有一层,随着节点插入,根据随机得到的层数增长层数,最大包括底层为32层
- 节点被删除时,有可能会减少层数
- 右侧没有tail,个人觉得是因为同时增长head和tail比较困难,以及实际运行中可以不需要
- 底层以外,Index层不使用marker节点,可能会有并发问题,但是同样以底层为准,所以影响不大
包括以上所有特性的ConcurrentSkipListMap简化版
import com.google.common.base.Preconditions; import javax.annotation.Nonnull; import java.util.Comparator; import java.util.concurrent.atomic.AtomicReference; @SuppressWarnings("Duplicates") public class SkipListSet3<T> { private final Comparator<T> comparator; private final AtomicReference<HeadIndex<T>> headIndex; public SkipListSet3(Comparator<T> comparator) { Preconditions.checkNotNull(comparator); this.comparator = comparator; Node<T> node = Node.ofNormal(null, null); headIndex = new AtomicReference<>(new HeadIndex<>(node, null, null, 1)); } public boolean contains(T item) { Preconditions.checkNotNull(item); for (Node<T> node = findNode(item, false); node != null; node = node.nextNotMarker()) { T value = node.item.get(); if (value == null) { // node is removed // no helping // go next continue; } int c = comparator.compare(value, item); if (c == 0) { // found return true; } if (c > 0) { // node.item > item // not found return false; } } // no more node return false; } public boolean contains2(T item) { Preconditions.checkNotNull(item); restart: while (true) { for (Node<T> predecessor = findNode(item, true), current = predecessor.next.get(); current != null; ) { if (current.marker) { // predecessor is deleted continue restart; } T value = current.item.get(); if (value == null) { // current is deleted helpDelete(predecessor, current); continue restart; } int c = comparator.compare(value, item); if (c == 0) { // found return true; } if (c > 0) { // current.item > item // not found return false; } // current.item < item predecessor = current; current = current.next.get(); } return false; } } /** * Help delete current. * * @param predecessor predecessor * @param current current */ private void helpDelete(Node<T> predecessor, Node<T> current) { Node<T> successor = current.next.get(); if (successor == null || !successor.marker) { // step 1 -> 2, insert marker, logical remove current.casNext(successor, Node.ofMarker(successor)); } else { // step 2 -> 3, physical remove predecessor.casNext(current, successor.next.get()); } } @Nonnull private Node<T> findNode(T item, boolean onlyPredecessor) { while (true) { for (Index<T> predecessor = headIndex.get(), current = predecessor.right.get(), successor, down; ; ) { if (current != null) { T value = current.node.item.get(); if (value == null) { // current is deleted successor = current.right.get(); if (!predecessor.casRight(current, successor)) { break; // restart } current = successor; continue; } int c = comparator.compare(value, item); if (c == 0) { // found return onlyPredecessor ? predecessor.node : current.node; } if (c < 0) { // current.node.item < item // go right predecessor = current; current = current.right.get(); continue; } } // 1. current == null, no more index // 2. current.item > item // go down down = predecessor.down; if (down == null) { // the last index level return predecessor.node; } predecessor = down; current = predecessor.right.get(); } } } public T add(T item) { Preconditions.checkNotNull(item); while (true) { for (Node<T> predecessor = findNode(item, true), current = predecessor.nextNotMarker(), successor; ; ) { if (current != null) { T value = current.item.get(); if (value == null) { // current is deleted successor = current.nextNotMarker(); if (!predecessor.casNext(current, successor)) { break; // restart } current = successor; continue; } int c = comparator.compare(value, item); if (c == 0) { if (value.equals(item) || current.casItem(value, item)) { return value; } break; // restart } if (c < 0) { // current.item < item predecessor = current; // current may be deleted at this point current = current.nextNotMarker(); continue; } } // 1. no more node // 2. current.item > item // insert node Node<T> newNode = Node.ofNormal(item, current); if (predecessor.casNext(current, newNode)) { // ok buildIndices(randomLevel(), newNode); return null; } break; // restart } } } public T add2(T item) { Preconditions.checkNotNull(item); while (true) { for (Node<T> predecessor = findNode(item, true), current = predecessor.next.get(); ; ) { if (current != null) { if (current.marker) { // predecessor is deleted break; // restart } T value = current.item.get(); if (value == null) { // current is deleted helpDelete(predecessor, current); break; } int c = comparator.compare(value, item); if (c == 0) { if (value.equals(item) || current.casItem(value, item)) { return value; } break; // restart } if (c < 0) { // current.item < item predecessor = current; current = current.next.get(); continue; } } // 1. no more node // 2. current.item > item // insert node Node<T> newNode = Node.ofNormal(item, current); if (predecessor.casNext(current, newNode)) { // ok buildIndices(randomLevel(), newNode); return null; } break; // restart } } } private void buildIndices(int level, Node<T> node) { if (level < 1) { // no index return; } HeadIndex<T> head = headIndex.get(); int newLevel; Index<T>[] indices; if (level > head.level) { newLevel = level; indices = makeIndices(newLevel, node); } else { newLevel = head.level + 1; indices = makeIndices(newLevel, node); head = increaseLevel(newLevel, head, node, indices); } insertIndices(head, newLevel, indices); } private void insertIndices(HeadIndex<T> head, int level, Index<T>[] indices) { final T item = indices[0].node.item.get(); Index<T> predecessor; Index<T> current; Index<T> successor; restart: while (true) { predecessor = head; current = predecessor.right.get(); int currentLevel = head.level; while (currentLevel > 0) { if (current != null) { T value = current.node.item.get(); if (value == null) { // current is deleted successor = current.right.get(); if (!predecessor.casRight(current, successor)) { continue restart; } current = successor; continue; } int c = comparator.compare(value, item); if (c == 0) { // impossible throw new IllegalStateException("encounter index with same item when insert index"); } if (c < 0) { // go right predecessor = current; current = current.right.get(); continue; } } // 1. current == null // 2. current.item > item // insert index if (currentLevel <= level) { // insert index indices[currentLevel - 1].lazySetRight(current); if (!predecessor.casRight(current, indices[currentLevel - 1])) { continue restart; } // node maybe deleted at this point } // 1. index inserted // 2. current level > level // go down predecessor = predecessor.down; current = predecessor.right.get(); currentLevel--; } // indices inserted return; } } private HeadIndex<T> increaseLevel(int expectLevel, HeadIndex<T> head, Node<T> node, Index<T>[] indices) { HeadIndex<T> oldHead = head; HeadIndex<T> newHead; while (oldHead.level < expectLevel) { // build head indices at once newHead = oldHead; for (int level = oldHead.level + 1; level <= expectLevel; level++) { newHead = new HeadIndex<>(node, indices[level - 1], newHead, level); } // newHead = new HeadIndex<>(node, indices[oldHead.level], oldHead, oldHead.level + 1); if (headIndex.compareAndSet(oldHead, newHead)) { return newHead; } oldHead = headIndex.get(); } return oldHead; } @SuppressWarnings("unchecked") private Index<T>[] makeIndices(int level, Node<T> node) { assert level > 0; Index<T>[] indices = (Index<T>[]) new Index[level]; Index<T> lastIndex = null; for (int i = 0; i < level; i++) { indices[i] = new Index<>(node, null, lastIndex); lastIndex = indices[i]; } return indices; } private int randomLevel() { int r = (int) System.nanoTime(); // xor shift r ^= r << 13; r ^= r >>> 17; r ^= r << 5; if ((r & 0x80000001) != 0) { return 0; } int level = 1; while (((r >>>= 1) & 1) != 0) { level++; } return level; } public T remove(T item) { Preconditions.checkNotNull(item); while (true) { for (Node<T> predecessor = findNode(item, true), current = predecessor.nextNotMarker(), successor; ; ) { if (current == null) { // not found return null; } T value = current.item.get(); if (value == null) { // current is deleted successor = current.nextNotMarker(); if (!predecessor.casNext(current, successor)) { break; // restart } current = successor; continue; } int c = comparator.compare(value, item); if (c > 0) { // not found return null; } if (c < 0) { // current.item < item // go next predecessor = current; current = current.nextNotMarker(); continue; } // c == 0 // found if (!current.casItem(value, null)) { break; // restart } successor = current.next.get(); if (successor == null || !successor.marker) { Node<T> marker = Node.ofMarker(successor); current.casNext(successor, marker); // if failed, a marker node is inserted predecessor.casNext(current, successor); } else { predecessor.casNext(current, successor.next.get()); } // unlink indices findNode(item, false); tryDecreaseLevel(); return value; } } } public T remove2(T item) { Preconditions.checkNotNull(item); restart: while (true) { for (Node<T> predecessor = findNode(item, true), current = predecessor.next.get(), successor; current != null; ) { if (current.marker) { // predecessor is deleted continue restart; } T value = current.item.get(); if (value == null) { // current is deleted helpDelete(predecessor, current); continue restart; } int c = comparator.compare(value, item); if (c > 0) { // current.item > item // not found return null; } if (c < 0) { // current.item < item // go next predecessor = current; current = current.next.get(); continue; } // c == 0 // found if (!current.casItem(value, null)) { continue restart; } successor = current.next.get(); if (successor == null || !successor.marker) { Node<T> marker = Node.ofMarker(successor); current.casNext(successor, marker); // if failed, a marker node is inserted predecessor.casNext(current, successor); } else { predecessor.casNext(current, successor.next.get()); } // unlink indices findNode(item, false); tryDecreaseLevel(); return value; } // not found return null; } } private void tryDecreaseLevel() { HeadIndex<T> t1 = headIndex.get(); if (t1.level <= 3) { return; } HeadIndex<T> t2 = (HeadIndex<T>) t1.down; HeadIndex<T> t3 = (HeadIndex<T>) t2.down; if (t3.right.get() != null || t2.right.get() != null || t1.right.get() != null) { return; } if (headIndex.compareAndSet(t1, t2)) { // rollback if right of t1 appeared if (t1.right.get() != null) { headIndex.compareAndSet(t2, t1); } } } private static class Index<T> { final Node<T> node; final AtomicReference<Index<T>> right; final Index<T> down; Index(Node<T> node, Index<T> right, Index<T> down) { this.node = node; this.right = new AtomicReference<>(right); this.down = down; } @SuppressWarnings("BooleanMethodIsAlwaysInverted") boolean casRight(Index<T> expect, Index<T> update) { return node.item.get() != null && right.compareAndSet(expect, update); } void lazySetRight(Index<T> right) { this.right.lazySet(right); } } private static final class HeadIndex<T> extends Index<T> { final int level; HeadIndex(Node<T> node, Index<T> right, Index<T> down, int level) { super(node, right, down); this.level = level; } } private static final class Node<T> { final boolean marker; final AtomicReference<T> item; final AtomicReference<Node<T>> next; Node(boolean marker, T item, Node<T> next) { this.marker = marker; this.item = new AtomicReference<>(item); this.next = new AtomicReference<>(next); } static <T> Node<T> ofMarker(Node<T> next) { return new Node<>(true, null, next); } static <T> Node<T> ofNormal(T item, Node<T> next) { return new Node<>(false, item, next); } Node<T> nextNotMarker() { Node<T> successor = next.get(); if (successor == null) { return null; } if (successor.marker) { return successor.next.get(); } return successor; } boolean casItem(T expect, T update) { return item.compareAndSet(expect, update); } boolean casNext(Node<T> expect, Node<T> update) { return next.compareAndSet(expect, update); } } }
代码中contains/add/remove使用nextNotMarker,contains2/add2/remove2更接近ConcurrentSkipListMap源代码。
代码复杂度contains < remove < add。
首先讲一下contains。contains比contains2要简单,原因在底层HM Linked List时不会帮助删除节点。所以可以像普通单向链表直接遍历。跳过null和小于目标值的节点。获取起始节点findNode方法和ConcurrentSkipListMap的findNode没有关系,要说的话,更像findPredecessor。findNode会从最上层Index开始,帮助删除已删除节点的Index,同时一直向下得到离目标值最近的节点,包括节点本身,返回前置还是节点本身由参数onlyPredecessor决定。
contain2比contains要复杂一些。因为要帮助删除节点,所以需要得到前置节点。其次,在当前节点被逻辑删除时,会根据当前是哪一步决定做什么。假如值被设置为null的话,那么就增加marker节点。如果marker节点存在,那么就物理删除节点。这里是否可以一步到位,反正值已经被设置为null,直接物理删除是否可以?或者说marker节点不加可以么?个人觉得理论上可能可以,但是一定几率会造成值为null的节点没有被物理删除成为数据结构中的垃圾数据,比如之前说的邻接节点同时删除的情况。
接下来是remove2,删除同样有一段help delete的过程,在找到节点之后,首先CAS item为null,接下来根据successor的情况,加marker和物理删除。由于值为null已经是一次逻辑删除了。所以加marker可以失败,物理删除也是。删除最后有一个尝试降低层级的过程。ConcurrentSkipListMap在层数大于3,并且顶层3层都没有Index(除了Head)的时候,会尝试下降一级,如果此时有新Index插入,那么会放弃下降。
最后是最复杂的add2,说它复杂是因为代码很多,ConcurrentSkipListMap全部写在一起。个人按照功能分离了好几个方法出来。首先是常规的HM Linked List插入。允许值替换。如果值替换了,直接返回。接下来是创建索引。使用xorshift生成一个随机数,从尾部开始计算连续1的数量。这个randomLevel个人本地测试了下,大致满足level越大出现次数越小的分布。当然你插入节点的位置你不能控制,所以还是拼运气。根据randomLevel,Head有可能会增高一层。也有可能不变。增高时有可能同时碰到删除引起的高度下降,理论上会碰到ABA问题,应该使用AtomicStampedReference,个人认为这里不用是因为SkipList并不参与不变量,所以一定的不稳定可以忍受。最后在SkipList插入新节点的索引。注意,因为SkipList没有使用HM Linked List,所以理论上有可能会出现没有删除,或者没加成功的情况,但是一方面几率小,另一方面可以忍受。
最后小结一下ConcurrentSkipListMap。ConcurrentSkipListMap依赖HM Linked List,使用marker节点优化。SkipList本身是一种随机数据结构,在并发环境下,SkipList能和HM Linked List很好地结合起来,避开类似基于数组的Map的resize等问题。由于是两种数据结构的组合,代码会比较复杂,但是值得阅读和思考。