本文是Java并发博客的第四篇。按照同步器和并发数据结构交替的顺序,本次是并发数据结构相关。之前介绍ConcurrentLinkedQueue的时候也提到过,并发数据结构不涉及同步器特有的问题,所以相对简单一些。分析的重点在于数据结构本身。
这次的主题是ConcurrentSkipListMap。老实说个人一开始也不知道这个类,既然有ConcurrentHashMap为什么还要有ConcurrentSkipListMap?单线程环境下的TreeMap用的次数本来就不多,并发环境下带排序的Map用得就更少了。话虽这么说,但在阅读了《the art of multiprocessor programming》第14章之后,个人发现ConcurrentSkipListMap是一个非常好的学习范本,特别是Java下HM Linked List(或者叫Harris Linked List)的实际产品代码。
如果你阅读过一些分析ConcurrentSkipListMap代码的文章的话,你可能会知道SkipList,因为名字里面就包括SkipList嘛。但实际上ConcurrentSkipListMap中维护正确性的不是SkipList,而是最底下的那一层HM Linked List。换句话说核心是在HM Linked List上,而不是帮助快速访问的SkipList。
个人建议在阅读ConcurrentSkipListMap的代码之前,了解以下内容
- HM Linked List
- HM Linked List基于marker节点的优化方式
- 普通的(非并发环境下的)SkipList
HM Linked List是一个比较有名的并发环境下的单向链表实现,有兴趣的人可以阅读相关论文。HM Linked List的实现要求,节点能够一次CAS marker和next两个字段。在Java里面对应AtomicMarkableReference。典型实现(《the art of multiprocessor programming》第9章)如下
import java.util.concurrent.atomic.AtomicMarkableReference;
/**
* HM linked list based on {@code AtomicMarkableReference}.
*
* @param <T>
*/
public class ListBasedSet1<T> {
private final Node<T> head;
public ListBasedSet1() {
head = new Node<>(Integer.MIN_VALUE, null, new Node<>(Integer.MAX_VALUE));
}
public boolean add(T item) {
final int key = item.hashCode();
Window<T> window;
Node<T> newNode;
while (true) {
window = find(head, key);
// same key
if (window.current.key == key) {
return false;
}
newNode = new Node<>(key, item, window.current);
if (window.predecessor.nextAndMark.compareAndSet(
window.current, newNode, false, false)) {
return true;
}
// retry
}
}
public boolean remove(T item) {
final int key = item.hashCode();
Window<T> window;
Node<T> successor;
while (true) {
window = find(head, key);
// not found
if (window.current.key != key) {
return false;
}
successor = window.current.nextAndMark.getReference();
// logical delete
if (window.current.nextAndMark.attemptMark(successor, true)) {
// physical delete
window.predecessor.nextAndMark.compareAndSet(
window.current, successor, false, false);
// fail is ok
return true;
}
// retry
}
}
public boolean contains(T item) {
int key = item.hashCode();
Node<T> current = head.next();
while (current.key < key) {
current = current.next();
}
return current.key == key && !current.isMarked();
}
private Window<T> find(Node<T> head, int key) {
boolean[] markedHolder = {false};
boolean snip;
retry:
for (Node<T> predecessor = head, current = predecessor.next(), successor; ; ) {
while (true) {
successor = current.nextAndMark.get(markedHolder);
// current node is deleted
while (markedHolder[0]) {
snip = predecessor.nextAndMark.compareAndSet(
current, successor, false, false);
if (!snip) {
continue retry;
}
current = successor;
successor = current.nextAndMark.get(markedHolder);
}
if (current.key >= key) {
return new Window<>(predecessor, current);
}
predecessor = current;
current = successor;
}
}
}
private static class Window<T> {
private final Node<T> predecessor;
private final Node<T> current;
Window(Node<T> predecessor, Node<T> current) {
this.predecessor = predecessor;
this.current = current;
}
}
private static class Node<T> {
private final int key;
private final T item;
private final AtomicMarkableReference<Node<T>> nextAndMark;
Node(int key) {
this.key = key;
this.item = null;
nextAndMark = new AtomicMarkableReference<>(null, false);
}
Node(int key, T item, Node<T> next) {
this.key = key;
this.item = item;
nextAndMark = new AtomicMarkableReference<>(next, false);
}
boolean isMarked() {
return nextAndMark.isMarked();
}
Node<T> next() {
return nextAndMark.getReference();
}
}
}
这里是一个ListBasedSet,即基于链表的Set。实现了contains/add/remove三个方法。基本想法是构建一个排序的链表,然后从起始节点遍历,找到指定节点(contains, remove)或者位置插入节点(add)。
你可能认为这是一个效率不高的Set实现,因为最坏情况下的时间复杂度是O(N)。但是加上SkipList之后,时间复杂度会变成O(Log(N))。这也是ConcurrentSkipListMap内部所做的事情。
回到ListBasedSet。以上实现的正确性如何证明?或者说为什么单线程环境下的单向链表(只有next指针)在并发环境下会出问题?具体出什么问题?ListBasedSet是如何解决的?
考虑如下单向链表
a -> b -> c -> d
假设只对next指针CAS的情况,线程1想把b节点删除,同时线程2想删除c。运行中出现如下情况
thread 1: CAS(a.next, b, b.next) thread 2: CAS(b.next, c, c.next)
两次CAS都能成功,但是c没有被删除,结果不正确。其他还有邻接节点的remove与add也会出现这种某个操作被“忽略”的情况(理论上ConcurrentSkipListMap的Index会出现这种情况,但是这可以被接受,因为正确性是由HM Linked List保证的)。表面原因是CAS了不同next指针,实际问题在于CAS next指针时无法保证参与执行的节点没有被删除。所以HM Linked List要求同时CAS next指针的节点的marker
thread 1: CAS(b, [b.next, false], [b.next, true]) thread 1: CAS(a, [b, false], [b.next, false]) thread 2: CAS(c, [c.next, false], [c.next, true]) thread 2: CAS(b, [c, false], [c.next, false])
这里方括号表示next指针值和marker(true表示被删除,false表示存在)。再次执行后,最后一个CAS会出错。因为b被mark了。具体的正确性证明建议看原论文或者《the art of multiprocessor programming》。
代码上体现上述过程的可以看remove部分的代码。第一次CAS是一个逻辑删除,之后的CAS才是真正的删除。实际实现中,只要求第一个CAS即逻辑删除成功,物理删除允许失败。原因在于add/remove中调用的find看到逻辑删除的节点,会帮助物理删除(当然,如果没有其他线程帮忙的话,物理删除的CAS肯定会成功)。其次,逻辑删除失败时,代码会从起始节点重新开始遍历和尝试删除。逻辑删除失败,一般来说是因为同一个节点被两个线程同时删除。理论上结果必须一个成功,一个失败(这是另外一个只CAS next指针时出现的问题,因为不凑巧的话两次都会返回true)。速度比较慢的线程重试时调用find,内部帮忙删除或者早已被另外一个线程删除之后,访问不到指定的节点肯定会返回false。所以这么做是正确的。
虽然HM Linked List被证明是正确的,实际实现还是不太理想。特别是AtomicMarkableReference在每次变更时会创建新的对象。C/C++可能可以通过操作next指针的某个bit来模拟marker,但是何时回收链表的节点是一个问题(remove的同时有一个遍历的线程的话你不能简单地删除节点)。
ConcurrentSkipListMap中给出的一个优化方案是在需要删除的节点的后面增加一个marker节点(设置节点值为null是针对其他问题的设计,请不要搞混)。假如我要删除以下链表的节点b,我会在b后面增加一个marker节点。
a -> b -> c a -> b -> marker -> c
只在删除的时候增加节点的话,比起原来AtomicMarkableReference的方式要好很多。但是代码会变得比较复杂。原来只需要看a, b, c三个节点,现在要看a, b, marker, c四个节点。
使用marker的参考实现
import java.util.concurrent.atomic.AtomicReference;
/**
* HM linked list using marker node.
*
* @param <T> element type
*/
public class ListBasedSet2<T> {
private final Node<T> head;
ListBasedSet2() {
Node<T> tail = new Node<>(NodeKind.NORMAL, Integer.MAX_VALUE, null, null);
head = new Node<>(NodeKind.NORMAL, Integer.MIN_VALUE, null, tail);
}
public boolean contains(T item) {
final int key = item.hashCode();
Node<T> current = head.next();
boolean[] markHolder = {false};
while (current.key < key) {
current = current.next(markHolder); // skip marker node
assert current != null;
}
return current.key == key && !markHolder[0];
}
@SuppressWarnings("unchecked")
public boolean add(T item) {
final int key = item.hashCode();
Node<T>[] nodes = (Node<T>[]) new Node[3];
while (true) {
// same key
if (find(key, nodes)) {
return false;
}
// p -> (n) -> c -> s
Node<T> newNode = new Node<>(NodeKind.NORMAL, key, item, nodes[1]);
if (nodes[0].next.compareAndSet(nodes[1], newNode)) {
return true;
}
}
}
private boolean find(int key, Node<T>[] nodes) {
Node<T> predecessor;
Node<T> current;
Node<T> successor;
boolean[] markHolder = {false};
boolean snip;
restart:
for (predecessor = head, current = predecessor.next(); ; ) {
assert current != null;
successor = current.next(markHolder);
while (markHolder[0]) {
assert successor != null;
snip = predecessor.next.compareAndSet(current, successor);
if (!snip) {
continue restart;
}
current = successor;
successor = current.next(markHolder);
}
if (current.key < key) {
predecessor = current;
current = successor;
} else {
break;
}
}
nodes[0] = predecessor;
nodes[1] = current;
nodes[2] = successor;
return current.key == key;
}
@SuppressWarnings("unchecked")
public boolean remove(T item) {
final int key = item.hashCode();
Node<T>[] nodes = (Node<T>[]) new Node[3];
while (true) {
if (!find(key, nodes)) {
return false;
}
// p -> [c -> (m)] -> s
// p -> s
Node<T> marker = new Node<>(NodeKind.MARKER, 0, null, nodes[2]);
if (nodes[1].next.compareAndSet(nodes[2], marker)) {
nodes[0].next.compareAndSet(nodes[1], nodes[2]);
return true;
}
}
}
@Override
public String toString() {
return "ListBasedSet2{" +
"head=" + head +
'}';
}
private enum NodeKind {
NORMAL,
MARKER;
}
private static final class Node<T> {
final NodeKind kind;
final int key;
final T item;
final AtomicReference<Node<T>> next;
Node(NodeKind kind, int key, T item, Node<T> next) {
this.kind = kind;
this.key = key;
this.item = item;
this.next = new AtomicReference<>(next);
}
Node<T> next() {
return next.get();
}
Node<T> next(boolean[] markHolder) {
Node<T> successor = next.get();
if (successor == null) { // tail
return null;
}
if (successor.kind == NodeKind.MARKER) {
markHolder[0] = true;
return successor.next();
}
markHolder[0] = false;
return successor;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder("Node{");
if (kind == NodeKind.NORMAL) {
builder.append("kind=NORMAL,");
builder.append("key=").append(key).append(',');
builder.append("item=").append(item).append(',');
} else {
builder.append("kind=MARKER,");
}
builder.append("next=\n").append(next.get()).append('}');
return builder.toString();
}
}
}
为了保持ListBasedSet1差不多的代码,ListBasedSet2用了一个签名为next(boolean[] markHolder)的方法,帮助跳过marker节点。同时,ListBasedSet2为了区分普通和marker节点,增加了NodeKind枚举。marker节点的内容无意义,所以简单的设置为null。注意,你需要很小心,不能在一个普通节点后面插入两个marker节点,也不能在普通节点和之后的marker节点中间插入节点。
你可能会注意到ListBasedSet1和ListBasedSet2使用了head和tail两个哨兵节点(sentinel node),以及按照hashCode大小排序。实际代码中tail不会被用到,其次由用户传入一个比较器(comparator)更好。以下是不使用tail和允许用户传入比较器的实现
import java.util.Comparator;
import java.util.concurrent.atomic.AtomicReference;
/**
* 1. HM linked list using marker node
* 2. no tail node
* 3. comparator
*
* @param <T> element type
*/
@SuppressWarnings("Duplicates")
public class ListBasedSet3<T> {
private final Comparator<T> comparator;
private final Node<T> head;
public ListBasedSet3(Comparator<T> comparator) {
this.comparator = comparator;
head = new Node<>(NodeKind.HEAD, null, null);
}
private int compareItem(Node<T> node1, T item) {
switch (node1.kind) {
case HEAD:
return -1;
case NORMAL:
return comparator.compare(node1.item, item);
default:
throw new IllegalStateException("cannot compareItem");
}
}
public boolean contains(T item) {
int c;
for (Node<T> current = head.next.get(), successor; ; ) {
if (current == null) {
// predecessor is last node
return false;
}
// current.kind never be head
assert current.kind != NodeKind.HEAD;
successor = current.next.get();
if (current.kind != NodeKind.MARKER) { // skip marker
c = compareItem(current, item);
if (c == 0) {
// current.item == item
// current.kind = normal
// item present if current is not marked
return successor == null || successor.kind != NodeKind.MARKER;
}
if (c > 0) {
// current.item > item
// not found
return false;
}
}
// 1. current.kind == marker
// 2. current.item < item
// go next
current = successor;
}
}
public boolean add(T item) {
int c;
boolean[] markHolder = {false};
restart:
while (true) {
for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
if (current == null) {
// no more node
if (insert(predecessor, item, null)) {
return true;
}
continue restart;
}
successor = current.next(markHolder);
// current is deleted
while (markHolder[0]) {
if (!predecessor.next.compareAndSet(current, successor)) {
continue restart;
}
current = successor;
if (current == null) {
// no more node
if (insert(predecessor, item, null)) {
return true;
}
continue restart;
}
successor = current.next(markHolder);
}
c = compareItem(current, item);
if (c == 0) {
// same key, item will not be replaced
return false;
}
if (c > 0) {
// current.item < item
if (insert(predecessor, item, successor)) {
return true;
}
continue restart;
}
predecessor = current;
current = successor;
}
}
}
private boolean insert(Node<T> predecessor, T item, Node<T> successor) {
Node<T> newNode = new Node<>(NodeKind.NORMAL, item, successor);
return predecessor.next.compareAndSet(successor, newNode);
}
public boolean remove(T item) {
int c;
boolean[] markHolder = {false};
Node<T> marker;
restart:
while (true) {
for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
if (current == null) {
// no more node
return false;
}
successor = current.next(markHolder);
while (markHolder[0]) {
if (!predecessor.next.compareAndSet(current, successor)) {
continue restart;
}
current = successor;
if (current == null) {
// no more node
return false;
}
successor = current.next(markHolder);
}
c = compareItem(current, item);
if (c == 0) {
marker = new Node<>(NodeKind.MARKER, null, successor);
// logical remove
if (current.next.compareAndSet(successor, marker)) {
// physical remove
predecessor.next.compareAndSet(current, successor);
return true;
}
continue restart;
}
if (c > 0) {
// not found
return false;
}
predecessor = current;
current = successor;
}
}
}
private enum NodeKind {
HEAD,
NORMAL,
MARKER
}
private static final class Node<T> {
final NodeKind kind;
final T item;
final AtomicReference<Node<T>> next;
Node(NodeKind kind, T item, Node<T> next) {
this.kind = kind;
this.item = item;
this.next = new AtomicReference<>(next);
}
Node<T> next(boolean[] markHolder) {
Node<T> successor = next.get();
if (successor == null) {
// last element
markHolder[0] = false;
return null;
}
if (successor.kind == NodeKind.MARKER) {
markHolder[0] = true;
return successor.next.get();
}
markHolder[0] = false;
return successor;
}
}
}
以上代码比较麻烦的是,因为了没有了tail,你需要时刻注意当前节点current是否为null。
以下的ListBasedSet4是个人写的最终版本,与前面的类不同的地方在于允许替换值。注意,HM Linked List默认在碰到相等的值的时候,是不会对节点做任何操作的。在Set的范畴内是正确的,但是在Map中,这很难保证。用户可能会传入相同的key,但是不同的value。所以,你需要一种方式能够允许用户替换值。
在HM Linked List中,直接替换值会碰到一个问题:替换了逻辑删除的节点的值,考虑如下执行序列
thread 1: logical delete node with key 1, value xxx thread 2: replace xxx with yyy thread 1: physical delete node with key 1
理论上,逻辑删除了的节点你是不能去替换的。但是现有策略无法给以上这种执行策略一个先后顺序,所以,ConcurrentSkipListMap提供了第二个逻辑删除的策略,即设置CAS值为null。即删除过程变成了三步
- CAS item not null -> null
- append marker
- CAS predecessor’s next pointer to successor
这样做的好处是,第一步之后,替换的线程会失败。坏处是null不能作为节点值,以及代码会更加复杂。
import com.google.common.base.Preconditions;
import java.util.Comparator;
import java.util.concurrent.atomic.AtomicReference;
/**
* 1. HM linked list using marker node
* 2. no tail node
* 3. comparator
* 4. replaceable
*
* @param <T> element type
*/
@SuppressWarnings("Duplicates")
public class ListBasedSet4<T> {
private final Comparator<T> comparator;
private final Node<T> head;
public ListBasedSet4(Comparator<T> comparator) {
this.comparator = comparator;
this.head = new Node<>(NodeKind.NORMAL, null, null);
}
public boolean contains(T item) {
Preconditions.checkNotNull(item);
int c;
T value;
for (Node<T> current = head.next.get(), successor; ; ) {
if (current == null) {
return false;
}
value = current.item.get();
if (value == null) {
// current is deleted
successor = current.next.get();
// skip marker if present
if (successor != null && successor.kind == NodeKind.MARKER) {
successor = successor.next.get();
}
current = successor;
continue;
}
c = comparator.compare(value, item);
if (c == 0) {
return true;
}
if (c > 0) {
// current.item > item
return false;
}
current = current.next.get();
}
}
public boolean add(T item) {
Preconditions.checkNotNull(item);
int c;
T value;
Node<T> newNode;
while (true) {
for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
if (current == null) {
// no more node
newNode = new Node<>(NodeKind.NORMAL, item, null);
if (predecessor.casNext(null, newNode)) {
return true;
}
break; // restart
}
value = current.item.get();
if (value == null) {
// current is deleted
successor = current.nextNotMarker();
if (predecessor.casNext(current, successor)) {
current = successor;
continue;
} else {
break;
}
}
c = comparator.compare(value, item);
if (c == 0) {
// same key, replace
if (current.casItem(value, item)) {
return true;
}
break;
}
successor = current.next.get();
if (c > 0) {
// current.item > item
newNode = new Node<>(NodeKind.NORMAL, item, successor);
if (predecessor.casNext(current, newNode)) {
return true;
}
break;
}
predecessor = current;
current = successor;
}
}
}
public boolean remove(T item) {
Preconditions.checkNotNull(item);
int c;
T value;
Node<T> marker;
while (true) {
for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
if (current == null) {
// no more node
return false;
}
value = current.item.get();
if (value == null) {
// current is deleted
successor = current.nextNotMarker();
if (predecessor.casNext(current, successor)) {
current = successor;
continue;
} else {
break; // restart
}
}
c = comparator.compare(value, item);
if (c == 0) {
// the node to remove
if (!current.casItem(value, null)) {
break;
}
// successor never be a marker node
successor = current.next.get();
marker = new Node<>(NodeKind.MARKER, null, successor);
if (current.casNext(successor, marker)) {
predecessor.casNext(current, successor);
return true;
}
// successor may be deleted at the same time
break;
}
if (c > 0) {
// not found
return false;
}
predecessor = current;
// current may be deleted at this point
current = current.nextNotMarker();
}
}
}
private enum NodeKind {
NORMAL,
MARKER
}
@SuppressWarnings("Duplicates")
private static final class Node<T> {
final NodeKind kind;
final AtomicReference<T> item;
final AtomicReference<Node<T>> next;
Node(NodeKind kind, T item, Node<T> next) {
this.kind = kind;
this.item = new AtomicReference<>(item);
this.next = new AtomicReference<>(next);
}
Node<T> nextNotMarker() {
Node<T> successor = next.get();
if (successor == null) {
return null;
}
if (successor.kind == NodeKind.MARKER) {
return successor.next.get();
}
return successor;
}
boolean casItem(T expect, T update) {
return item.compareAndSet(expect, update);
}
boolean casNext(Node<T> expect, Node<T> update) {
return next.compareAndSet(expect, update);
}
}
}
你可以看到在所有值比较之前,必须先检查当前值是否null。以上代码实际上已经很接近ConcurrentSkipListMap里面Node的用法了。在看如何套上SkipList代码之前,看一下SkipList的基本实现
import java.util.Random;
public class SkipList<T> {
private static final int MAX_LEVEL = 5;
private static final int NOT_FOUND = -1;
private final Random random = new Random();
private final Node<T> head;
public SkipList() {
head = new Node<>(Integer.MIN_VALUE);
final Node<T> tail = new Node<>(Integer.MAX_VALUE);
for (int level = 0; level < MAX_LEVEL; level++) {
head.next[level] = tail;
}
}
public boolean contains(T x) {
int key = x.hashCode();
Node<T> predecessor = head;
Node<T> current;
for (int level = MAX_LEVEL; level >= 0; level--) {
current = predecessor.next[level];
while (current.key < key) {
predecessor = current;
current = current.next[level];
}
if (current.key == key) {
return true;
}
}
return false;
}
private int find(T x, Node<T>[] predecessors, Node<T>[] successors) {
int key = x.hashCode();
int foundAtLevel = NOT_FOUND;
Node<T> predecessor = head;
Node<T> current;
for (int level = MAX_LEVEL - 1; level >= 0; level--) {
current = predecessor.next[level];
while (current.key < key) {
predecessor = current;
current = current.next[level];
}
predecessors[level] = predecessor;
successors[level] = current;
if (foundAtLevel == NOT_FOUND && current.key == key) {
foundAtLevel = level;
}
}
return foundAtLevel;
}
@SuppressWarnings("unchecked")
public boolean add(T x) {
Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
int foundAtLevel = find(x, predecessors, successors);
if (foundAtLevel != NOT_FOUND) {
return false;
}
int topLevel = random.nextInt(MAX_LEVEL);
Node<T> node = new Node<>(x.hashCode(), x, topLevel);
for (int level = 0; level < topLevel; level++) {
node.next[level] = successors[level];
predecessors[level].next[level] = node;
}
return true;
}
@SuppressWarnings("unchecked")
public boolean remove(T x) {
Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
int foundAtLevel = find(x, predecessors, successors);
if (foundAtLevel == NOT_FOUND) {
return false;
}
Node<T> node = successors[foundAtLevel];
for (int level = 0; level <= foundAtLevel; level++) {
predecessors[level].next[level] = node.next[level];
}
return true;
}
private static class Node<T> {
private final int key;
private final T value;
private final Node<T>[] next;
@SuppressWarnings("unchecked")
Node(int key) {
this.key = key;
value = null;
next = (Node<T>[]) new Node[MAX_LEVEL];
}
// topLevel >= 1 && topLevel <= MAX_LEVEL
@SuppressWarnings("unchecked")
Node(int key, T value, int topLevel) {
this.key = key;
this.value = value;
next = (Node<T>[]) new Node[topLevel];
}
int topLevel() {
return next.length;
}
@Override
public String toString() {
return "Node{" +
"key=" + key +
", item=" + value +
'}';
}
}
}
以上是《the art of multiprocessor programming》第14章使用的SkipList的简化版。老实说比ConcurrentSkipListMap简单太多了。有tail,不用一直检查null。最高层固定,使用随机数决定节点层数。
接下来是第14章中最终版本的SkipListBasedSet,使用HM Linked List(ListBasedSet1)+SkipList。
import com.google.common.base.Preconditions;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicMarkableReference;
@SuppressWarnings("Duplicates")
public class SkipListSet1<T> {
private static final int MAX_LEVEL = 5;
private final ThreadLocalRandom random = ThreadLocalRandom.current();
private final Node<T> head;
public SkipListSet1() {
head = new Node<>(null, Integer.MIN_VALUE, MAX_LEVEL);
final Node<T> tail = new Node<>(null, Integer.MAX_VALUE, MAX_LEVEL);
for (int i = 0; i < MAX_LEVEL; i++) {
head.next[i] = new AtomicMarkableReference<>(tail, false);
}
}
/**
* Test if element in set.
*
* @param x element
* @return true if present, otherwise false
*/
public boolean contains(T x) {
Preconditions.checkNotNull(x);
final int key = x.hashCode();
Node<T> predecessor = head;
Node<T> current = null;
Node<T> successor;
boolean[] markHolder = new boolean[1];
for (int i = MAX_LEVEL - 1; i >= 0; i--) {
current = predecessor.next(i);
while (true) {
successor = current.next[i].get(markHolder);
while (markHolder[0]) {
current = successor;
successor = current.next[i].get(markHolder);
}
if (current.key < key) {
predecessor = current;
current = current.next(i);
} else {
break;
}
}
}
return current.key == key;
}
/**
* Add element.
*
* @param x element
* @return true if added, false if present
*/
@SuppressWarnings("unchecked")
public boolean add(T x) {
Preconditions.checkNotNull(x);
final int key = x.hashCode();
final int level = random.nextInt(MAX_LEVEL) + 1;
final Node<T> node = new Node<>(x, key, level);
Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
boolean success;
while (true) {
if (find(key, predecessors, successors)) {
return false;
}
node.next[0] = new AtomicMarkableReference<>(successors[0], false);
success = predecessors[0].next[0].compareAndSet(successors[0], node, false, false);
if (!success) {
continue; // retry
}
// node added to skip list if success
for (int i = 1; i < level; i++) {
while (true) {
node.next[i] = new AtomicMarkableReference<>(successors[i], false);
success = predecessors[i].next[i].compareAndSet(successors[i], node, false, false);
if (success) {
break;
} else {
find(key, predecessors, successors);
}
}
}
return true;
}
}
/**
* Find element in skip list and return predecessors and successors.
*
* @param key key
* @return true if found, otherwise false
*/
private boolean find(int key, Node<T>[] predecessors, Node<T>[] successors) {
Node<T> predecessor = head;
Node<T> current = null;
Node<T> successor;
boolean[] markHolder = new boolean[1];
boolean snip;
restart:
for (int i = MAX_LEVEL - 1; i >= 0; i--) {
current = predecessor.next(i);
// find first node
// 1. not marked
// 2. node's key is not greater than specified key
while (true) {
successor = current.next[i].get(markHolder);
while (markHolder[0]) {
snip = predecessor.next[i].compareAndSet(current, successor, false, false);
if (!snip) {
continue restart;
}
current = successor;
successor = current.next[i].get(markHolder);
}
// current is not marked
if (current.key < key) {
predecessor = current;
current = successor;
} else {
break;
}
}
predecessors[i] = predecessor;
successors[i] = current;
}
return current.key == key;
}
/**
* Remove element.
*
* @param x element
* @return true if success, false if not found or removed
*/
@SuppressWarnings("unchecked")
public boolean remove(T x) {
final int key = x.hashCode();
Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
Node<T> node;
Node<T> next;
boolean[] markHolder = new boolean[1];
boolean success;
if (!find(key, predecessors, successors)) {
return false; // not found
}
node = successors[0];
for (int i = node.level() - 1; i >= 1; i--) {
next = node.next[i].get(markHolder);
while (!markHolder[0]) {
success = node.next[i].attemptMark(next, true);
if (success) {
break;
}
next = node.next[i].get(markHolder);
}
}
while (true) {
next = node.next[0].get(markHolder);
// logically removed
if (markHolder[0]) {
return false;
}
success = node.next[0].attemptMark(next, true);
if (success) {
find(key, predecessors, successors);
return true;
}
}
}
private static final class Node<T> {
final T value;
final int key;
final AtomicMarkableReference<Node<T>>[] next;
@SuppressWarnings("unchecked")
Node(T value, int key, int level) {
this.value = value;
this.key = key;
next = (AtomicMarkableReference<Node<T>>[]) new AtomicMarkableReference[level];
}
Node<T> next(int level) {
return next[level].getReference();
}
int level() {
return next.length;
}
}
}
注意这个类的不变量:书中提到以最后一行为准。因为SkipList其实是多个List的层叠,要所有层都保证一致比较困难。另外具体上述代码的分析,有兴趣的人建议阅读原书。
准备知识至此为主结束。接下来,分析ConcurrentSkipListMap特有的一些设计。
最下层不用说,就是一个HM Linked List。使用marker节点标示删除,同时在设置marker节点之前会CAS item为null。SkipList
- 一开始只有一层,随着节点插入,根据随机得到的层数增长层数,最大包括底层为32层
- 节点被删除时,有可能会减少层数
- 右侧没有tail,个人觉得是因为同时增长head和tail比较困难,以及实际运行中可以不需要
- 底层以外,Index层不使用marker节点,可能会有并发问题,但是同样以底层为准,所以影响不大
包括以上所有特性的ConcurrentSkipListMap简化版
import com.google.common.base.Preconditions;
import javax.annotation.Nonnull;
import java.util.Comparator;
import java.util.concurrent.atomic.AtomicReference;
@SuppressWarnings("Duplicates")
public class SkipListSet3<T> {
private final Comparator<T> comparator;
private final AtomicReference<HeadIndex<T>> headIndex;
public SkipListSet3(Comparator<T> comparator) {
Preconditions.checkNotNull(comparator);
this.comparator = comparator;
Node<T> node = Node.ofNormal(null, null);
headIndex = new AtomicReference<>(new HeadIndex<>(node, null, null, 1));
}
public boolean contains(T item) {
Preconditions.checkNotNull(item);
for (Node<T> node = findNode(item, false); node != null; node = node.nextNotMarker()) {
T value = node.item.get();
if (value == null) {
// node is removed
// no helping
// go next
continue;
}
int c = comparator.compare(value, item);
if (c == 0) {
// found
return true;
}
if (c > 0) {
// node.item > item
// not found
return false;
}
}
// no more node
return false;
}
public boolean contains2(T item) {
Preconditions.checkNotNull(item);
restart:
while (true) {
for (Node<T> predecessor = findNode(item, true),
current = predecessor.next.get(); current != null; ) {
if (current.marker) {
// predecessor is deleted
continue restart;
}
T value = current.item.get();
if (value == null) {
// current is deleted
helpDelete(predecessor, current);
continue restart;
}
int c = comparator.compare(value, item);
if (c == 0) {
// found
return true;
}
if (c > 0) {
// current.item > item
// not found
return false;
}
// current.item < item
predecessor = current;
current = current.next.get();
}
return false;
}
}
/**
* Help delete current.
*
* @param predecessor predecessor
* @param current current
*/
private void helpDelete(Node<T> predecessor, Node<T> current) {
Node<T> successor = current.next.get();
if (successor == null || !successor.marker) {
// step 1 -> 2, insert marker, logical remove
current.casNext(successor, Node.ofMarker(successor));
} else {
// step 2 -> 3, physical remove
predecessor.casNext(current, successor.next.get());
}
}
@Nonnull
private Node<T> findNode(T item, boolean onlyPredecessor) {
while (true) {
for (Index<T> predecessor = headIndex.get(), current = predecessor.right.get(), successor, down; ; ) {
if (current != null) {
T value = current.node.item.get();
if (value == null) {
// current is deleted
successor = current.right.get();
if (!predecessor.casRight(current, successor)) {
break; // restart
}
current = successor;
continue;
}
int c = comparator.compare(value, item);
if (c == 0) {
// found
return onlyPredecessor ? predecessor.node : current.node;
}
if (c < 0) {
// current.node.item < item
// go right
predecessor = current;
current = current.right.get();
continue;
}
}
// 1. current == null, no more index
// 2. current.item > item
// go down
down = predecessor.down;
if (down == null) {
// the last index level
return predecessor.node;
}
predecessor = down;
current = predecessor.right.get();
}
}
}
public T add(T item) {
Preconditions.checkNotNull(item);
while (true) {
for (Node<T> predecessor = findNode(item, true),
current = predecessor.nextNotMarker(), successor; ; ) {
if (current != null) {
T value = current.item.get();
if (value == null) {
// current is deleted
successor = current.nextNotMarker();
if (!predecessor.casNext(current, successor)) {
break; // restart
}
current = successor;
continue;
}
int c = comparator.compare(value, item);
if (c == 0) {
if (value.equals(item) || current.casItem(value, item)) {
return value;
}
break; // restart
}
if (c < 0) {
// current.item < item
predecessor = current;
// current may be deleted at this point
current = current.nextNotMarker();
continue;
}
}
// 1. no more node
// 2. current.item > item
// insert node
Node<T> newNode = Node.ofNormal(item, current);
if (predecessor.casNext(current, newNode)) {
// ok
buildIndices(randomLevel(), newNode);
return null;
}
break; // restart
}
}
}
public T add2(T item) {
Preconditions.checkNotNull(item);
while (true) {
for (Node<T> predecessor = findNode(item, true), current = predecessor.next.get(); ; ) {
if (current != null) {
if (current.marker) {
// predecessor is deleted
break; // restart
}
T value = current.item.get();
if (value == null) {
// current is deleted
helpDelete(predecessor, current);
break;
}
int c = comparator.compare(value, item);
if (c == 0) {
if (value.equals(item) || current.casItem(value, item)) {
return value;
}
break; // restart
}
if (c < 0) {
// current.item < item
predecessor = current;
current = current.next.get();
continue;
}
}
// 1. no more node
// 2. current.item > item
// insert node
Node<T> newNode = Node.ofNormal(item, current);
if (predecessor.casNext(current, newNode)) {
// ok
buildIndices(randomLevel(), newNode);
return null;
}
break; // restart
}
}
}
private void buildIndices(int level, Node<T> node) {
if (level < 1) {
// no index
return;
}
HeadIndex<T> head = headIndex.get();
int newLevel;
Index<T>[] indices;
if (level > head.level) {
newLevel = level;
indices = makeIndices(newLevel, node);
} else {
newLevel = head.level + 1;
indices = makeIndices(newLevel, node);
head = increaseLevel(newLevel, head, node, indices);
}
insertIndices(head, newLevel, indices);
}
private void insertIndices(HeadIndex<T> head, int level, Index<T>[] indices) {
final T item = indices[0].node.item.get();
Index<T> predecessor;
Index<T> current;
Index<T> successor;
restart:
while (true) {
predecessor = head;
current = predecessor.right.get();
int currentLevel = head.level;
while (currentLevel > 0) {
if (current != null) {
T value = current.node.item.get();
if (value == null) {
// current is deleted
successor = current.right.get();
if (!predecessor.casRight(current, successor)) {
continue restart;
}
current = successor;
continue;
}
int c = comparator.compare(value, item);
if (c == 0) { // impossible
throw new IllegalStateException("encounter index with same item when insert index");
}
if (c < 0) {
// go right
predecessor = current;
current = current.right.get();
continue;
}
}
// 1. current == null
// 2. current.item > item
// insert index
if (currentLevel <= level) {
// insert index
indices[currentLevel - 1].lazySetRight(current);
if (!predecessor.casRight(current, indices[currentLevel - 1])) {
continue restart;
}
// node maybe deleted at this point
}
// 1. index inserted
// 2. current level > level
// go down
predecessor = predecessor.down;
current = predecessor.right.get();
currentLevel--;
}
// indices inserted
return;
}
}
private HeadIndex<T> increaseLevel(int expectLevel, HeadIndex<T> head, Node<T> node, Index<T>[] indices) {
HeadIndex<T> oldHead = head;
HeadIndex<T> newHead;
while (oldHead.level < expectLevel) {
// build head indices at once
newHead = oldHead;
for (int level = oldHead.level + 1; level <= expectLevel; level++) {
newHead = new HeadIndex<>(node, indices[level - 1], newHead, level);
}
// newHead = new HeadIndex<>(node, indices[oldHead.level], oldHead, oldHead.level + 1);
if (headIndex.compareAndSet(oldHead, newHead)) {
return newHead;
}
oldHead = headIndex.get();
}
return oldHead;
}
@SuppressWarnings("unchecked")
private Index<T>[] makeIndices(int level, Node<T> node) {
assert level > 0;
Index<T>[] indices = (Index<T>[]) new Index[level];
Index<T> lastIndex = null;
for (int i = 0; i < level; i++) {
indices[i] = new Index<>(node, null, lastIndex);
lastIndex = indices[i];
}
return indices;
}
private int randomLevel() {
int r = (int) System.nanoTime();
// xor shift
r ^= r << 13;
r ^= r >>> 17;
r ^= r << 5;
if ((r & 0x80000001) != 0) {
return 0;
}
int level = 1;
while (((r >>>= 1) & 1) != 0) {
level++;
}
return level;
}
public T remove(T item) {
Preconditions.checkNotNull(item);
while (true) {
for (Node<T> predecessor = findNode(item, true),
current = predecessor.nextNotMarker(), successor; ; ) {
if (current == null) {
// not found
return null;
}
T value = current.item.get();
if (value == null) {
// current is deleted
successor = current.nextNotMarker();
if (!predecessor.casNext(current, successor)) {
break; // restart
}
current = successor;
continue;
}
int c = comparator.compare(value, item);
if (c > 0) {
// not found
return null;
}
if (c < 0) {
// current.item < item
// go next
predecessor = current;
current = current.nextNotMarker();
continue;
}
// c == 0
// found
if (!current.casItem(value, null)) {
break; // restart
}
successor = current.next.get();
if (successor == null || !successor.marker) {
Node<T> marker = Node.ofMarker(successor);
current.casNext(successor, marker);
// if failed, a marker node is inserted
predecessor.casNext(current, successor);
} else {
predecessor.casNext(current, successor.next.get());
}
// unlink indices
findNode(item, false);
tryDecreaseLevel();
return value;
}
}
}
public T remove2(T item) {
Preconditions.checkNotNull(item);
restart:
while (true) {
for (Node<T> predecessor = findNode(item, true),
current = predecessor.next.get(), successor; current != null; ) {
if (current.marker) {
// predecessor is deleted
continue restart;
}
T value = current.item.get();
if (value == null) {
// current is deleted
helpDelete(predecessor, current);
continue restart;
}
int c = comparator.compare(value, item);
if (c > 0) {
// current.item > item
// not found
return null;
}
if (c < 0) {
// current.item < item
// go next
predecessor = current;
current = current.next.get();
continue;
}
// c == 0
// found
if (!current.casItem(value, null)) {
continue restart;
}
successor = current.next.get();
if (successor == null || !successor.marker) {
Node<T> marker = Node.ofMarker(successor);
current.casNext(successor, marker);
// if failed, a marker node is inserted
predecessor.casNext(current, successor);
} else {
predecessor.casNext(current, successor.next.get());
}
// unlink indices
findNode(item, false);
tryDecreaseLevel();
return value;
}
// not found
return null;
}
}
private void tryDecreaseLevel() {
HeadIndex<T> t1 = headIndex.get();
if (t1.level <= 3) {
return;
}
HeadIndex<T> t2 = (HeadIndex<T>) t1.down;
HeadIndex<T> t3 = (HeadIndex<T>) t2.down;
if (t3.right.get() != null || t2.right.get() != null || t1.right.get() != null) {
return;
}
if (headIndex.compareAndSet(t1, t2)) {
// rollback if right of t1 appeared
if (t1.right.get() != null) {
headIndex.compareAndSet(t2, t1);
}
}
}
private static class Index<T> {
final Node<T> node;
final AtomicReference<Index<T>> right;
final Index<T> down;
Index(Node<T> node, Index<T> right, Index<T> down) {
this.node = node;
this.right = new AtomicReference<>(right);
this.down = down;
}
@SuppressWarnings("BooleanMethodIsAlwaysInverted")
boolean casRight(Index<T> expect, Index<T> update) {
return node.item.get() != null && right.compareAndSet(expect, update);
}
void lazySetRight(Index<T> right) {
this.right.lazySet(right);
}
}
private static final class HeadIndex<T> extends Index<T> {
final int level;
HeadIndex(Node<T> node, Index<T> right, Index<T> down, int level) {
super(node, right, down);
this.level = level;
}
}
private static final class Node<T> {
final boolean marker;
final AtomicReference<T> item;
final AtomicReference<Node<T>> next;
Node(boolean marker, T item, Node<T> next) {
this.marker = marker;
this.item = new AtomicReference<>(item);
this.next = new AtomicReference<>(next);
}
static <T> Node<T> ofMarker(Node<T> next) {
return new Node<>(true, null, next);
}
static <T> Node<T> ofNormal(T item, Node<T> next) {
return new Node<>(false, item, next);
}
Node<T> nextNotMarker() {
Node<T> successor = next.get();
if (successor == null) {
return null;
}
if (successor.marker) {
return successor.next.get();
}
return successor;
}
boolean casItem(T expect, T update) {
return item.compareAndSet(expect, update);
}
boolean casNext(Node<T> expect, Node<T> update) {
return next.compareAndSet(expect, update);
}
}
}
代码中contains/add/remove使用nextNotMarker,contains2/add2/remove2更接近ConcurrentSkipListMap源代码。
代码复杂度contains < remove < add。
首先讲一下contains。contains比contains2要简单,原因在底层HM Linked List时不会帮助删除节点。所以可以像普通单向链表直接遍历。跳过null和小于目标值的节点。获取起始节点findNode方法和ConcurrentSkipListMap的findNode没有关系,要说的话,更像findPredecessor。findNode会从最上层Index开始,帮助删除已删除节点的Index,同时一直向下得到离目标值最近的节点,包括节点本身,返回前置还是节点本身由参数onlyPredecessor决定。
contain2比contains要复杂一些。因为要帮助删除节点,所以需要得到前置节点。其次,在当前节点被逻辑删除时,会根据当前是哪一步决定做什么。假如值被设置为null的话,那么就增加marker节点。如果marker节点存在,那么就物理删除节点。这里是否可以一步到位,反正值已经被设置为null,直接物理删除是否可以?或者说marker节点不加可以么?个人觉得理论上可能可以,但是一定几率会造成值为null的节点没有被物理删除成为数据结构中的垃圾数据,比如之前说的邻接节点同时删除的情况。
接下来是remove2,删除同样有一段help delete的过程,在找到节点之后,首先CAS item为null,接下来根据successor的情况,加marker和物理删除。由于值为null已经是一次逻辑删除了。所以加marker可以失败,物理删除也是。删除最后有一个尝试降低层级的过程。ConcurrentSkipListMap在层数大于3,并且顶层3层都没有Index(除了Head)的时候,会尝试下降一级,如果此时有新Index插入,那么会放弃下降。
最后是最复杂的add2,说它复杂是因为代码很多,ConcurrentSkipListMap全部写在一起。个人按照功能分离了好几个方法出来。首先是常规的HM Linked List插入。允许值替换。如果值替换了,直接返回。接下来是创建索引。使用xorshift生成一个随机数,从尾部开始计算连续1的数量。这个randomLevel个人本地测试了下,大致满足level越大出现次数越小的分布。当然你插入节点的位置你不能控制,所以还是拼运气。根据randomLevel,Head有可能会增高一层。也有可能不变。增高时有可能同时碰到删除引起的高度下降,理论上会碰到ABA问题,应该使用AtomicStampedReference,个人认为这里不用是因为SkipList并不参与不变量,所以一定的不稳定可以忍受。最后在SkipList插入新节点的索引。注意,因为SkipList没有使用HM Linked List,所以理论上有可能会出现没有删除,或者没加成功的情况,但是一方面几率小,另一方面可以忍受。
最后小结一下ConcurrentSkipListMap。ConcurrentSkipListMap依赖HM Linked List,使用marker节点优化。SkipList本身是一种随机数据结构,在并发环境下,SkipList能和HM Linked List很好地结合起来,避开类似基于数组的Map的resize等问题。由于是两种数据结构的组合,代码会比较复杂,但是值得阅读和思考。