接上篇。在写完ReentrantLock之后,其实可以基于ReentrantLock写一个ReadWriteLock,《the art of multiprocessor programming》第八章有介绍。但是,本着不完全AQS(AbstractQueuedSynchronizer)介绍的系列主题,这里从零开始重新写一个ReentrantReadWriteLock。
按照ReadWriteLock的定义,任何时候都满足
- 没有线程持有锁
- 有1~n个线程持有共享锁(Read)
- 有1个线程持有独占锁(Write)
中的一个。
其次公平的ReadWriteLock要求新来的Read或者Write线程必须在队列中等待,非公平的ReadWriteLock允许新来的Read或者Write比队列中等待的线程先获取锁。关于非公平锁这里多说一句,理论上的非公平锁类似一群人哄抢的现象,但是实现多半是只允许新来和线程队列最前面的线程抢占锁。ReadWriteLock也是一样。如果你想要完全非公平的锁的话,可能AQS和这里的实现不满足你的需求。
为了实现ReadWriteLock的定义,你需要分别记录读写状态。考虑到独占(Write)状态只可能有一个线程,可能场景如下:
| No. | Reader Count | Writer Count |
|---|---|---|
| 1 | 0 | 0 |
| 2 | 1+ | 0 |
| 3 | 0 | 1 |
可以看到,从没有任何线程持有锁的状态1变成2或者3时,reader count和writer count分别变化,这个特性导致你必须同时CAS reader count和writer count,或者混合成一个变量,分别用低位和高位来区别reader count和writer count(或者用正数和负数也可以)。AQS采用的是后者。这里简化一下,并给几个简单的关系式:
count = 00000000 reader_count = count & 0x00FF writer_count = count & 0xFF00
还有一个要考虑的问题是重入(reentrant)如何实现。你可以增加reader count,或者和ReentrantLock一样每个读线程自己维护一个thread local变量。从效率上来说,后者比较好。
SimpleReadWriteLock
只有tryLock的基本实现
import javax.annotation.Nonnull;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
public class SimpleReadWriteLock implements ReadWriteLock {
private static final int WRITER_MASK = 0xFF00;
private static final int WRITER_UNIT = 0x0100;
// writer reentrant times, reader count
private final AtomicInteger count = new AtomicInteger(0);
private final ReadLock readLock = new ReadLock();
private final WriteLock writeLock = new WriteLock();
@Override
@Nonnull
public Lock readLock() {
return readLock;
}
@Override
@Nonnull
public Lock writeLock() {
return writeLock;
}
private static abstract class AbstractLock implements Lock {
@Override
public void lock() {
throw new UnsupportedOperationException();
}
@Override
public void lockInterruptibly() throws InterruptedException {
throw new UnsupportedOperationException();
}
@Override
public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException {
throw new UnsupportedOperationException();
}
@Override
@Nonnull
public Condition newCondition() {
throw new UnsupportedOperationException();
}
}
@SuppressWarnings("Duplicates")
private class ReadLock extends AbstractLock {
// reentrant times of current reader
private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() -> 0);
@Override
public boolean tryLock() {
int rt = reentrantTimes.get();
if (rt > 0) {
reentrantTimes.set(rt + 1);
return true;
}
int c = count.get();
if (((c & WRITER_MASK) == 0) && count.compareAndSet(c, c + 1)) {
reentrantTimes.set(1);
return true;
}
return false;
}
@Override
public void unlock() {
int rt = reentrantTimes.get();
if (rt <= 0) {
throw new IllegalStateException("attempt to unlock without holding lock");
}
if (rt > 1) {
reentrantTimes.set(rt - 1);
return;
}
reentrantTimes.set(0);
if (count.get() < 1) {
throw new IllegalStateException("no reentrantTimes");
}
count.decrementAndGet();
}
}
@SuppressWarnings("Duplicates")
private class WriteLock extends AbstractLock {
private Thread writer;
@Override
public boolean tryLock() {
if (writer == Thread.currentThread()) {
count.getAndAdd(WRITER_UNIT);
return true;
}
if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
writer = Thread.currentThread();
return true;
}
return false;
}
@Override
public void unlock() {
if (writer != Thread.currentThread()) {
throw new IllegalStateException("attempt to unlock without holding lock");
}
int c = count.get();
if (c < WRITER_UNIT) {
throw new IllegalStateException("no writer");
}
if (c == WRITER_UNIT) {
writer = null;
}
count.set(c - WRITER_UNIT);
}
}
}
重点关注其中的CAS操作是如何处理count的。代码这里不具体分析了,相信看过ReentrantLock分析的话,这里也不会太难。
接下来考虑如何加入队列。实现可能有两种选择
- 和Write一样,每个Read作为单独一个节点
- 连续的Read单独作为一个节点
后者看起来不错,因为头节点(这里指Write节点)在唤醒后续节点的时候只需要唤醒一个就行了,否则要唤醒多个连续的Read节点。但是实际上AQS选择的是第一种!原因是什么?个人认为第二种在唤醒上和第一种在本质上没有太大差别:唤醒后续的多个Read节点和唤醒后续的单个Read节点后再唤醒这个Read节点里其他的Read线程。于此同时带来了一系列细节问题,入队时有可能直接入队(最后的节点是Write节点时),也有可能跟随最后的一个Read节点,还要考虑唤醒丢失的问题等等。总之,得不偿失。所以个人的渐进式实现也是用第一种方法。
在唤醒后续节点上,考虑以下几种情况
第一种是只有Write的情况,必然只需要唤醒一个后继节点。第二种只有一个后继Read节点,所以也只要唤醒一个后继节点。但是第三种情况下,首节点可能要唤醒两个后续节点。第四种,Read线程作为首节点,只需要唤醒后续第一个Write节点。第五种情况比较特殊:这里表示连续唤醒的中途。
从图中可以看出,将一个Read线程作为一个节点的话,唤醒必须是多次的。更准确一点来说,如果后继节点是Read节点,那么连续唤醒,否则(Write节点),只唤醒一个。
再考虑一个问题,在哪个步骤唤醒?
- unlock
- 前一个Read节点获取了锁之后
以及谁唤醒谁?
- unlock的节点唤醒后续节点
- Read节点唤醒后续节点
这里其实没有2×2总共4种设计,实际可行的主要就以上两种。
第一种在unlock时唤醒后续的Read节点,注意只唤醒一个。然后Read节点在获取了锁之后唤醒后续的一个Read节点,依此类推。第二种在unlock时唤醒后续所有的Read节点。
AQS选择了第一种。原因是什么?个人认为原因是保证能够获取锁时才唤醒。假设使用第二种设计,第一个Read节点被唤醒,可以获取锁了,但是还没推进队列(acquire和release时队列的处理请参见之前ReentrantLock的介绍),此时W继续唤醒第二个Read节点时,由于第二个Read节点的前置节点,即第一个Read节点还没有成为head,所以只能继续park,等于无效唤醒。但是采用第一种设计的话,第一个Read节点只有在获取了锁之后,设置自己为head,接着唤醒后续Read节点,此时第二个Read节点肯定能够获取锁,所以不会是一次无效唤醒。
注意,第一种方法的图中,节点上方的唤醒和unlock时的唤醒是不同的,上方的唤醒只针对Read节点,也就是说,后继节点是Read节点才会唤醒,相对的,unlock无条件唤醒后续的单个节点。
你可能注意到,上述方法中并没有提到Node的PROPAGATE,即使看上去Read节点的连续唤醒很像传播。原因是ReadWriteLock不需要PROPAGATE,其次PROPAGATE可以说是针对基于AQS的Semaphore的解决方案。这里大致分析一下Semaphore没有PROPAGATE的话可能存在的问题。
对于Semaphore来说,主要有两种操作
- acquire
- release
其中release可能会唤醒多个等待acquire的线程,所以采用的是类似ReadLock,即没有WriteLock的ReadWriteLock模型。acquire方法对应AQS的acquireShared,release方法对应releaseShared。
- 假设一个Semaphore有permits为2,线程1持有1,线程2持有1,线程3等待中,线程4等待中。此时的队列类似上图。head这里不重要,可能是sentinel节点,也有可能是线程1或者2。两个waiter分别是线程3和4
- 线程1使用完permit,通过releaseShared返还。这时head唤醒第一个waiter,即线程3。线程3发现Semaphore中有一个permit,成功获取,同时因为只有1个permit,线程4不决定唤醒线程4(准确来说是tryAcquireShared返回的propagate为0)。此时线程3还未推进队列
- 线程2使用完permit,通过releaseShared返还。这时head没有变化,同时按照ReentrantLock一样的方式,如果唤醒过了后续节点的话,不会再尝试唤醒。即线程3不会被重复唤醒
- 线程3继续执行,没有唤醒后续节点4,同时permits保留1,但是线程4无法获取
这里个人认为一个直接的解决方法是,在唤醒上使用之前的unlock时连续唤醒,同时首节点无条件唤醒后续所有节点(不检查自己的status)。当然这样做的话无效唤醒的可能性会比较大。另一种方法是线程1返还permit时,确保在下一次release之前线程3成为首节点。老实说这比较难。
AQS的解决方法主要是
- 在推进队列时检查head的status(setHeadAndPropagate)
- releaseShared是从0改成PROPAGATE(doReleaseShared)
这样,上述场景的线程3发现head为PROPAGATE,并唤醒后续节点。
由于加入了PROPAGATE,一部分代码需要针对这个新状态做额外处理。但是在ReadWriteLock中,不存在这个问题:Read节点会无条件唤醒后续节点。
最终代码
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReadWriteLock;
@SuppressWarnings("Duplicates")
public class UnfairReadWriteTimedLock3 implements ReadWriteLock {
// private static final int READER_MASK = 0x00FF;
private static final int WRITER_MASK = 0xFF00;
private static final int WRITER_UNIT = 0x0100;
private final AtomicInteger count = new AtomicInteger(0);
private final Queue queue = new Queue();
private final ReadLock readLock = new ReadLock();
private final WriteLock writeLock = new WriteLock();
@Override
@Nonnull
public Lock readLock() {
return readLock;
}
@Override
@Nonnull
public Lock writeLock() {
return writeLock;
}
private static abstract class AbstractLock implements Lock {
@Override
@Nonnull
public Condition newCondition() {
throw new UnsupportedOperationException();
}
}
@SuppressWarnings("Duplicates")
private class ReadLock extends AbstractLock {
private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() -> 0);
@Override
public void lock() {
if (tryLock()) {
return;
}
int c;
final Node node = new Node(Thread.currentThread(), true);
queue.enqueue(node);
while (true) {
if (node.predecessor.get() == queue.head.get()) {
c = count.get();
if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
myTurn(node);
return;
}
}
if (isReadyToPark(node)) {
LockSupport.park(this);
}
}
}
@Override
public void lockInterruptibly() throws InterruptedException {
if (tryLock()) {
return;
}
int c;
final Node node = new Node(Thread.currentThread(), true);
Node predecessor = queue.enqueue(node);
while (true) {
if (predecessor == queue.head.get()) {
c = count.get();
if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
myTurn(node);
return;
}
}
if (isReadyToPark(node)) {
LockSupport.park(this);
}
if (Thread.interrupted()) {
abort(node);
throw new InterruptedException();
}
}
}
@Override
public boolean tryLock() {
int rt = reentrantTimes.get();
if (rt > 0) {
reentrantTimes.set(rt + 1);
return true;
}
int c = count.get();
if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
reentrantTimes.set(1);
return true;
}
return false;
}
@Override
public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException {
if (tryLock()) {
return true;
}
final long deadline = System.nanoTime() + unit.toNanos(time);
final Node node = new Node(Thread.currentThread(), true);
Node predecessor = queue.enqueue(node);
long nanos;
int c;
while (true) {
if (predecessor == queue.head.get()) {
c = count.get();
if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) {
myTurn(node);
return true;
}
}
nanos = deadline - System.nanoTime();
if (nanos <= 0L) {
abort(node);
return false;
}
if (isReadyToPark(node)) {
LockSupport.parkNanos(this, nanos);
}
if (Thread.interrupted()) {
abort(node);
throw new InterruptedException();
}
}
}
private void myTurn(@Nonnull Node node) {
reentrantTimes.set(1);
node.clearThread();
queue.head.set(node);
/*
* propagate if successor is reader
*
* In ReadWriteLock, there's no need to check if propagate, it always propagates.
*/
if (node.resetSignalStatus()) {
Node successor = queue.findNormalSuccessor(node);
if (successor != null && successor.shared) {
LockSupport.unpark(successor.thread.get());
}
}
}
@Override
public void unlock() {
int rt = reentrantTimes.get();
if (rt < 1) {
throw new IllegalStateException("not the thread holding lock");
}
if (rt > 1) {
reentrantTimes.set(rt - 1);
return;
}
// rt == 1
reentrantTimes.set(0);
if (count.get() < 1) {
throw new IllegalStateException("count < 1");
}
if (count.decrementAndGet() > 0) {
return;
}
Node h = queue.head.get();
if (h != null && h.resetSignalStatus()) {
unparkNormalSuccessor(h);
}
}
}
@SuppressWarnings("Duplicates")
private class WriteLock extends AbstractLock {
private Thread owner;
@Override
public void lock() {
if (tryLock()) {
return;
}
Node node = new Node(Thread.currentThread());
Node predecessor = queue.enqueue(node);
while (true) {
if (predecessor == queue.head.get() &&
count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
myTurn(node);
return;
}
if (isReadyToPark(node)) {
LockSupport.park(this);
}
}
}
@Override
public void lockInterruptibly() throws InterruptedException {
if (tryLock()) {
return;
}
Node node = new Node(Thread.currentThread());
Node predecessor = queue.enqueue(node);
while (true) {
if (predecessor == queue.head.get() &&
count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
myTurn(node);
return;
}
if (isReadyToPark(node)) {
LockSupport.park(this);
}
if (Thread.interrupted()) {
abort(node);
throw new InterruptedException();
}
}
}
@Override
public boolean tryLock() {
if (owner == Thread.currentThread()) {
count.getAndAdd(WRITER_UNIT);
return true;
}
if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
owner = Thread.currentThread();
return true;
}
return false;
}
@Override
public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException {
if (tryLock()) {
return true;
}
long deadline = System.nanoTime() + unit.toNanos(time);
Node node = new Node(Thread.currentThread());
Node predecessor = queue.enqueue(node);
long nanos;
while (true) {
if (predecessor == queue.head.get() &&
count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) {
myTurn(node);
return true;
}
nanos = deadline - System.nanoTime();
if (nanos <= 0L) {
abort(node);
return false;
}
if (isReadyToPark(node)) {
LockSupport.parkNanos(this, nanos);
}
if (Thread.interrupted()) {
abort(node);
throw new InterruptedException();
}
}
}
private void myTurn(@Nonnull Node node) {
node.clearThread();
owner = Thread.currentThread();
queue.head.set(node);
}
@Override
public void unlock() {
if (owner != Thread.currentThread()) {
throw new IllegalStateException("not the thread holding write lock");
}
int c = count.get();
if (c < WRITER_UNIT) {
throw new IllegalStateException("no writer");
}
if (c > WRITER_UNIT) {
count.set(c - WRITER_UNIT);
return;
}
// c == WRITER_UNIT
owner = null;
// linearization point
count.set(0);
// signal successor
Node node = queue.head.get();
if (node != null && node.status.get() == Node.STATUS_SIGNAL) {
node.status.set(Node.STATUS_NORMAL);
unparkNormalSuccessor(node);
}
}
}
private boolean isReadyToPark(@Nonnull Node node) {
Node predecessor = node.predecessor.get();
int s = predecessor.status.get();
if (s == Node.STATUS_SIGNAL) {
return true;
}
if (s == Node.STATUS_ABORTED) {
predecessor = queue.skipAbortedPredecessor(node);
predecessor.successor.set(node);
} else {
predecessor.status.compareAndSet(Node.STATUS_NORMAL, Node.STATUS_SIGNAL);
}
return false;
}
private void abort(@Nonnull Node node) {
node.clearThread();
Node p = queue.skipAbortedPredecessor(node);
Node ps = p.successor.get();
node.status.set(Node.STATUS_ABORTED);
Node t = queue.tail.get();
if (t == node && queue.tail.compareAndSet(t, p)) {
p.successor.compareAndSet(ps, null);
return;
}
if (p != queue.head.get() && p.ensureSignalStatus() && p.thread.get() != null) {
Node s = node.successor.get();
if (s != null && s.status.get() != Node.STATUS_ABORTED) {
p.successor.compareAndSet(ps, s);
}
} else {
node.resetSignalStatus();
unparkNormalSuccessor(node);
}
}
private void unparkNormalSuccessor(@Nonnull Node node) {
Node successor = queue.findNormalSuccessor(node);
if (successor != null) {
LockSupport.unpark(successor.thread.get());
}
}
@SuppressWarnings("Duplicates")
private static class Queue {
final AtomicReference<Node> head = new AtomicReference<>();
final AtomicReference<Node> tail = new AtomicReference<>();
@Nonnull
Node enqueue(@Nonnull Node node) {
Node t;
while (true) {
t = tail.get();
if (t == null) {
Node sentinel = new Node();
if (head.compareAndSet(null, sentinel)) {
tail.set(sentinel);
}
} else {
node.predecessor.lazySet(t);
if (tail.compareAndSet(t, node)) {
t.successor.set(node);
return t;
}
}
}
}
@Nullable
Node findNormalSuccessor(@Nonnull Node node) {
Node s = node.successor.get();
if (s != null && s.status.get() != Node.STATUS_ABORTED) {
return s;
}
// find from tail
s = null;
Node c = tail.get();
while (c != null && c != node) {
if (c.status.get() != Node.STATUS_ABORTED) {
s = c;
}
c = c.predecessor.get();
}
return s;
}
@Nonnull
Node skipAbortedPredecessor(@Nonnull Node node) {
Node h = head.get();
Node p = node.predecessor.get();
while (p != h && p.status.get() == Node.STATUS_ABORTED) {
p = p.predecessor.get();
node.predecessor.set(p);
}
return p;
}
}
private static class Node {
static final int STATUS_NORMAL = 0;
static final int STATUS_SIGNAL = 1;
static final int STATUS_ABORTED = -1;
final AtomicReference<Thread> thread;
final boolean shared;
final AtomicInteger status = new AtomicInteger(STATUS_NORMAL);
final AtomicReference<Node> predecessor = new AtomicReference<>();
// optimization
final AtomicReference<Node> successor = new AtomicReference<>();
Node() {
this(null, false);
}
Node(@Nullable Thread thread) {
this(thread, false);
}
Node(@Nullable Thread thread, boolean shared) {
this.thread = new AtomicReference<>(thread);
this.shared = shared;
}
void clearThread() {
thread.set(null);
}
/**
* Ensure signal status.
* If current status is signal, just return.
* If current status is normal, then try to CAS status from normal to signal.
*
* @return true if changed to signal, otherwise false
*/
boolean ensureSignalStatus() {
int s = status.get();
return s == STATUS_SIGNAL || (s == STATUS_NORMAL && status.compareAndSet(STATUS_NORMAL, STATUS_SIGNAL));
}
/**
* Reset signal status.
* SIGNAL -> NORMAL
*
* @return true if successful, otherwise false
*/
boolean resetSignalStatus() {
return status.get() == STATUS_SIGNAL && status.compareAndSet(STATUS_SIGNAL, STATUS_NORMAL);
}
}
}
总结
“自己写ReentrantLock和ReentrantReadWriteLock“系列至此结束。总得来说,并发编程需要仔细思考很多情况,包括各种细节,有些时候细节会导致你的设计有所变化。希望这个系列对你学习AQS,学习Java语言下的并发编程有用,以及欢迎交流交换意见。


