接上篇。在写完ReentrantLock之后,其实可以基于ReentrantLock写一个ReadWriteLock,《the art of multiprocessor programming》第八章有介绍。但是,本着不完全AQS(AbstractQueuedSynchronizer)介绍的系列主题,这里从零开始重新写一个ReentrantReadWriteLock。
按照ReadWriteLock的定义,任何时候都满足
- 没有线程持有锁
- 有1~n个线程持有共享锁(Read)
- 有1个线程持有独占锁(Write)
中的一个。
其次公平的ReadWriteLock要求新来的Read或者Write线程必须在队列中等待,非公平的ReadWriteLock允许新来的Read或者Write比队列中等待的线程先获取锁。关于非公平锁这里多说一句,理论上的非公平锁类似一群人哄抢的现象,但是实现多半是只允许新来和线程队列最前面的线程抢占锁。ReadWriteLock也是一样。如果你想要完全非公平的锁的话,可能AQS和这里的实现不满足你的需求。
为了实现ReadWriteLock的定义,你需要分别记录读写状态。考虑到独占(Write)状态只可能有一个线程,可能场景如下:
No. | Reader Count | Writer Count |
---|---|---|
1 | 0 | 0 |
2 | 1+ | 0 |
3 | 0 | 1 |
可以看到,从没有任何线程持有锁的状态1变成2或者3时,reader count和writer count分别变化,这个特性导致你必须同时CAS reader count和writer count,或者混合成一个变量,分别用低位和高位来区别reader count和writer count(或者用正数和负数也可以)。AQS采用的是后者。这里简化一下,并给几个简单的关系式:
count = 00000000 reader_count = count & 0x00FF writer_count = count & 0xFF00
还有一个要考虑的问题是重入(reentrant)如何实现。你可以增加reader count,或者和ReentrantLock一样每个读线程自己维护一个thread local变量。从效率上来说,后者比较好。
SimpleReadWriteLock
只有tryLock的基本实现
import javax.annotation.Nonnull; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; public class SimpleReadWriteLock implements ReadWriteLock { private static final int WRITER_MASK = 0xFF00; private static final int WRITER_UNIT = 0x0100; // writer reentrant times, reader count private final AtomicInteger count = new AtomicInteger(0); private final ReadLock readLock = new ReadLock(); private final WriteLock writeLock = new WriteLock(); @Override @Nonnull public Lock readLock() { return readLock; } @Override @Nonnull public Lock writeLock() { return writeLock; } private static abstract class AbstractLock implements Lock { @Override public void lock() { throw new UnsupportedOperationException(); } @Override public void lockInterruptibly() throws InterruptedException { throw new UnsupportedOperationException(); } @Override public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException { throw new UnsupportedOperationException(); } @Override @Nonnull public Condition newCondition() { throw new UnsupportedOperationException(); } } @SuppressWarnings("Duplicates") private class ReadLock extends AbstractLock { // reentrant times of current reader private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() -> 0); @Override public boolean tryLock() { int rt = reentrantTimes.get(); if (rt > 0) { reentrantTimes.set(rt + 1); return true; } int c = count.get(); if (((c & WRITER_MASK) == 0) && count.compareAndSet(c, c + 1)) { reentrantTimes.set(1); return true; } return false; } @Override public void unlock() { int rt = reentrantTimes.get(); if (rt <= 0) { throw new IllegalStateException("attempt to unlock without holding lock"); } if (rt > 1) { reentrantTimes.set(rt - 1); return; } reentrantTimes.set(0); if (count.get() < 1) { throw new IllegalStateException("no reentrantTimes"); } count.decrementAndGet(); } } @SuppressWarnings("Duplicates") private class WriteLock extends AbstractLock { private Thread writer; @Override public boolean tryLock() { if (writer == Thread.currentThread()) { count.getAndAdd(WRITER_UNIT); return true; } if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { writer = Thread.currentThread(); return true; } return false; } @Override public void unlock() { if (writer != Thread.currentThread()) { throw new IllegalStateException("attempt to unlock without holding lock"); } int c = count.get(); if (c < WRITER_UNIT) { throw new IllegalStateException("no writer"); } if (c == WRITER_UNIT) { writer = null; } count.set(c - WRITER_UNIT); } } }
重点关注其中的CAS操作是如何处理count的。代码这里不具体分析了,相信看过ReentrantLock分析的话,这里也不会太难。
接下来考虑如何加入队列。实现可能有两种选择
- 和Write一样,每个Read作为单独一个节点
- 连续的Read单独作为一个节点
后者看起来不错,因为头节点(这里指Write节点)在唤醒后续节点的时候只需要唤醒一个就行了,否则要唤醒多个连续的Read节点。但是实际上AQS选择的是第一种!原因是什么?个人认为第二种在唤醒上和第一种在本质上没有太大差别:唤醒后续的多个Read节点和唤醒后续的单个Read节点后再唤醒这个Read节点里其他的Read线程。于此同时带来了一系列细节问题,入队时有可能直接入队(最后的节点是Write节点时),也有可能跟随最后的一个Read节点,还要考虑唤醒丢失的问题等等。总之,得不偿失。所以个人的渐进式实现也是用第一种方法。
在唤醒后续节点上,考虑以下几种情况
第一种是只有Write的情况,必然只需要唤醒一个后继节点。第二种只有一个后继Read节点,所以也只要唤醒一个后继节点。但是第三种情况下,首节点可能要唤醒两个后续节点。第四种,Read线程作为首节点,只需要唤醒后续第一个Write节点。第五种情况比较特殊:这里表示连续唤醒的中途。
从图中可以看出,将一个Read线程作为一个节点的话,唤醒必须是多次的。更准确一点来说,如果后继节点是Read节点,那么连续唤醒,否则(Write节点),只唤醒一个。
再考虑一个问题,在哪个步骤唤醒?
- unlock
- 前一个Read节点获取了锁之后
以及谁唤醒谁?
- unlock的节点唤醒后续节点
- Read节点唤醒后续节点
这里其实没有2×2总共4种设计,实际可行的主要就以上两种。
第一种在unlock时唤醒后续的Read节点,注意只唤醒一个。然后Read节点在获取了锁之后唤醒后续的一个Read节点,依此类推。第二种在unlock时唤醒后续所有的Read节点。
AQS选择了第一种。原因是什么?个人认为原因是保证能够获取锁时才唤醒。假设使用第二种设计,第一个Read节点被唤醒,可以获取锁了,但是还没推进队列(acquire和release时队列的处理请参见之前ReentrantLock的介绍),此时W继续唤醒第二个Read节点时,由于第二个Read节点的前置节点,即第一个Read节点还没有成为head,所以只能继续park,等于无效唤醒。但是采用第一种设计的话,第一个Read节点只有在获取了锁之后,设置自己为head,接着唤醒后续Read节点,此时第二个Read节点肯定能够获取锁,所以不会是一次无效唤醒。
注意,第一种方法的图中,节点上方的唤醒和unlock时的唤醒是不同的,上方的唤醒只针对Read节点,也就是说,后继节点是Read节点才会唤醒,相对的,unlock无条件唤醒后续的单个节点。
你可能注意到,上述方法中并没有提到Node的PROPAGATE,即使看上去Read节点的连续唤醒很像传播。原因是ReadWriteLock不需要PROPAGATE,其次PROPAGATE可以说是针对基于AQS的Semaphore的解决方案。这里大致分析一下Semaphore没有PROPAGATE的话可能存在的问题。
对于Semaphore来说,主要有两种操作
- acquire
- release
其中release可能会唤醒多个等待acquire的线程,所以采用的是类似ReadLock,即没有WriteLock的ReadWriteLock模型。acquire方法对应AQS的acquireShared,release方法对应releaseShared。
- 假设一个Semaphore有permits为2,线程1持有1,线程2持有1,线程3等待中,线程4等待中。此时的队列类似上图。head这里不重要,可能是sentinel节点,也有可能是线程1或者2。两个waiter分别是线程3和4
- 线程1使用完permit,通过releaseShared返还。这时head唤醒第一个waiter,即线程3。线程3发现Semaphore中有一个permit,成功获取,同时因为只有1个permit,线程4不决定唤醒线程4(准确来说是tryAcquireShared返回的propagate为0)。此时线程3还未推进队列
- 线程2使用完permit,通过releaseShared返还。这时head没有变化,同时按照ReentrantLock一样的方式,如果唤醒过了后续节点的话,不会再尝试唤醒。即线程3不会被重复唤醒
- 线程3继续执行,没有唤醒后续节点4,同时permits保留1,但是线程4无法获取
这里个人认为一个直接的解决方法是,在唤醒上使用之前的unlock时连续唤醒,同时首节点无条件唤醒后续所有节点(不检查自己的status)。当然这样做的话无效唤醒的可能性会比较大。另一种方法是线程1返还permit时,确保在下一次release之前线程3成为首节点。老实说这比较难。
AQS的解决方法主要是
- 在推进队列时检查head的status(setHeadAndPropagate)
- releaseShared是从0改成PROPAGATE(doReleaseShared)
这样,上述场景的线程3发现head为PROPAGATE,并唤醒后续节点。
由于加入了PROPAGATE,一部分代码需要针对这个新状态做额外处理。但是在ReadWriteLock中,不存在这个问题:Read节点会无条件唤醒后续节点。
最终代码
import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReadWriteLock; @SuppressWarnings("Duplicates") public class UnfairReadWriteTimedLock3 implements ReadWriteLock { // private static final int READER_MASK = 0x00FF; private static final int WRITER_MASK = 0xFF00; private static final int WRITER_UNIT = 0x0100; private final AtomicInteger count = new AtomicInteger(0); private final Queue queue = new Queue(); private final ReadLock readLock = new ReadLock(); private final WriteLock writeLock = new WriteLock(); @Override @Nonnull public Lock readLock() { return readLock; } @Override @Nonnull public Lock writeLock() { return writeLock; } private static abstract class AbstractLock implements Lock { @Override @Nonnull public Condition newCondition() { throw new UnsupportedOperationException(); } } @SuppressWarnings("Duplicates") private class ReadLock extends AbstractLock { private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() -> 0); @Override public void lock() { if (tryLock()) { return; } int c; final Node node = new Node(Thread.currentThread(), true); queue.enqueue(node); while (true) { if (node.predecessor.get() == queue.head.get()) { c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { myTurn(node); return; } } if (isReadyToPark(node)) { LockSupport.park(this); } } } @Override public void lockInterruptibly() throws InterruptedException { if (tryLock()) { return; } int c; final Node node = new Node(Thread.currentThread(), true); Node predecessor = queue.enqueue(node); while (true) { if (predecessor == queue.head.get()) { c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { myTurn(node); return; } } if (isReadyToPark(node)) { LockSupport.park(this); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } @Override public boolean tryLock() { int rt = reentrantTimes.get(); if (rt > 0) { reentrantTimes.set(rt + 1); return true; } int c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { reentrantTimes.set(1); return true; } return false; } @Override public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException { if (tryLock()) { return true; } final long deadline = System.nanoTime() + unit.toNanos(time); final Node node = new Node(Thread.currentThread(), true); Node predecessor = queue.enqueue(node); long nanos; int c; while (true) { if (predecessor == queue.head.get()) { c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { myTurn(node); return true; } } nanos = deadline - System.nanoTime(); if (nanos <= 0L) { abort(node); return false; } if (isReadyToPark(node)) { LockSupport.parkNanos(this, nanos); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } private void myTurn(@Nonnull Node node) { reentrantTimes.set(1); node.clearThread(); queue.head.set(node); /* * propagate if successor is reader * * In ReadWriteLock, there's no need to check if propagate, it always propagates. */ if (node.resetSignalStatus()) { Node successor = queue.findNormalSuccessor(node); if (successor != null && successor.shared) { LockSupport.unpark(successor.thread.get()); } } } @Override public void unlock() { int rt = reentrantTimes.get(); if (rt < 1) { throw new IllegalStateException("not the thread holding lock"); } if (rt > 1) { reentrantTimes.set(rt - 1); return; } // rt == 1 reentrantTimes.set(0); if (count.get() < 1) { throw new IllegalStateException("count < 1"); } if (count.decrementAndGet() > 0) { return; } Node h = queue.head.get(); if (h != null && h.resetSignalStatus()) { unparkNormalSuccessor(h); } } } @SuppressWarnings("Duplicates") private class WriteLock extends AbstractLock { private Thread owner; @Override public void lock() { if (tryLock()) { return; } Node node = new Node(Thread.currentThread()); Node predecessor = queue.enqueue(node); while (true) { if (predecessor == queue.head.get() && count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { myTurn(node); return; } if (isReadyToPark(node)) { LockSupport.park(this); } } } @Override public void lockInterruptibly() throws InterruptedException { if (tryLock()) { return; } Node node = new Node(Thread.currentThread()); Node predecessor = queue.enqueue(node); while (true) { if (predecessor == queue.head.get() && count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { myTurn(node); return; } if (isReadyToPark(node)) { LockSupport.park(this); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } @Override public boolean tryLock() { if (owner == Thread.currentThread()) { count.getAndAdd(WRITER_UNIT); return true; } if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { owner = Thread.currentThread(); return true; } return false; } @Override public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException { if (tryLock()) { return true; } long deadline = System.nanoTime() + unit.toNanos(time); Node node = new Node(Thread.currentThread()); Node predecessor = queue.enqueue(node); long nanos; while (true) { if (predecessor == queue.head.get() && count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { myTurn(node); return true; } nanos = deadline - System.nanoTime(); if (nanos <= 0L) { abort(node); return false; } if (isReadyToPark(node)) { LockSupport.parkNanos(this, nanos); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } private void myTurn(@Nonnull Node node) { node.clearThread(); owner = Thread.currentThread(); queue.head.set(node); } @Override public void unlock() { if (owner != Thread.currentThread()) { throw new IllegalStateException("not the thread holding write lock"); } int c = count.get(); if (c < WRITER_UNIT) { throw new IllegalStateException("no writer"); } if (c > WRITER_UNIT) { count.set(c - WRITER_UNIT); return; } // c == WRITER_UNIT owner = null; // linearization point count.set(0); // signal successor Node node = queue.head.get(); if (node != null && node.status.get() == Node.STATUS_SIGNAL) { node.status.set(Node.STATUS_NORMAL); unparkNormalSuccessor(node); } } } private boolean isReadyToPark(@Nonnull Node node) { Node predecessor = node.predecessor.get(); int s = predecessor.status.get(); if (s == Node.STATUS_SIGNAL) { return true; } if (s == Node.STATUS_ABORTED) { predecessor = queue.skipAbortedPredecessor(node); predecessor.successor.set(node); } else { predecessor.status.compareAndSet(Node.STATUS_NORMAL, Node.STATUS_SIGNAL); } return false; } private void abort(@Nonnull Node node) { node.clearThread(); Node p = queue.skipAbortedPredecessor(node); Node ps = p.successor.get(); node.status.set(Node.STATUS_ABORTED); Node t = queue.tail.get(); if (t == node && queue.tail.compareAndSet(t, p)) { p.successor.compareAndSet(ps, null); return; } if (p != queue.head.get() && p.ensureSignalStatus() && p.thread.get() != null) { Node s = node.successor.get(); if (s != null && s.status.get() != Node.STATUS_ABORTED) { p.successor.compareAndSet(ps, s); } } else { node.resetSignalStatus(); unparkNormalSuccessor(node); } } private void unparkNormalSuccessor(@Nonnull Node node) { Node successor = queue.findNormalSuccessor(node); if (successor != null) { LockSupport.unpark(successor.thread.get()); } } @SuppressWarnings("Duplicates") private static class Queue { final AtomicReference<Node> head = new AtomicReference<>(); final AtomicReference<Node> tail = new AtomicReference<>(); @Nonnull Node enqueue(@Nonnull Node node) { Node t; while (true) { t = tail.get(); if (t == null) { Node sentinel = new Node(); if (head.compareAndSet(null, sentinel)) { tail.set(sentinel); } } else { node.predecessor.lazySet(t); if (tail.compareAndSet(t, node)) { t.successor.set(node); return t; } } } } @Nullable Node findNormalSuccessor(@Nonnull Node node) { Node s = node.successor.get(); if (s != null && s.status.get() != Node.STATUS_ABORTED) { return s; } // find from tail s = null; Node c = tail.get(); while (c != null && c != node) { if (c.status.get() != Node.STATUS_ABORTED) { s = c; } c = c.predecessor.get(); } return s; } @Nonnull Node skipAbortedPredecessor(@Nonnull Node node) { Node h = head.get(); Node p = node.predecessor.get(); while (p != h && p.status.get() == Node.STATUS_ABORTED) { p = p.predecessor.get(); node.predecessor.set(p); } return p; } } private static class Node { static final int STATUS_NORMAL = 0; static final int STATUS_SIGNAL = 1; static final int STATUS_ABORTED = -1; final AtomicReference<Thread> thread; final boolean shared; final AtomicInteger status = new AtomicInteger(STATUS_NORMAL); final AtomicReference<Node> predecessor = new AtomicReference<>(); // optimization final AtomicReference<Node> successor = new AtomicReference<>(); Node() { this(null, false); } Node(@Nullable Thread thread) { this(thread, false); } Node(@Nullable Thread thread, boolean shared) { this.thread = new AtomicReference<>(thread); this.shared = shared; } void clearThread() { thread.set(null); } /** * Ensure signal status. * If current status is signal, just return. * If current status is normal, then try to CAS status from normal to signal. * * @return true if changed to signal, otherwise false */ boolean ensureSignalStatus() { int s = status.get(); return s == STATUS_SIGNAL || (s == STATUS_NORMAL && status.compareAndSet(STATUS_NORMAL, STATUS_SIGNAL)); } /** * Reset signal status. * SIGNAL -> NORMAL * * @return true if successful, otherwise false */ boolean resetSignalStatus() { return status.get() == STATUS_SIGNAL && status.compareAndSet(STATUS_SIGNAL, STATUS_NORMAL); } } }
总结
“自己写ReentrantLock和ReentrantReadWriteLock“系列至此结束。总得来说,并发编程需要仔细思考很多情况,包括各种细节,有些时候细节会导致你的设计有所变化。希望这个系列对你学习AQS,学习Java语言下的并发编程有用,以及欢迎交流交换意见。