Java并发研究 自己写ReentrantLock和ReentrantReadWriteLock(4)


上篇。在写完ReentrantLock之后,其实可以基于ReentrantLock写一个ReadWriteLock,《the art of multiprocessor programming》第八章有介绍。但是,本着不完全AQS(AbstractQueuedSynchronizer)介绍的系列主题,这里从零开始重新写一个ReentrantReadWriteLock。

按照ReadWriteLock的定义,任何时候都满足

  1. 没有线程持有锁
  2. 有1~n个线程持有共享锁(Read)
  3. 有1个线程持有独占锁(Write)

中的一个。

其次公平的ReadWriteLock要求新来的Read或者Write线程必须在队列中等待,非公平的ReadWriteLock允许新来的Read或者Write比队列中等待的线程先获取锁。关于非公平锁这里多说一句,理论上的非公平锁类似一群人哄抢的现象,但是实现多半是只允许新来和线程队列最前面的线程抢占锁。ReadWriteLock也是一样。如果你想要完全非公平的锁的话,可能AQS和这里的实现不满足你的需求。

为了实现ReadWriteLock的定义,你需要分别记录读写状态。考虑到独占(Write)状态只可能有一个线程,可能场景如下:

No. Reader Count Writer Count
1 0 0
2 1+ 0
3 0 1

可以看到,从没有任何线程持有锁的状态1变成2或者3时,reader count和writer count分别变化,这个特性导致你必须同时CAS reader count和writer count,或者混合成一个变量,分别用低位和高位来区别reader count和writer count(或者用正数和负数也可以)。AQS采用的是后者。这里简化一下,并给几个简单的关系式:

count = 00000000
reader_count = count & 0x00FF
writer_count = count & 0xFF00

还有一个要考虑的问题是重入(reentrant)如何实现。你可以增加reader count,或者和ReentrantLock一样每个读线程自己维护一个thread local变量。从效率上来说,后者比较好。

SimpleReadWriteLock

只有tryLock的基本实现

import javax.annotation.Nonnull;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;

public class SimpleReadWriteLock implements ReadWriteLock {
    private static final int WRITER_MASK = 0xFF00;
    private static final int WRITER_UNIT = 0x0100;

    // writer reentrant times, reader count
    private final AtomicInteger count = new AtomicInteger(0);
    private final ReadLock readLock = new ReadLock();
    private final WriteLock writeLock = new WriteLock();

    @Override
    @Nonnull
    public Lock readLock() {
        return readLock;
    }

    @Override
    @Nonnull
    public Lock writeLock() {
        return writeLock;
    }

    private static abstract class AbstractLock implements Lock {
        @Override
        public void lock() {
            throw new UnsupportedOperationException();
        }

        @Override
        public void lockInterruptibly() throws InterruptedException {
            throw new UnsupportedOperationException();
        }

        @Override
        public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException {
            throw new UnsupportedOperationException();
        }

        @Override
        @Nonnull
        public Condition newCondition() {
            throw new UnsupportedOperationException();
        }
    }

    @SuppressWarnings("Duplicates")
    private class ReadLock extends AbstractLock {
        // reentrant times of current reader
        private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() -> 0);

        @Override
        public boolean tryLock() {
            int rt = reentrantTimes.get();
            if (rt > 0) {
                reentrantTimes.set(rt + 1);
                return true;
            }
            int c = count.get();
            if (((c & WRITER_MASK) == 0) && count.compareAndSet(c, c + 1)) {
                reentrantTimes.set(1);
                return true;
            }
            return false;
        }

        @Override
        public void unlock() {
            int rt = reentrantTimes.get();
            if (rt <= 0) {
                throw new IllegalStateException("attempt to unlock without holding lock");
            }
            if (rt > 1) {
                reentrantTimes.set(rt - 1);
                return;
            }
            reentrantTimes.set(0);
            if (count.get() < 1) {
                throw new IllegalStateException("no reentrantTimes");
            }
            count.decrementAndGet();
        }
    }

    @SuppressWarnings("Duplicates")
    private class WriteLock extends AbstractLock {
        private Thread writer;

        @Override
        public boolean tryLock() {
            if (writer == Thread.currentThread()) {
                count.getAndAdd(WRITER_UNIT);
                return true;
            }
            if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
                writer = Thread.currentThread();
                return true;
            }
            return false;
        }

        @Override
        public void unlock() {
            if (writer != Thread.currentThread()) {
                throw new IllegalStateException("attempt to unlock without holding lock");
            }
            int c = count.get();
            if (c < WRITER_UNIT) {
                throw new IllegalStateException("no writer");
            }
            if (c == WRITER_UNIT) {
                writer = null;
            }
            count.set(c - WRITER_UNIT);
        }
    }
}

