Java并发研究 ConcurrentSkipListMap与HM Linked List


本文是Java并发博客的第四篇。按照同步器和并发数据结构交替的顺序,本次是并发数据结构相关。之前介绍ConcurrentLinkedQueue的时候也提到过,并发数据结构不涉及同步器特有的问题,所以相对简单一些。分析的重点在于数据结构本身。

这次的主题是ConcurrentSkipListMap。老实说个人一开始也不知道这个类,既然有ConcurrentHashMap为什么还要有ConcurrentSkipListMap?单线程环境下的TreeMap用的次数本来就不多,并发环境下带排序的Map用得就更少了。话虽这么说,但在阅读了《the art of multiprocessor programming》第14章之后,个人发现ConcurrentSkipListMap是一个非常好的学习范本,特别是Java下HM Linked List(或者叫Harris Linked List)的实际产品代码。

如果你阅读过一些分析ConcurrentSkipListMap代码的文章的话,你可能会知道SkipList,因为名字里面就包括SkipList嘛。但实际上ConcurrentSkipListMap中维护正确性的不是SkipList,而是最底下的那一层HM Linked List。换句话说核心是在HM Linked List上,而不是帮助快速访问的SkipList。

个人建议在阅读ConcurrentSkipListMap的代码之前,了解以下内容

  1. HM Linked List
  2. HM Linked List基于marker节点的优化方式
  3. 普通的(非并发环境下的)SkipList

HM Linked List是一个比较有名的并发环境下的单向链表实现,有兴趣的人可以阅读相关论文。HM Linked List的实现要求,节点能够一次CAS marker和next两个字段。在Java里面对应AtomicMarkableReference。典型实现(《the art of multiprocessor programming》第9章)如下

import java.util.concurrent.atomic.AtomicMarkableReference;

/**
 * HM linked list based on {@code AtomicMarkableReference}.
 *
 * @param <T>
 */
public class ListBasedSet1<T> {

    private final Node<T> head;

    public ListBasedSet1() {
        head = new Node<>(Integer.MIN_VALUE, null, new Node<>(Integer.MAX_VALUE));
    }

    public boolean add(T item) {
        final int key = item.hashCode();
        Window<T> window;
        Node<T> newNode;
        while (true) {
            window = find(head, key);
            // same key
            if (window.current.key == key) {
                return false;
            }
            newNode = new Node<>(key, item, window.current);
            if (window.predecessor.nextAndMark.compareAndSet(
                    window.current, newNode, false, false)) {
                return true;
            }
            // retry
        }
    }

    public boolean remove(T item) {
        final int key = item.hashCode();
        Window<T> window;
        Node<T> successor;
        while (true) {
            window = find(head, key);
            // not found
            if (window.current.key != key) {
                return false;
            }
            successor = window.current.nextAndMark.getReference();
            // logical delete
            if (window.current.nextAndMark.attemptMark(successor, true)) {
                // physical delete
                window.predecessor.nextAndMark.compareAndSet(
                        window.current, successor, false, false);
                // fail is ok
                return true;
            }
            // retry
        }
    }

    public boolean contains(T item) {
        int key = item.hashCode();
        Node<T> current = head.next();
        while (current.key < key) {
            current = current.next();
        }
        return current.key == key && !current.isMarked();
    }

    private Window<T> find(Node<T> head, int key) {
        boolean[] markedHolder = {false};
        boolean snip;

        retry:
        for (Node<T> predecessor = head, current = predecessor.next(), successor; ; ) {
            while (true) {
                successor = current.nextAndMark.get(markedHolder);

                // current node is deleted
                while (markedHolder[0]) {
                    snip = predecessor.nextAndMark.compareAndSet(
                            current, successor, false, false);
                    if (!snip) {
                        continue retry;
                    }

                    current = successor;
                    successor = current.nextAndMark.get(markedHolder);
                }

                if (current.key >= key) {
                    return new Window<>(predecessor, current);
                }

                predecessor = current;
                current = successor;
            }
        }
    }


    private static class Window<T> {
        private final Node<T> predecessor;
        private final Node<T> current;

        Window(Node<T> predecessor, Node<T> current) {
            this.predecessor = predecessor;
            this.current = current;
        }
    }

    private static class Node<T> {
        private final int key;
        private final T item;
        private final AtomicMarkableReference<Node<T>> nextAndMark;

        Node(int key) {
            this.key = key;
            this.item = null;
            nextAndMark = new AtomicMarkableReference<>(null, false);
        }

        Node(int key, T item, Node<T> next) {
            this.key = key;
            this.item = item;
            nextAndMark = new AtomicMarkableReference<>(next, false);
        }

        boolean isMarked() {
            return nextAndMark.isMarked();
        }

        Node<T> next() {
            return nextAndMark.getReference();
        }
    }
}

这里是一个ListBasedSet,即基于链表的Set。实现了contains/add/remove三个方法。基本想法是构建一个排序的链表,然后从起始节点遍历,找到指定节点(contains, remove)或者位置插入节点(add)。

你可能认为这是一个效率不高的Set实现,因为最坏情况下的时间复杂度是O(N)。但是加上SkipList之后,时间复杂度会变成O(Log(N))。这也是ConcurrentSkipListMap内部所做的事情。

回到ListBasedSet。以上实现的正确性如何证明?或者说为什么单线程环境下的单向链表(只有next指针)在并发环境下会出问题?具体出什么问题?ListBasedSet是如何解决的?

考虑如下单向链表

a -> b -> c -> d

假设只对next指针CAS的情况,线程1想把b节点删除,同时线程2想删除c。运行中出现如下情况

thread 1: CAS(a.next, b, b.next)
thread 2: CAS(b.next, c, c.next)

两次CAS都能成功,但是c没有被删除,结果不正确。其他还有邻接节点的remove与add也会出现这种某个操作被“忽略”的情况(理论上ConcurrentSkipListMap的Index会出现这种情况,但是这可以被接受,因为正确性是由HM Linked List保证的)。表面原因是CAS了不同next指针,实际问题在于CAS next指针时无法保证参与执行的节点没有被删除。所以HM Linked List要求同时CAS next指针的节点的marker