重点关注其中的CAS操作是如何处理count的。代码这里不具体分析了,相信看过ReentrantLock分析的话,这里也不会太难。

接下来考虑如何加入队列。实现可能有两种选择

  1. 和Write一样,每个Read作为单独一个节点
  2. 连续的Read单独作为一个节点

后者看起来不错,因为头节点(这里指Write节点)在唤醒后续节点的时候只需要唤醒一个就行了,否则要唤醒多个连续的Read节点。但是实际上AQS选择的是第一种!原因是什么?个人认为第二种在唤醒上和第一种在本质上没有太大差别:唤醒后续的多个Read节点和唤醒后续的单个Read节点后再唤醒这个Read节点里其他的Read线程。于此同时带来了一系列细节问题,入队时有可能直接入队(最后的节点是Write节点时),也有可能跟随最后的一个Read节点,还要考虑唤醒丢失的问题等等。总之,得不偿失。所以个人的渐进式实现也是用第一种方法。

在唤醒后续节点上,考虑以下几种情况

 

第一种是只有Write的情况,必然只需要唤醒一个后继节点。第二种只有一个后继Read节点,所以也只要唤醒一个后继节点。但是第三种情况下,首节点可能要唤醒两个后续节点。第四种,Read线程作为首节点,只需要唤醒后续第一个Write节点。第五种情况比较特殊:这里表示连续唤醒的中途。

从图中可以看出,将一个Read线程作为一个节点的话,唤醒必须是多次的。更准确一点来说,如果后继节点是Read节点,那么连续唤醒,否则(Write节点),只唤醒一个。

再考虑一个问题,在哪个步骤唤醒?

  1. unlock
  2. 前一个Read节点获取了锁之后

以及谁唤醒谁?

  1. unlock的节点唤醒后续节点
  2. Read节点唤醒后续节点

这里其实没有2×2总共4种设计,实际可行的主要就以上两种。

第一种在unlock时唤醒后续的Read节点,注意只唤醒一个。然后Read节点在获取了锁之后唤醒后续的一个Read节点,依此类推。第二种在unlock时唤醒后续所有的Read节点。

AQS选择了第一种。原因是什么?个人认为原因是保证能够获取锁时才唤醒。假设使用第二种设计,第一个Read节点被唤醒,可以获取锁了,但是还没推进队列(acquire和release时队列的处理请参见之前ReentrantLock的介绍),此时W继续唤醒第二个Read节点时,由于第二个Read节点的前置节点,即第一个Read节点还没有成为head,所以只能继续park,等于无效唤醒。但是采用第一种设计的话,第一个Read节点只有在获取了锁之后,设置自己为head,接着唤醒后续Read节点,此时第二个Read节点肯定能够获取锁,所以不会是一次无效唤醒。

注意,第一种方法的图中,节点上方的唤醒和unlock时的唤醒是不同的,上方的唤醒只针对Read节点,也就是说,后继节点是Read节点才会唤醒,相对的,unlock无条件唤醒后续的单个节点。

你可能注意到,上述方法中并没有提到Node的PROPAGATE,即使看上去Read节点的连续唤醒很像传播。原因是ReadWriteLock不需要PROPAGATE,其次PROPAGATE可以说是针对基于AQS的Semaphore的解决方案。这里大致分析一下Semaphore没有PROPAGATE的话可能存在的问题。

对于Semaphore来说,主要有两种操作

  • acquire
  • release

其中release可能会唤醒多个等待acquire的线程,所以采用的是类似ReadLock,即没有WriteLock的ReadWriteLock模型。acquire方法对应AQS的acquireShared,release方法对应releaseShared。

  1. 假设一个Semaphore有permits为2,线程1持有1,线程2持有1,线程3等待中,线程4等待中。此时的队列类似上图。head这里不重要,可能是sentinel节点,也有可能是线程1或者2。两个waiter分别是线程3和4
  2. 线程1使用完permit,通过releaseShared返还。这时head唤醒第一个waiter,即线程3。线程3发现Semaphore中有一个permit,成功获取,同时因为只有1个permit,线程4不决定唤醒线程4(准确来说是tryAcquireShared返回的propagate为0)。此时线程3还未推进队列
  3. 线程2使用完permit,通过releaseShared返还。这时head没有变化,同时按照ReentrantLock一样的方式,如果唤醒过了后续节点的话,不会再尝试唤醒。即线程3不会被重复唤醒
  4. 线程3继续执行,没有唤醒后续节点4,同时permits保留1,但是线程4无法获取

这里个人认为一个直接的解决方法是,在唤醒上使用之前的unlock时连续唤醒,同时首节点无条件唤醒后续所有节点(不检查自己的status)。当然这样做的话无效唤醒的可能性会比较大。另一种方法是线程1返还permit时,确保在下一次release之前线程3成为首节点。老实说这比较难。

AQS的解决方法主要是

  • 在推进队列时检查head的status(setHeadAndPropagate)
  • releaseShared是从0改成PROPAGATE(doReleaseShared)

这样,上述场景的线程3发现head为PROPAGATE,并唤醒后续节点。

由于加入了PROPAGATE,一部分代码需要针对这个新状态做额外处理。但是在ReadWriteLock中,不存在这个问题:Read节点会无条件唤醒后续节点。

最终代码

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReadWriteLock;

@SuppressWarnings("Duplicates")
public class UnfairReadWriteTimedLock3 implements ReadWriteLock {
    // private static final int READER_MASK = 0x00FF;
    private static final int WRITER_MASK = 0xFF00;
    private static final int WRITER_UNIT = 0x0100;

    private final AtomicInteger count = new AtomicInteger(0);
    private final Queue queue = new Queue();
    private final ReadLock readLock = new ReadLock();
    private final WriteLock writeLock = new WriteLock();

    @Override
    @Nonnull
    public Lock readLock() {
        return readLock;
    }

    @Override
    @Nonnull
    public Lock writeLock() {
        return writeLock;
    }

    private static abstract class AbstractLock implements Lock {
        @Override
        @Nonnull
        public Condition newCondition() {
            throw new UnsupportedOperationException();
        }
    }

    @SuppressWarnings("Duplicates")
    private class ReadLock extends AbstractLock {
        private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() -> 0);