thread 1: CAS(b, [b.next, false], [b.next, true])
thread 1: CAS(a, [b, false], [b.next, false])
thread 2: CAS(c, [c.next, false], [c.next, true])
thread 2: CAS(b, [c, false], [c.next, false])

这里方括号表示next指针值和marker(true表示被删除,false表示存在)。再次执行后,最后一个CAS会出错。因为b被mark了。具体的正确性证明建议看原论文或者《the art of multiprocessor programming》。

代码上体现上述过程的可以看remove部分的代码。第一次CAS是一个逻辑删除,之后的CAS才是真正的删除。实际实现中,只要求第一个CAS即逻辑删除成功,物理删除允许失败。原因在于add/remove中调用的find看到逻辑删除的节点,会帮助物理删除(当然,如果没有其他线程帮忙的话,物理删除的CAS肯定会成功)。其次,逻辑删除失败时,代码会从起始节点重新开始遍历和尝试删除。逻辑删除失败,一般来说是因为同一个节点被两个线程同时删除。理论上结果必须一个成功,一个失败(这是另外一个只CAS next指针时出现的问题,因为不凑巧的话两次都会返回true)。速度比较慢的线程重试时调用find,内部帮忙删除或者早已被另外一个线程删除之后,访问不到指定的节点肯定会返回false。所以这么做是正确的。

虽然HM Linked List被证明是正确的,实际实现还是不太理想。特别是AtomicMarkableReference在每次变更时会创建新的对象。C/C++可能可以通过操作next指针的某个bit来模拟marker,但是何时回收链表的节点是一个问题(remove的同时有一个遍历的线程的话你不能简单地删除节点)。

ConcurrentSkipListMap中给出的一个优化方案是在需要删除的节点的后面增加一个marker节点(设置节点值为null是针对其他问题的设计,请不要搞混)。假如我要删除以下链表的节点b,我会在b后面增加一个marker节点。

a -> b -> c
a -> b -> marker -> c

只在删除的时候增加节点的话,比起原来AtomicMarkableReference的方式要好很多。但是代码会变得比较复杂。原来只需要看a, b, c三个节点,现在要看a, b, marker, c四个节点。

使用marker的参考实现

import java.util.concurrent.atomic.AtomicReference;

/**
 * HM linked list using marker node.
 *
 * @param <T> element type
 */
public class ListBasedSet2<T> {

    private final Node<T> head;

    ListBasedSet2() {
        Node<T> tail = new Node<>(NodeKind.NORMAL, Integer.MAX_VALUE, null, null);
        head = new Node<>(NodeKind.NORMAL, Integer.MIN_VALUE, null, tail);
    }

    public boolean contains(T item) {
        final int key = item.hashCode();
        Node<T> current = head.next();
        boolean[] markHolder = {false};
        while (current.key < key) {
            current = current.next(markHolder); // skip marker node
            assert current != null;
        }
        return current.key == key && !markHolder[0];
    }

    @SuppressWarnings("unchecked")
    public boolean add(T item) {
        final int key = item.hashCode();
        Node<T>[] nodes = (Node<T>[]) new Node[3];
        while (true) {
            // same key
            if (find(key, nodes)) {
                return false;
            }
            // p -> (n) -> c -> s
            Node<T> newNode = new Node<>(NodeKind.NORMAL, key, item, nodes[1]);
            if (nodes[0].next.compareAndSet(nodes[1], newNode)) {
                return true;
            }
        }
    }

    private boolean find(int key, Node<T>[] nodes) {
        Node<T> predecessor;
        Node<T> current;
        Node<T> successor;

        boolean[] markHolder = {false};
        boolean snip;

        restart:
        for (predecessor = head, current = predecessor.next(); ; ) {
            assert current != null;
            successor = current.next(markHolder);
            while (markHolder[0]) {
                assert successor != null;
                snip = predecessor.next.compareAndSet(current, successor);
                if (!snip) {
                    continue restart;
                }
                current = successor;
                successor = current.next(markHolder);
            }
            if (current.key < key) {
                predecessor = current;
                current = successor;
            } else {
                break;
            }
        }

        nodes[0] = predecessor;
        nodes[1] = current;
        nodes[2] = successor;
        return current.key == key;
    }

    @SuppressWarnings("unchecked")
    public boolean remove(T item) {
        final int key = item.hashCode();
        Node<T>[] nodes = (Node<T>[]) new Node[3];

        while (true) {
            if (!find(key, nodes)) {
                return false;
            }
            // p -> [c -> (m)] -> s
            // p -> s
            Node<T> marker = new Node<>(NodeKind.MARKER, 0, null, nodes[2]);
            if (nodes[1].next.compareAndSet(nodes[2], marker)) {
                nodes[0].next.compareAndSet(nodes[1], nodes[2]);
                return true;
            }
        }
    }

    @Override
    public String toString() {
        return "ListBasedSet2{" +
                "head=" + head +
                '}';
    }

    private enum NodeKind {
        NORMAL,
        MARKER;
    }

    private static final class Node<T> {
        final NodeKind kind;
        final int key;
        final T item;
        final AtomicReference<Node<T>> next;

        Node(NodeKind kind, int key, T item, Node<T> next) {
            this.kind = kind;
            this.key = key;
            this.item = item;
            this.next = new AtomicReference<>(next);
        }

        Node<T> next() {
            return next.get();
        }

        Node<T> next(boolean[] markHolder) {
            Node<T> successor = next.get();
            if (successor == null) { // tail
                return null;
            }
            if (successor.kind == NodeKind.MARKER) {
                markHolder[0] = true;
                return successor.next();
            }
            markHolder[0] = false;
            return successor;
        }

        @Override
        public String toString() {
            StringBuilder builder = new StringBuilder("Node{");
            if (kind == NodeKind.NORMAL) {
                builder.append("kind=NORMAL,");
                builder.append("key=").append(key).append(',');
                builder.append("item=").append(item).append(',');
            } else {
                builder.append("kind=MARKER,");
            }
            builder.append("next=\n").append(next.get()).append('}');
            return builder.toString();
        }
    }
}

为了保持ListBasedSet1差不多的代码,ListBasedSet2用了一个签名为next(boolean[] markHolder)的方法,帮助跳过marker节点。同时,ListBasedSet2为了区分普通和marker节点,增加了NodeKind枚举。marker节点的内容无意义,所以简单的设置为null。注意,你需要很小心,不能在一个普通节点后面插入两个marker节点,也不能在普通节点和之后的marker节点中间插入节点。

你可能会注意到ListBasedSet1和ListBasedSet2使用了head和tail两个哨兵节点(sentinel node),以及按照hashCode大小排序。实际代码中tail不会被用到,其次由用户传入一个比较器(comparator)更好。以下是不使用tail和允许用户传入比较器的实现

import java.util.Comparator;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 1. HM linked list using marker node
 * 2. no tail node
 * 3. comparator
 *
 * @param <T> element type
 */
@SuppressWarnings("Duplicates")
public class ListBasedSet3<T> {

    private final Comparator<T> comparator;
    private final Node<T> head;

    public ListBasedSet3(Comparator<T> comparator) {
        this.comparator = comparator;
        head = new Node<>(NodeKind.HEAD, null, null);
    }

    private int compareItem(Node<T> node1, T item) {
        switch (node1.kind) {
            case HEAD:
                return -1;
            case NORMAL:
                return comparator.compare(node1.item, item);
            default:
                throw new IllegalStateException("cannot compareItem");
        }
    }

    public boolean contains(T item) {
        int c;
        for (Node<T> current = head.next.get(), successor; ; ) {
            if (current == null) {
                // predecessor is last node
                return false;
            }
            // current.kind never be head
            assert current.kind != NodeKind.HEAD;
            successor = current.next.get();
            if (current.kind != NodeKind.MARKER) { // skip marker
                c = compareItem(current, item);
                if (c == 0) {
                    // current.item == item
                    // current.kind = normal
                    // item present if current is not marked
                    return successor == null || successor.kind != NodeKind.MARKER;
                }
                if (c > 0) {
                    // current.item > item
                    // not found
                    return false;
                }
            }
            // 1. current.kind == marker
            // 2. current.item < item
            // go next
            current = successor;
        }
    }

    public boolean add(T item) {
        int c;
        boolean[] markHolder = {false};

        restart:
        while (true) {
            for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
                if (current == null) {
                    // no more node
                    if (insert(predecessor, item, null)) {
                        return true;
                    }
                    continue restart;
                }
                successor = current.next(markHolder);
                // current is deleted
                while (markHolder[0]) {
                    if (!predecessor.next.compareAndSet(current, successor)) {
                        continue restart;
                    }
                    current = successor;
                    if (current == null) {
                        // no more node
                        if (insert(predecessor, item, null)) {
                            return true;
                        }
                        continue restart;
                    }
                    successor = current.next(markHolder);
                }
                c = compareItem(current, item);
                if (c == 0) {
                    // same key, item will not be replaced
                    return false;
                }
                if (c > 0) {
                    // current.item < item
                    if (insert(predecessor, item, successor)) {
                        return true;
                    }
                    continue restart;
                }
                predecessor = current;
                current = successor;
            }
        }
    }

    private boolean insert(Node<T> predecessor, T item, Node<T> successor) {
        Node<T> newNode = new Node<>(NodeKind.NORMAL, item, successor);
        return predecessor.next.compareAndSet(successor, newNode);
    }

    public boolean remove(T item) {
        int c;
        boolean[] markHolder = {false};
        Node<T> marker;

        restart:
        while (true) {
            for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
                if (current == null) {
                    // no more node
                    return false;
                }
                successor = current.next(markHolder);
                while (markHolder[0]) {
                    if (!predecessor.next.compareAndSet(current, successor)) {
                        continue restart;
                    }
                    current = successor;
                    if (current == null) {
                        // no more node
                        return false;
                    }
                    successor = current.next(markHolder);
                }
                c = compareItem(current, item);
                if (c == 0) {
                    marker = new Node<>(NodeKind.MARKER, null, successor);
                    // logical remove
                    if (current.next.compareAndSet(successor, marker)) {
                        // physical remove
                        predecessor.next.compareAndSet(current, successor);
                        return true;
                    }
                    continue restart;
                }
                if (c > 0) {
                    // not found
                    return false;
                }
                predecessor = current;
                current = successor;
            }
        }
    }

    private enum NodeKind {
        HEAD,
        NORMAL,
        MARKER
    }

    private static final class Node<T> {
        final NodeKind kind;
        final T item;
        final AtomicReference<Node<T>> next;

        Node(NodeKind kind, T item, Node<T> next) {
            this.kind = kind;
            this.item = item;
            this.next = new AtomicReference<>(next);
        }

        Node<T> next(boolean[] markHolder) {
            Node<T> successor = next.get();
            if (successor == null) {
                // last element
                markHolder[0] = false;
                return null;
            }
            if (successor.kind == NodeKind.MARKER) {
                markHolder[0] = true;
                return successor.next.get();
            }
            markHolder[0] = false;
            return successor;
        }
    }

}

以上代码比较麻烦的是,因为了没有了tail,你需要时刻注意当前节点current是否为null。

以下的ListBasedSet4是个人写的最终版本,与前面的类不同的地方在于允许替换值。注意,HM Linked List默认在碰到相等的值的时候,是不会对节点做任何操作的。在Set的范畴内是正确的,但是在Map中,这很难保证。用户可能会传入相同的key,但是不同的value。所以,你需要一种方式能够允许用户替换值。

在HM Linked List中,直接替换值会碰到一个问题:替换了逻辑删除的节点的值,考虑如下执行序列

thread 1: logical delete node with key 1, value xxx
thread 2: replace xxx with yyy
thread 1: physical delete node with key 1

理论上,逻辑删除了的节点你是不能去替换的。但是现有策略无法给以上这种执行策略一个先后顺序,所以,ConcurrentSkipListMap提供了第二个逻辑删除的策略,即设置CAS值为null。即删除过程变成了三步

  1. CAS item not null -> null
  2. append marker
  3. CAS predecessor’s next pointer to successor

这样做的好处是,第一步之后,替换的线程会失败。坏处是null不能作为节点值,以及代码会更加复杂。

import com.google.common.base.Preconditions;