        @Override
        public void lock() {
            if (tryLock()) {
                return;
            }
            int c;
            final Node node = new Node(Thread.currentThread(), true);
            queue.enqueue(node);
            while (true) {
                if (node.predecessor.get() == queue.head.get()) {
                    c = count.get();
                    if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
                        myTurn(node);
                        return;
                    }
                }
                if (isReadyToPark(node)) {
                    LockSupport.park(this);
                }
            }
        }

        @Override
        public void lockInterruptibly() throws InterruptedException {
            if (tryLock()) {
                return;
            }
            int c;
            final Node node = new Node(Thread.currentThread(), true);
            Node predecessor = queue.enqueue(node);
            while (true) {
                if (predecessor == queue.head.get()) {
                    c = count.get();
                    if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
                        myTurn(node);
                        return;
                    }
                }
                if (isReadyToPark(node)) {
                    LockSupport.park(this);
                }
                if (Thread.interrupted()) {
                    abort(node);
                    throw new InterruptedException();
                }
            }
        }

        @Override
        public boolean tryLock() {
            int rt = reentrantTimes.get();
            if (rt > 0) {
                reentrantTimes.set(rt + 1);
                return true;
            }
            int c = count.get();
            if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
                reentrantTimes.set(1);
                return true;
            }
            return false;
        }

        @Override
        public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException {
            if (tryLock()) {
                return true;
            }
            final long deadline = System.nanoTime() + unit.toNanos(time);
            final Node node = new Node(Thread.currentThread(), true);
            Node predecessor = queue.enqueue(node);
            long nanos;
            int c;
            while (true) {
                if (predecessor == queue.head.get()) {
                    c = count.get();
                    if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
                        myTurn(node);
                        return true;
                    }
                }
                nanos = deadline - System.nanoTime();
                if (nanos <= 0L) {
                    abort(node);
                    return false;
                }
                if (isReadyToPark(node)) {
                    LockSupport.parkNanos(this, nanos);
                }
                if (Thread.interrupted()) {
                    abort(node);
                    throw new InterruptedException();
                }
            }
        }

        private void myTurn(@Nonnull Node node) {
            reentrantTimes.set(1);
            node.clearThread();
            queue.head.set(node);

            /*
             * propagate if successor is reader
             *
             * In ReadWriteLock, there's no need to check if propagate, it always propagates.
             */
            if (node.resetSignalStatus()) {
                Node successor = queue.findNormalSuccessor(node);
                if (successor != null && successor.shared) {
                    LockSupport.unpark(successor.thread.get());
                }
            }
        }

        @Override
        public void unlock() {
            int rt = reentrantTimes.get();
            if (rt < 1) {
                throw new IllegalStateException("not the thread holding lock");
            }
            if (rt > 1) {
                reentrantTimes.set(rt - 1);
                return;
            }
            // rt == 1
            reentrantTimes.set(0);
            if (count.get() < 1) {
                throw new IllegalStateException("count < 1");
            }
            if (count.decrementAndGet() > 0) {
                return;
            }
            Node h = queue.head.get();
            if (h != null && h.resetSignalStatus()) {
                unparkNormalSuccessor(h);
            }
        }
    }

    @SuppressWarnings("Duplicates")
    private class WriteLock extends AbstractLock {
        private Thread owner;

        @Override
        public void lock() {
            if (tryLock()) {
                return;
            }
            Node node = new Node(Thread.currentThread());
            Node predecessor = queue.enqueue(node);
            while (true) {
                if (predecessor == queue.head.get() &&
                        count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
                    myTurn(node);
                    return;
                }
                if (isReadyToPark(node)) {
                    LockSupport.park(this);
                }
            }
        }

        @Override
        public void lockInterruptibly() throws InterruptedException {
            if (tryLock()) {
                return;
            }
            Node node = new Node(Thread.currentThread());
            Node predecessor = queue.enqueue(node);
            while (true) {
                if (predecessor == queue.head.get() &&
                        count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
                    myTurn(node);
                    return;
                }
                if (isReadyToPark(node)) {
                    LockSupport.park(this);
                }
                if (Thread.interrupted()) {
                    abort(node);
                    throw new InterruptedException();
                }
            }
        }

        @Override
        public boolean tryLock() {
            if (owner == Thread.currentThread()) {
                count.getAndAdd(WRITER_UNIT);
                return true;
            }
            if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
                owner = Thread.currentThread();
                return true;
            }
            return false;
        }

        @Override
        public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException {
            if (tryLock()) {
                return true;
            }
            long deadline = System.nanoTime() + unit.toNanos(time);
            Node node = new Node(Thread.currentThread());
            Node predecessor = queue.enqueue(node);
            long nanos;
            while (true) {
                if (predecessor == queue.head.get() &&
                        count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
                    myTurn(node);
                    return true;
                }
                nanos = deadline - System.nanoTime();
                if (nanos <= 0L) {
                    abort(node);
                    return false;
                }
                if (isReadyToPark(node)) {
                    LockSupport.parkNanos(this, nanos);
                }
                if (Thread.interrupted()) {
                    abort(node);
                    throw new InterruptedException();
                }
            }
        }

        private void myTurn(@Nonnull Node node) {
            node.clearThread();
            owner = Thread.currentThread();
            queue.head.set(node);
        }

        @Override
        public void unlock() {
            if (owner != Thread.currentThread()) {
                throw new IllegalStateException("not the thread holding write lock");
            }
            int c = count.get();
            if (c < WRITER_UNIT) {
                throw new IllegalStateException("no writer");
            }
            if (c > WRITER_UNIT) {
                count.set(c - WRITER_UNIT);
                return;
            }
            // c == WRITER_UNIT
            owner = null;
            // linearization point
            count.set(0);

            // signal successor
            Node node = queue.head.get();
            if (node != null && node.status.get() == Node.STATUS_SIGNAL) {
                node.status.set(Node.STATUS_NORMAL);
                unparkNormalSuccessor(node);
            }
        }
    }

    private boolean isReadyToPark(@Nonnull Node node) {
        Node predecessor = node.predecessor.get();
        int s = predecessor.status.get();
        if (s == Node.STATUS_SIGNAL) {
            return true;
        }
        if (s == Node.STATUS_ABORTED) {
            predecessor = queue.skipAbortedPredecessor(node);
            predecessor.successor.set(node);
        } else {
            predecessor.status.compareAndSet(Node.STATUS_NORMAL, Node.STATUS_SIGNAL);
        }
        return false;
    }

    private void abort(@Nonnull Node node) {
        node.clearThread();

        Node p = queue.skipAbortedPredecessor(node);
        Node ps = p.successor.get();

        node.status.set(Node.STATUS_ABORTED);

        Node t = queue.tail.get();
        if (t == node && queue.tail.compareAndSet(t, p)) {
            p.successor.compareAndSet(ps, null);
            return;
        }

        if (p != queue.head.get() && p.ensureSignalStatus() && p.thread.get() != null) {
            Node s = node.successor.get();
            if (s != null && s.status.get() != Node.STATUS_ABORTED) {
                p.successor.compareAndSet(ps, s);
            }
        } else {
            node.resetSignalStatus();
            unparkNormalSuccessor(node);
        }
    }

    private void unparkNormalSuccessor(@Nonnull Node node) {
        Node successor = queue.findNormalSuccessor(node);
        if (successor != null) {
            LockSupport.unpark(successor.thread.get());
        }
    }

    @SuppressWarnings("Duplicates")
    private static class Queue {
        final AtomicReference<Node> head = new AtomicReference<>();
        final AtomicReference<Node> tail = new AtomicReference<>();

        @Nonnull
        Node enqueue(@Nonnull Node node) {
            Node t;
            while (true) {
                t = tail.get();
                if (t == null) {
                    Node sentinel = new Node();
                    if (head.compareAndSet(null, sentinel)) {
                        tail.set(sentinel);
                    }
                } else {
                    node.predecessor.lazySet(t);
                    if (tail.compareAndSet(t, node)) {
                        t.successor.set(node);
                        return t;
                    }
                }
            }
        }

        @Nullable
        Node findNormalSuccessor(@Nonnull Node node) {
            Node s = node.successor.get();
            if (s != null && s.status.get() != Node.STATUS_ABORTED) {
                return s;
            }

            // find from tail
            s = null;
            Node c = tail.get();
            while (c != null && c != node) {
                if (c.status.get() != Node.STATUS_ABORTED) {
                    s = c;
                }
                c = c.predecessor.get();
            }
            return s;
        }

        @Nonnull
        Node skipAbortedPredecessor(@Nonnull Node node) {
            Node h = head.get();
            Node p = node.predecessor.get();
            while (p != h && p.status.get() == Node.STATUS_ABORTED) {
                p = p.predecessor.get();
                node.predecessor.set(p);
            }
            return p;
        }
    }

    private static class Node {
        static final int STATUS_NORMAL = 0;
        static final int STATUS_SIGNAL = 1;
        static final int STATUS_ABORTED = -1;

        final AtomicReference<Thread> thread;
        final boolean shared;
        final AtomicInteger status = new AtomicInteger(STATUS_NORMAL);
        final AtomicReference<Node> predecessor = new AtomicReference<>();
        // optimization
        final AtomicReference<Node> successor = new AtomicReference<>();

        Node() {
            this(null, false);
        }

        Node(@Nullable Thread thread) {
            this(thread, false);
        }

        Node(@Nullable Thread thread, boolean shared) {
            this.thread = new AtomicReference<>(thread);
            this.shared = shared;
        }

        void clearThread() {
            thread.set(null);
        }

        /**
         * Ensure signal status.
         * If current status is signal, just return.
         * If current status is normal, then try to CAS status from normal to signal.
         *
         * @return true if changed to signal, otherwise false
         */
        boolean ensureSignalStatus() {
            int s = status.get();
            return s == STATUS_SIGNAL || (s == STATUS_NORMAL && status.compareAndSet(STATUS_NORMAL, STATUS_SIGNAL));
        }

        /**
         * Reset signal status.
         * SIGNAL -> NORMAL
         *
         * @return true if successful, otherwise false
         */
        boolean resetSignalStatus() {
            return status.get() == STATUS_SIGNAL && status.compareAndSet(STATUS_SIGNAL, STATUS_NORMAL);
        }
    }
}

总结

“自己写ReentrantLock和ReentrantReadWriteLock“系列至此结束。总得来说,并发编程需要仔细思考很多情况,包括各种细节,有些时候细节会导致你的设计有所变化。希望这个系列对你学习AQS,学习Java语言下的并发编程有用,以及欢迎交流交换意见。