import java.util.Comparator;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 1. HM linked list using marker node
 * 2. no tail node
 * 3. comparator
 * 4. replaceable
 *
 * @param <T> element type
 */
@SuppressWarnings("Duplicates")
public class ListBasedSet4<T> {

    private final Comparator<T> comparator;
    private final Node<T> head;

    public ListBasedSet4(Comparator<T> comparator) {
        this.comparator = comparator;
        this.head = new Node<>(NodeKind.NORMAL, null, null);
    }

    public boolean contains(T item) {
        Preconditions.checkNotNull(item);

        int c;
        T value;

        for (Node<T> current = head.next.get(), successor; ; ) {
            if (current == null) {
                return false;
            }
            value = current.item.get();
            if (value == null) {
                // current is deleted
                successor = current.next.get();
                // skip marker if present
                if (successor != null && successor.kind == NodeKind.MARKER) {
                    successor = successor.next.get();
                }
                current = successor;
                continue;
            }
            c = comparator.compare(value, item);
            if (c == 0) {
                return true;
            }
            if (c > 0) {
                // current.item > item
                return false;
            }
            current = current.next.get();
        }
    }

    public boolean add(T item) {
        Preconditions.checkNotNull(item);

        int c;
        T value;
        Node<T> newNode;

        while (true) {
            for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
                if (current == null) {
                    // no more node
                    newNode = new Node<>(NodeKind.NORMAL, item, null);
                    if (predecessor.casNext(null, newNode)) {
                        return true;
                    }
                    break; // restart
                }
                value = current.item.get();
                if (value == null) {
                    // current is deleted
                    successor = current.nextNotMarker();
                    if (predecessor.casNext(current, successor)) {
                        current = successor;
                        continue;
                    } else {
                        break;
                    }
                }
                c = comparator.compare(value, item);
                if (c == 0) {
                    // same key, replace
                    if (current.casItem(value, item)) {
                        return true;
                    }
                    break;
                }
                successor = current.next.get();
                if (c > 0) {
                    // current.item > item
                    newNode = new Node<>(NodeKind.NORMAL, item, successor);
                    if (predecessor.casNext(current, newNode)) {
                        return true;
                    }
                    break;
                }
                predecessor = current;
                current = successor;
            }
        }
    }

    public boolean remove(T item) {
        Preconditions.checkNotNull(item);

        int c;
        T value;
        Node<T> marker;

        while (true) {
            for (Node<T> predecessor = head, current = predecessor.next.get(), successor; ; ) {
                if (current == null) {
                    // no more node
                    return false;
                }
                value = current.item.get();
                if (value == null) {
                    // current is deleted
                    successor = current.nextNotMarker();
                    if (predecessor.casNext(current, successor)) {
                        current = successor;
                        continue;
                    } else {
                        break; // restart
                    }
                }
                c = comparator.compare(value, item);
                if (c == 0) {
                    // the node to remove
                    if (!current.casItem(value, null)) {
                        break;
                    }
                    // successor never be a marker node
                    successor = current.next.get();
                    marker = new Node<>(NodeKind.MARKER, null, successor);
                    if (current.casNext(successor, marker)) {
                        predecessor.casNext(current, successor);
                        return true;
                    }
                    // successor may be deleted at the same time
                    break;
                }
                if (c > 0) {
                    // not found
                    return false;
                }
                predecessor = current;
                // current may be deleted at this point
                current = current.nextNotMarker();
            }
        }
    }

    private enum NodeKind {
        NORMAL,
        MARKER
    }

    @SuppressWarnings("Duplicates")
    private static final class Node<T> {
        final NodeKind kind;
        final AtomicReference<T> item;
        final AtomicReference<Node<T>> next;

        Node(NodeKind kind, T item, Node<T> next) {
            this.kind = kind;
            this.item = new AtomicReference<>(item);
            this.next = new AtomicReference<>(next);
        }

        Node<T> nextNotMarker() {
            Node<T> successor = next.get();
            if (successor == null) {
                return null;
            }
            if (successor.kind == NodeKind.MARKER) {
                return successor.next.get();
            }
            return successor;
        }

        boolean casItem(T expect, T update) {
            return item.compareAndSet(expect, update);
        }

        boolean casNext(Node<T> expect, Node<T> update) {
            return next.compareAndSet(expect, update);
        }
    }
}

你可以看到在所有值比较之前,必须先检查当前值是否null。以上代码实际上已经很接近ConcurrentSkipListMap里面Node的用法了。在看如何套上SkipList代码之前,看一下SkipList的基本实现

import java.util.Random;

public class SkipList<T> {

    private static final int MAX_LEVEL = 5;
    private static final int NOT_FOUND = -1;

    private final Random random = new Random();
    private final Node<T> head;

    public SkipList() {
        head = new Node<>(Integer.MIN_VALUE);
        final Node<T> tail = new Node<>(Integer.MAX_VALUE);
        for (int level = 0; level < MAX_LEVEL; level++) {
            head.next[level] = tail;
        }
    }

    public boolean contains(T x) {
        int key = x.hashCode();
        Node<T> predecessor = head;
        Node<T> current;
        for (int level = MAX_LEVEL; level >= 0; level--) {
            current = predecessor.next[level];
            while (current.key < key) {
                predecessor = current;
                current = current.next[level];
            }
            if (current.key == key) {
                return true;
            }
        }
        return false;
    }

    private int find(T x, Node<T>[] predecessors, Node<T>[] successors) {
        int key = x.hashCode();
        int foundAtLevel = NOT_FOUND;
        Node<T> predecessor = head;
        Node<T> current;
        for (int level = MAX_LEVEL - 1; level >= 0; level--) {
            current = predecessor.next[level];
            while (current.key < key) {
                predecessor = current;
                current = current.next[level];
            }
            predecessors[level] = predecessor;
            successors[level] = current;
            if (foundAtLevel == NOT_FOUND && current.key == key) {
                foundAtLevel = level;
            }
        }
        return foundAtLevel;
    }

    @SuppressWarnings("unchecked")
    public boolean add(T x) {
        Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
        Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
        int foundAtLevel = find(x, predecessors, successors);
        if (foundAtLevel != NOT_FOUND) {
            return false;
        }
        int topLevel = random.nextInt(MAX_LEVEL);
        Node<T> node = new Node<>(x.hashCode(), x, topLevel);
        for (int level = 0; level < topLevel; level++) {
            node.next[level] = successors[level];
            predecessors[level].next[level] = node;
        }
        return true;
    }

    @SuppressWarnings("unchecked")
    public boolean remove(T x) {
        Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
        Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
        int foundAtLevel = find(x, predecessors, successors);
        if (foundAtLevel == NOT_FOUND) {
            return false;
        }
        Node<T> node = successors[foundAtLevel];
        for (int level = 0; level <= foundAtLevel; level++) {
            predecessors[level].next[level] = node.next[level];
        }
        return true;
    }

    private static class Node<T> {
        private final int key;
        private final T value;
        private final Node<T>[] next;

        @SuppressWarnings("unchecked")
        Node(int key) {
            this.key = key;
            value = null;
            next = (Node<T>[]) new Node[MAX_LEVEL];
        }

        // topLevel >= 1 && topLevel <= MAX_LEVEL
        @SuppressWarnings("unchecked")
        Node(int key, T value, int topLevel) {
            this.key = key;
            this.value = value;
            next = (Node<T>[]) new Node[topLevel];
        }

        int topLevel() {
            return next.length;
        }

        @Override
        public String toString() {
            return "Node{" +
                    "key=" + key +
                    ", item=" + value +
                    '}';
        }
    }
}

以上是《the art of multiprocessor programming》第14章使用的SkipList的简化版。老实说比ConcurrentSkipListMap简单太多了。有tail,不用一直检查null。最高层固定,使用随机数决定节点层数。

接下来是第14章中最终版本的SkipListBasedSet,使用HM Linked List(ListBasedSet1)+SkipList。

import com.google.common.base.Preconditions;

import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicMarkableReference;

@SuppressWarnings("Duplicates")
public class SkipListSet1<T> {

    private static final int MAX_LEVEL = 5;

    private final ThreadLocalRandom random = ThreadLocalRandom.current();
    private final Node<T> head;

    public SkipListSet1() {
        head = new Node<>(null, Integer.MIN_VALUE, MAX_LEVEL);
        final Node<T> tail = new Node<>(null, Integer.MAX_VALUE, MAX_LEVEL);
        for (int i = 0; i < MAX_LEVEL; i++) {
            head.next[i] = new AtomicMarkableReference<>(tail, false);
        }
    }

    /**
     * Test if element in set.
     *
     * @param x element
     * @return true if present, otherwise false
     */
    public boolean contains(T x) {
        Preconditions.checkNotNull(x);
        final int key = x.hashCode();
        Node<T> predecessor = head;
        Node<T> current = null;
        Node<T> successor;
        boolean[] markHolder = new boolean[1];
        for (int i = MAX_LEVEL - 1; i >= 0; i--) {
            current = predecessor.next(i);
            while (true) {
                successor = current.next[i].get(markHolder);
                while (markHolder[0]) {
                    current = successor;
                    successor = current.next[i].get(markHolder);
                }
                if (current.key < key) {
                    predecessor = current;
                    current = current.next(i);
                } else {
                    break;
                }
            }
        }
        return current.key == key;
    }

    /**
     * Add element.
     *
     * @param x element
     * @return true if added, false if present
     */
    @SuppressWarnings("unchecked")
    public boolean add(T x) {
        Preconditions.checkNotNull(x);
        final int key = x.hashCode();
        final int level = random.nextInt(MAX_LEVEL) + 1;
        final Node<T> node = new Node<>(x, key, level);
        Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
        Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
        boolean success;
        while (true) {
            if (find(key, predecessors, successors)) {
                return false;
            }
            node.next[0] = new AtomicMarkableReference<>(successors[0], false);
            success = predecessors[0].next[0].compareAndSet(successors[0], node, false, false);
            if (!success) {
                continue; // retry
            }
            // node added to skip list if success

            for (int i = 1; i < level; i++) {
                while (true) {
                    node.next[i] = new AtomicMarkableReference<>(successors[i], false);
                    success = predecessors[i].next[i].compareAndSet(successors[i], node, false, false);
                    if (success) {
                        break;
                    } else {
                        find(key, predecessors, successors);
                    }
                }
            }
            return true;
        }
    }

    /**
     * Find element in skip list and return predecessors and successors.
     *
     * @param key key
     * @return true if found, otherwise false
     */
    private boolean find(int key, Node<T>[] predecessors, Node<T>[] successors) {
        Node<T> predecessor = head;
        Node<T> current = null;
        Node<T> successor;

        boolean[] markHolder = new boolean[1];
        boolean snip;

        restart:
        for (int i = MAX_LEVEL - 1; i >= 0; i--) {
            current = predecessor.next(i);
            // find first node
            // 1. not marked
            // 2. node's key is not greater than specified key
            while (true) {
                successor = current.next[i].get(markHolder);
                while (markHolder[0]) {
                    snip = predecessor.next[i].compareAndSet(current, successor, false, false);
                    if (!snip) {
                        continue restart;
                    }
                    current = successor;
                    successor = current.next[i].get(markHolder);
                }
                // current is not marked
                if (current.key < key) {
                    predecessor = current;
                    current = successor;
                } else {
                    break;
                }
            }
            predecessors[i] = predecessor;
            successors[i] = current;
        }
        return current.key == key;
    }

    /**
     * Remove element.
     *
     * @param x element
     * @return true if success, false if not found or removed
     */
    @SuppressWarnings("unchecked")
    public boolean remove(T x) {
        final int key = x.hashCode();
        Node<T>[] predecessors = (Node<T>[]) new Node[MAX_LEVEL];
        Node<T>[] successors = (Node<T>[]) new Node[MAX_LEVEL];
        Node<T> node;
        Node<T> next;
        boolean[] markHolder = new boolean[1];
        boolean success;

        if (!find(key, predecessors, successors)) {
            return false; // not found
        }
        node = successors[0];
        for (int i = node.level() - 1; i >= 1; i--) {
            next = node.next[i].get(markHolder);
            while (!markHolder[0]) {
                success = node.next[i].attemptMark(next, true);
                if (success) {
                    break;
                }
                next = node.next[i].get(markHolder);
            }
        }
        while (true) {
            next = node.next[0].get(markHolder);
            // logically removed
            if (markHolder[0]) {
                return false;
            }
            success = node.next[0].attemptMark(next, true);
            if (success) {
                find(key, predecessors, successors);
                return true;
            }
        }
    }

    private static final class Node<T> {
        final T value;
        final int key;
        final AtomicMarkableReference<Node<T>>[] next;

        @SuppressWarnings("unchecked")
        Node(T value, int key, int level) {
            this.value = value;
            this.key = key;
            next = (AtomicMarkableReference<Node<T>>[]) new AtomicMarkableReference[level];
        }

        Node<T> next(int level) {
            return next[level].getReference();
        }

        int level() {
            return next.length;
        }
    }
}

注意这个类的不变量:书中提到以最后一行为准。因为SkipList其实是多个List的层叠,要所有层都保证一致比较困难。另外具体上述代码的分析,有兴趣的人建议阅读原书。

准备知识至此为主结束。接下来,分析ConcurrentSkipListMap特有的一些设计。

最下层不用说,就是一个HM Linked List。使用marker节点标示删除,同时在设置marker节点之前会CAS item为null。SkipList

  • 一开始只有一层,随着节点插入,根据随机得到的层数增长层数,最大包括底层为32层
  • 节点被删除时,有可能会减少层数
  • 右侧没有tail,个人觉得是因为同时增长head和tail比较困难,以及实际运行中可以不需要
  • 底层以外,Index层不使用marker节点,可能会有并发问题,但是同样以底层为准,所以影响不大

包括以上所有特性的ConcurrentSkipListMap简化版

import com.google.common.base.Preconditions;

import javax.annotation.Nonnull;
import java.util.Comparator;
import java.util.concurrent.atomic.AtomicReference;

@SuppressWarnings("Duplicates")
public class SkipListSet3<T> {

    private final Comparator<T> comparator;
    private final AtomicReference<HeadIndex<T>> headIndex;

    public SkipListSet3(Comparator<T> comparator) {
        Preconditions.checkNotNull(comparator);
        this.comparator = comparator;

        Node<T> node = Node.ofNormal(null, null);
        headIndex = new AtomicReference<>(new HeadIndex<>(node, null, null, 1));
    }

    public boolean contains(T item) {
        Preconditions.checkNotNull(item);

        for (Node<T> node = findNode(item, false); node != null; node = node.nextNotMarker()) {
            T value = node.item.get();
            if (value == null) {
                // node is removed
                // no helping
                // go next
                continue;
            }
            int c = comparator.compare(value, item);
            if (c == 0) {
                // found
                return true;
            }
            if (c > 0) {
                // node.item > item
                // not found
                return false;
            }
        }
        // no more node
        return false;
    }

    public boolean contains2(T item) {
        Preconditions.checkNotNull(item);

        restart:
        while (true) {
            for (Node<T> predecessor = findNode(item, true),
                 current = predecessor.next.get(); current != null; ) {

                if (current.marker) {
                    // predecessor is deleted
                    continue restart;
                }
                T value = current.item.get();
                if (value == null) {
                    // current is deleted
                    helpDelete(predecessor, current);
                    continue restart;
                }
                int c = comparator.compare(value, item);
                if (c == 0) {
                    // found
                    return true;
                }
                if (c > 0) {
                    // current.item > item
                    // not found
                    return false;
                }
                // current.item < item
                predecessor = current;
                current = current.next.get();
            }
            return false;
        }
    }

    /**
     * Help delete current.
     *
     * @param predecessor predecessor
     * @param current     current
     */
    private void helpDelete(Node<T> predecessor, Node<T> current) {
        Node<T> successor = current.next.get();
        if (successor == null || !successor.marker) {
            // step 1 -> 2, insert marker, logical remove
            current.casNext(successor, Node.ofMarker(successor));
        } else {
            // step 2 -> 3, physical remove
            predecessor.casNext(current, successor.next.get());
        }
    }

    @Nonnull
    private Node<T> findNode(T item, boolean onlyPredecessor) {
        while (true) {
            for (Index<T> predecessor = headIndex.get(), current = predecessor.right.get(), successor, down; ; ) {
                if (current != null) {
                    T value = current.node.item.get();
                    if (value == null) {
                        // current is deleted
                        successor = current.right.get();
                        if (!predecessor.casRight(current, successor)) {
                            break; // restart
                        }
                        current = successor;
                        continue;
                    }
                    int c = comparator.compare(value, item);
                    if (c == 0) {
                        // found
                        return onlyPredecessor ? predecessor.node : current.node;
                    }
                    if (c < 0) {
                        // current.node.item < item
                        // go right
                        predecessor = current;
                        current = current.right.get();
                        continue;
                    }
                }
                // 1. current == null, no more index
                // 2. current.item > item
                // go down
                down = predecessor.down;
                if (down == null) {
                    // the last index level
                    return predecessor.node;
                }
                predecessor = down;
                current = predecessor.right.get();
            }
        }
    }

    public T add(T item) {
        Preconditions.checkNotNull(item);

        while (true) {
            for (Node<T> predecessor = findNode(item, true),
                 current = predecessor.nextNotMarker(), successor; ; ) {

                if (current != null) {
                    T value = current.item.get();
                    if (value == null) {
                        // current is deleted
                        successor = current.nextNotMarker();
                        if (!predecessor.casNext(current, successor)) {
                            break; // restart
                        }
                        current = successor;
                        continue;
                    }
                    int c = comparator.compare(value, item);
                    if (c == 0) {
                        if (value.equals(item) || current.casItem(value, item)) {
                            return value;
                        }
                        break; // restart
                    }
                    if (c < 0) {
                        // current.item < item
                        predecessor = current;
                        // current may be deleted at this point
                        current = current.nextNotMarker();
                        continue;
                    }
                }
                // 1. no more node
                // 2. current.item > item
                // insert node
                Node<T> newNode = Node.ofNormal(item, current);
                if (predecessor.casNext(current, newNode)) {
                    // ok
                    buildIndices(randomLevel(), newNode);
                    return null;
                }
                break; // restart
            }
        }
    }

    public T add2(T item) {
        Preconditions.checkNotNull(item);

        while (true) {
            for (Node<T> predecessor = findNode(item, true), current = predecessor.next.get(); ; ) {
                if (current != null) {
                    if (current.marker) {
                        // predecessor is deleted
                        break; // restart
                    }

                    T value = current.item.get();
                    if (value == null) {
                        // current is deleted
                        helpDelete(predecessor, current);
                        break;
                    }
                    int c = comparator.compare(value, item);
                    if (c == 0) {
                        if (value.equals(item) || current.casItem(value, item)) {
                            return value;
                        }
                        break; // restart
                    }
                    if (c < 0) {
                        // current.item < item
                        predecessor = current;
                        current = current.next.get();
                        continue;
                    }
                }
                // 1. no more node
                // 2. current.item > item
                // insert node
                Node<T> newNode = Node.ofNormal(item, current);
                if (predecessor.casNext(current, newNode)) {
                    // ok
                    buildIndices(randomLevel(), newNode);
                    return null;
                }
                break; // restart
            }
        }
    }

    private void buildIndices(int level, Node<T> node) {
        if (level < 1) {
            // no index
            return;
        }

        HeadIndex<T> head = headIndex.get();
        int newLevel;
        Index<T>[] indices;
        if (level > head.level) {
            newLevel = level;
            indices = makeIndices(newLevel, node);
        } else {
            newLevel = head.level + 1;
            indices = makeIndices(newLevel, node);
            head = increaseLevel(newLevel, head, node, indices);
        }
        insertIndices(head, newLevel, indices);
    }

    private void insertIndices(HeadIndex<T> head, int level, Index<T>[] indices) {
        final T item = indices[0].node.item.get();

        Index<T> predecessor;
        Index<T> current;
        Index<T> successor;

        restart:
        while (true) {
            predecessor = head;
            current = predecessor.right.get();

            int currentLevel = head.level;
            while (currentLevel > 0) {
                if (current != null) {
                    T value = current.node.item.get();
                    if (value == null) {
                        // current is deleted
                        successor = current.right.get();
                        if (!predecessor.casRight(current, successor)) {
                            continue restart;
                        }
                        current = successor;
                        continue;
                    }
                    int c = comparator.compare(value, item);
                    if (c == 0) { // impossible
                        throw new IllegalStateException("encounter index with same item when insert index");
                    }
                    if (c < 0) {
                        // go right
                        predecessor = current;
                        current = current.right.get();
                        continue;
                    }
                }
                // 1. current == null
                // 2. current.item > item
                // insert index
                if (currentLevel <= level) {
                    // insert index
                    indices[currentLevel - 1].lazySetRight(current);
                    if (!predecessor.casRight(current, indices[currentLevel - 1])) {
                        continue restart;
                    }
                    // node maybe deleted at this point
                }
                // 1. index inserted
                // 2. current level > level
                // go down
                predecessor = predecessor.down;
                current = predecessor.right.get();
                currentLevel--;
            }

            // indices inserted
            return;
        }
    }

    private HeadIndex<T> increaseLevel(int expectLevel, HeadIndex<T> head, Node<T> node, Index<T>[] indices) {
        HeadIndex<T> oldHead = head;
        HeadIndex<T> newHead;

        while (oldHead.level < expectLevel) {
            // build head indices at once
            newHead = oldHead;
            for (int level = oldHead.level + 1; level <= expectLevel; level++) {
                newHead = new HeadIndex<>(node, indices[level - 1], newHead, level);
            }
            // newHead = new HeadIndex<>(node, indices[oldHead.level], oldHead, oldHead.level + 1);
            if (headIndex.compareAndSet(oldHead, newHead)) {
                return newHead;
            }
            oldHead = headIndex.get();
        }
        return oldHead;
    }

    @SuppressWarnings("unchecked")
    private Index<T>[] makeIndices(int level, Node<T> node) {
        assert level > 0;
        Index<T>[] indices = (Index<T>[]) new Index[level];
        Index<T> lastIndex = null;
        for (int i = 0; i < level; i++) {
            indices[i] = new Index<>(node, null, lastIndex);
            lastIndex = indices[i];
        }
        return indices;
    }

    private int randomLevel() {
        int r = (int) System.nanoTime();
        // xor shift
        r ^= r << 13;
        r ^= r >>> 17;
        r ^= r << 5;
        if ((r & 0x80000001) != 0) {
            return 0;
        }
        int level = 1;
        while (((r >>>= 1) & 1) != 0) {
            level++;
        }
        return level;
    }

    public T remove(T item) {
        Preconditions.checkNotNull(item);

        while (true) {
            for (Node<T> predecessor = findNode(item, true),
                 current = predecessor.nextNotMarker(), successor; ; ) {
                if (current == null) {
                    // not found
                    return null;
                }
                T value = current.item.get();
                if (value == null) {
                    // current is deleted
                    successor = current.nextNotMarker();
                    if (!predecessor.casNext(current, successor)) {
                        break; // restart
                    }
                    current = successor;
                    continue;
                }
                int c = comparator.compare(value, item);
                if (c > 0) {
                    // not found
                    return null;
                }
                if (c < 0) {
                    // current.item < item
                    // go next
                    predecessor = current;
                    current = current.nextNotMarker();
                    continue;
                }
                // c == 0
                // found
                if (!current.casItem(value, null)) {
                    break; // restart
                }
                successor = current.next.get();
                if (successor == null || !successor.marker) {
                    Node<T> marker = Node.ofMarker(successor);
                    current.casNext(successor, marker);
                    // if failed, a marker node is inserted
                    predecessor.casNext(current, successor);
                } else {
                    predecessor.casNext(current, successor.next.get());
                }
                // unlink indices
                findNode(item, false);
                tryDecreaseLevel();
                return value;
            }
        }
    }

    public T remove2(T item) {
        Preconditions.checkNotNull(item);

        restart:
        while (true) {
            for (Node<T> predecessor = findNode(item, true),
                 current = predecessor.next.get(), successor; current != null; ) {
                if (current.marker) {
                    // predecessor is deleted
                    continue restart;
                }
                T value = current.item.get();
                if (value == null) {
                    // current is deleted
                    helpDelete(predecessor, current);
                    continue restart;
                }
                int c = comparator.compare(value, item);
                if (c > 0) {
                    // current.item > item
                    // not found
                    return null;
                }
                if (c < 0) {
                    // current.item < item
                    // go next
                    predecessor = current;
                    current = current.next.get();
                    continue;
                }
                // c == 0
                // found
                if (!current.casItem(value, null)) {
                    continue restart;
                }
                successor = current.next.get();
                if (successor == null || !successor.marker) {
                    Node<T> marker = Node.ofMarker(successor);
                    current.casNext(successor, marker);
                    // if failed, a marker node is inserted
                    predecessor.casNext(current, successor);
                } else {
                    predecessor.casNext(current, successor.next.get());
                }
                // unlink indices
                findNode(item, false);
                tryDecreaseLevel();
                return value;
            }
            // not found
            return null;
        }
    }

    private void tryDecreaseLevel() {
        HeadIndex<T> t1 = headIndex.get();
        if (t1.level <= 3) {
            return;
        }
        HeadIndex<T> t2 = (HeadIndex<T>) t1.down;
        HeadIndex<T> t3 = (HeadIndex<T>) t2.down;
        if (t3.right.get() != null || t2.right.get() != null || t1.right.get() != null) {
            return;
        }
        if (headIndex.compareAndSet(t1, t2)) {
            // rollback if right of t1 appeared
            if (t1.right.get() != null) {
                headIndex.compareAndSet(t2, t1);
            }
        }
    }

    private static class Index<T> {
        final Node<T> node;
        final AtomicReference<Index<T>> right;
        final Index<T> down;

        Index(Node<T> node, Index<T> right, Index<T> down) {
            this.node = node;
            this.right = new AtomicReference<>(right);
            this.down = down;
        }

        @SuppressWarnings("BooleanMethodIsAlwaysInverted")
        boolean casRight(Index<T> expect, Index<T> update) {
            return node.item.get() != null && right.compareAndSet(expect, update);
        }

        void lazySetRight(Index<T> right) {
            this.right.lazySet(right);
        }
    }

    private static final class HeadIndex<T> extends Index<T> {
        final int level;

        HeadIndex(Node<T> node, Index<T> right, Index<T> down, int level) {
            super(node, right, down);
            this.level = level;
        }
    }

    private static final class Node<T> {
        final boolean marker;
        final AtomicReference<T> item;
        final AtomicReference<Node<T>> next;

        Node(boolean marker, T item, Node<T> next) {
            this.marker = marker;
            this.item = new AtomicReference<>(item);
            this.next = new AtomicReference<>(next);
        }

        static <T> Node<T> ofMarker(Node<T> next) {
            return new Node<>(true, null, next);
        }

        static <T> Node<T> ofNormal(T item, Node<T> next) {
            return new Node<>(false, item, next);
        }

        Node<T> nextNotMarker() {
            Node<T> successor = next.get();
            if (successor == null) {
                return null;
            }
            if (successor.marker) {
                return successor.next.get();
            }
            return successor;
        }

        boolean casItem(T expect, T update) {
            return item.compareAndSet(expect, update);
        }

        boolean casNext(Node<T> expect, Node<T> update) {
            return next.compareAndSet(expect, update);
        }

    }
}

代码中contains/add/remove使用nextNotMarker,contains2/add2/remove2更接近ConcurrentSkipListMap源代码。

代码复杂度contains < remove < add。

首先讲一下contains。contains比contains2要简单,原因在底层HM Linked List时不会帮助删除节点。所以可以像普通单向链表直接遍历。跳过null和小于目标值的节点。获取起始节点findNode方法和ConcurrentSkipListMap的findNode没有关系,要说的话,更像findPredecessor。findNode会从最上层Index开始,帮助删除已删除节点的Index,同时一直向下得到离目标值最近的节点,包括节点本身,返回前置还是节点本身由参数onlyPredecessor决定。

contain2比contains要复杂一些。因为要帮助删除节点,所以需要得到前置节点。其次,在当前节点被逻辑删除时,会根据当前是哪一步决定做什么。假如值被设置为null的话,那么就增加marker节点。如果marker节点存在,那么就物理删除节点。这里是否可以一步到位,反正值已经被设置为null,直接物理删除是否可以?或者说marker节点不加可以么?个人觉得理论上可能可以,但是一定几率会造成值为null的节点没有被物理删除成为数据结构中的垃圾数据,比如之前说的邻接节点同时删除的情况。

接下来是remove2,删除同样有一段help delete的过程,在找到节点之后,首先CAS item为null,接下来根据successor的情况,加marker和物理删除。由于值为null已经是一次逻辑删除了。所以加marker可以失败,物理删除也是。删除最后有一个尝试降低层级的过程。ConcurrentSkipListMap在层数大于3,并且顶层3层都没有Index(除了Head)的时候,会尝试下降一级,如果此时有新Index插入,那么会放弃下降。

最后是最复杂的add2,说它复杂是因为代码很多,ConcurrentSkipListMap全部写在一起。个人按照功能分离了好几个方法出来。首先是常规的HM Linked List插入。允许值替换。如果值替换了,直接返回。接下来是创建索引。使用xorshift生成一个随机数,从尾部开始计算连续1的数量。这个randomLevel个人本地测试了下,大致满足level越大出现次数越小的分布。当然你插入节点的位置你不能控制,所以还是拼运气。根据randomLevel,Head有可能会增高一层。也有可能不变。增高时有可能同时碰到删除引起的高度下降,理论上会碰到ABA问题,应该使用AtomicStampedReference,个人认为这里不用是因为SkipList并不参与不变量,所以一定的不稳定可以忍受。最后在SkipList插入新节点的索引。注意,因为SkipList没有使用HM Linked List,所以理论上有可能会出现没有删除,或者没加成功的情况,但是一方面几率小,另一方面可以忍受。

最后小结一下ConcurrentSkipListMap。ConcurrentSkipListMap依赖HM Linked List,使用marker节点优化。SkipList本身是一种随机数据结构,在并发环境下,SkipList能和HM Linked List很好地结合起来,避开类似基于数组的Map的resize等问题。由于是两种数据结构的组合,代码会比较复杂,但是值得阅读和思考。