接上篇。在写完ReentrantLock之后,其实可以基于ReentrantLock写一个ReadWriteLock,《the art of multiprocessor programming》第八章有介绍。但是,本着不完全AQS(AbstractQueuedSynchronizer)介绍的系列主题,这里从零开始重新写一个ReentrantReadWriteLock。
按照ReadWriteLock的定义,任何时候都满足
- 没有线程持有锁
- 有1~n个线程持有共享锁(Read)
- 有1个线程持有独占锁(Write)
中的一个。
其次公平的ReadWriteLock要求新来的Read或者Write线程必须在队列中等待,非公平的ReadWriteLock允许新来的Read或者Write比队列中等待的线程先获取锁。关于非公平锁这里多说一句,理论上的非公平锁类似一群人哄抢的现象,但是实现多半是只允许新来和线程队列最前面的线程抢占锁。ReadWriteLock也是一样。如果你想要完全非公平的锁的话,可能AQS和这里的实现不满足你的需求。
为了实现ReadWriteLock的定义,你需要分别记录读写状态。考虑到独占(Write)状态只可能有一个线程,可能场景如下:
No. | Reader Count | Writer Count |
---|---|---|
1 | 0 | 0 |
2 | 1+ | 0 |
3 | 0 | 1 |
可以看到,从没有任何线程持有锁的状态1变成2或者3时,reader count和writer count分别变化,这个特性导致你必须同时CAS reader count和writer count,或者混合成一个变量,分别用低位和高位来区别reader count和writer count(或者用正数和负数也可以)。AQS采用的是后者。这里简化一下,并给几个简单的关系式:
1 2 3 |
count = 00000000 reader_count = count & 0x00FF writer_count = count & 0xFF00 |
还有一个要考虑的问题是重入(reentrant)如何实现。你可以增加reader count,或者和ReentrantLock一样每个读线程自己维护一个thread local变量。从效率上来说,后者比较好。
SimpleReadWriteLock
只有tryLock的基本实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import javax.annotation.Nonnull; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; public class SimpleReadWriteLock implements ReadWriteLock { private static final int WRITER_MASK = 0xFF00; private static final int WRITER_UNIT = 0x0100; // writer reentrant times, reader count private final AtomicInteger count = new AtomicInteger(0); private final ReadLock readLock = new ReadLock(); private final WriteLock writeLock = new WriteLock(); @Override @Nonnull public Lock readLock() { return readLock; } @Override @Nonnull public Lock writeLock() { return writeLock; } private static abstract class AbstractLock implements Lock { @Override public void lock() { throw new UnsupportedOperationException(); } @Override public void lockInterruptibly() throws InterruptedException { throw new UnsupportedOperationException(); } @Override public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException { throw new UnsupportedOperationException(); } @Override @Nonnull public Condition newCondition() { throw new UnsupportedOperationException(); } } @SuppressWarnings(“Duplicates”) private class ReadLock extends AbstractLock { // reentrant times of current reader private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() –> 0); @Override public boolean tryLock() { int rt = reentrantTimes.get(); if (rt > 0) { reentrantTimes.set(rt + 1); return true; } int c = count.get(); if (((c & WRITER_MASK) == 0) && count.compareAndSet(c, c + 1)) { reentrantTimes.set(1); return true; } return false; } @Override public void unlock() { int rt = reentrantTimes.get(); if (rt <= 0) { throw new IllegalStateException(“attempt to unlock without holding lock”); } if (rt > 1) { reentrantTimes.set(rt – 1); return; } reentrantTimes.set(0); if (count.get() < 1) { throw new IllegalStateException(“no reentrantTimes”); } count.decrementAndGet(); } } @SuppressWarnings(“Duplicates”) private class WriteLock extends AbstractLock { private Thread writer; @Override public boolean tryLock() { if (writer == Thread.currentThread()) { count.getAndAdd(WRITER_UNIT); return true; } if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { writer = Thread.currentThread(); return true; } return false; } @Override public void unlock() { if (writer != Thread.currentThread()) { throw new IllegalStateException(“attempt to unlock without holding lock”); } int c = count.get(); if (c < WRITER_UNIT) { throw new IllegalStateException(“no writer”); } if (c == WRITER_UNIT) { writer = null; } count.set(c – WRITER_UNIT); } } } |
重点关注其中的CAS操作是如何处理count的。代码这里不具体分析了,相信看过ReentrantLock分析的话,这里也不会太难。
接下来考虑如何加入队列。实现可能有两种选择
- 和Write一样,每个Read作为单独一个节点
- 连续的Read单独作为一个节点
后者看起来不错,因为头节点(这里指Write节点)在唤醒后续节点的时候只需要唤醒一个就行了,否则要唤醒多个连续的Read节点。但是实际上AQS选择的是第一种!原因是什么?个人认为第二种在唤醒上和第一种在本质上没有太大差别:唤醒后续的多个Read节点和唤醒后续的单个Read节点后再唤醒这个Read节点里其他的Read线程。于此同时带来了一系列细节问题,入队时有可能直接入队(最后的节点是Write节点时),也有可能跟随最后的一个Read节点,还要考虑唤醒丢失的问题等等。总之,得不偿失。所以个人的渐进式实现也是用第一种方法。
在唤醒后续节点上,考虑以下几种情况
第一种是只有Write的情况,必然只需要唤醒一个后继节点。第二种只有一个后继Read节点,所以也只要唤醒一个后继节点。但是第三种情况下,首节点可能要唤醒两个后续节点。第四种,Read线程作为首节点,只需要唤醒后续第一个Write节点。第五种情况比较特殊:这里表示连续唤醒的中途。
从图中可以看出,将一个Read线程作为一个节点的话,唤醒必须是多次的。更准确一点来说,如果后继节点是Read节点,那么连续唤醒,否则(Write节点),只唤醒一个。
再考虑一个问题,在哪个步骤唤醒?
- unlock
- 前一个Read节点获取了锁之后
以及谁唤醒谁?
- unlock的节点唤醒后续节点
- Read节点唤醒后续节点
这里其实没有2×2总共4种设计,实际可行的主要就以上两种。
第一种在unlock时唤醒后续的Read节点,注意只唤醒一个。然后Read节点在获取了锁之后唤醒后续的一个Read节点,依此类推。第二种在unlock时唤醒后续所有的Read节点。
AQS选择了第一种。原因是什么?个人认为原因是保证能够获取锁时才唤醒。假设使用第二种设计,第一个Read节点被唤醒,可以获取锁了,但是还没推进队列(acquire和release时队列的处理请参见之前ReentrantLock的介绍),此时W继续唤醒第二个Read节点时,由于第二个Read节点的前置节点,即第一个Read节点还没有成为head,所以只能继续park,等于无效唤醒。但是采用第一种设计的话,第一个Read节点只有在获取了锁之后,设置自己为head,接着唤醒后续Read节点,此时第二个Read节点肯定能够获取锁,所以不会是一次无效唤醒。
注意,第一种方法的图中,节点上方的唤醒和unlock时的唤醒是不同的,上方的唤醒只针对Read节点,也就是说,后继节点是Read节点才会唤醒,相对的,unlock无条件唤醒后续的单个节点。
你可能注意到,上述方法中并没有提到Node的PROPAGATE,即使看上去Read节点的连续唤醒很像传播。原因是ReadWriteLock不需要PROPAGATE,其次PROPAGATE可以说是针对基于AQS的Semaphore的解决方案。这里大致分析一下Semaphore没有PROPAGATE的话可能存在的问题。
对于Semaphore来说,主要有两种操作
- acquire
- release
其中release可能会唤醒多个等待acquire的线程,所以采用的是类似ReadLock,即没有WriteLock的ReadWriteLock模型。acquire方法对应AQS的acquireShared,release方法对应releaseShared。
- 假设一个Semaphore有permits为2,线程1持有1,线程2持有1,线程3等待中,线程4等待中。此时的队列类似上图。head这里不重要,可能是sentinel节点,也有可能是线程1或者2。两个waiter分别是线程3和4
- 线程1使用完permit,通过releaseShared返还。这时head唤醒第一个waiter,即线程3。线程3发现Semaphore中有一个permit,成功获取,同时因为只有1个permit,线程4不决定唤醒线程4(准确来说是tryAcquireShared返回的propagate为0)。此时线程3还未推进队列
- 线程2使用完permit,通过releaseShared返还。这时head没有变化,同时按照ReentrantLock一样的方式,如果唤醒过了后续节点的话,不会再尝试唤醒。即线程3不会被重复唤醒
- 线程3继续执行,没有唤醒后续节点4,同时permits保留1,但是线程4无法获取
这里个人认为一个直接的解决方法是,在唤醒上使用之前的unlock时连续唤醒,同时首节点无条件唤醒后续所有节点(不检查自己的status)。当然这样做的话无效唤醒的可能性会比较大。另一种方法是线程1返还permit时,确保在下一次release之前线程3成为首节点。老实说这比较难。
AQS的解决方法主要是
- 在推进队列时检查head的status(setHeadAndPropagate)
- releaseShared是从0改成PROPAGATE(doReleaseShared)
这样,上述场景的线程3发现head为PROPAGATE,并唤醒后续节点。
由于加入了PROPAGATE,一部分代码需要针对这个新状态做额外处理。但是在ReadWriteLock中,不存在这个问题:Read节点会无条件唤醒后续节点。
最终代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 |
import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReadWriteLock; @SuppressWarnings(“Duplicates”) public class UnfairReadWriteTimedLock3 implements ReadWriteLock { // private static final int READER_MASK = 0x00FF; private static final int WRITER_MASK = 0xFF00; private static final int WRITER_UNIT = 0x0100; private final AtomicInteger count = new AtomicInteger(0); private final Queue queue = new Queue(); private final ReadLock readLock = new ReadLock(); private final WriteLock writeLock = new WriteLock(); @Override @Nonnull public Lock readLock() { return readLock; } @Override @Nonnull public Lock writeLock() { return writeLock; } private static abstract class AbstractLock implements Lock { @Override @Nonnull public Condition newCondition() { throw new UnsupportedOperationException(); } } @SuppressWarnings(“Duplicates”) private class ReadLock extends AbstractLock { private final ThreadLocal<Integer> reentrantTimes = ThreadLocal.withInitial(() –> 0); @Override public void lock() { if (tryLock()) { return; } int c; final Node node = new Node(Thread.currentThread(), true); queue.enqueue(node); while (true) { if (node.predecessor.get() == queue.head.get()) { c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { myTurn(node); return; } } if (isReadyToPark(node)) { LockSupport.park(this); } } } @Override public void lockInterruptibly() throws InterruptedException { if (tryLock()) { return; } int c; final Node node = new Node(Thread.currentThread(), true); Node predecessor = queue.enqueue(node); while (true) { if (predecessor == queue.head.get()) { c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { myTurn(node); return; } } if (isReadyToPark(node)) { LockSupport.park(this); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } @Override public boolean tryLock() { int rt = reentrantTimes.get(); if (rt > 0) { reentrantTimes.set(rt + 1); return true; } int c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { reentrantTimes.set(1); return true; } return false; } @Override public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException { if (tryLock()) { return true; } final long deadline = System.nanoTime() + unit.toNanos(time); final Node node = new Node(Thread.currentThread(), true); Node predecessor = queue.enqueue(node); long nanos; int c; while (true) { if (predecessor == queue.head.get()) { c = count.get(); if ((c & WRITER_MASK) == 0 && count.compareAndSet(c, c + 1)) { myTurn(node); return true; } } nanos = deadline – System.nanoTime(); if (nanos <= 0L) { abort(node); return false; } if (isReadyToPark(node)) { LockSupport.parkNanos(this, nanos); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } private void myTurn(@Nonnull Node node) { reentrantTimes.set(1); node.clearThread(); queue.head.set(node); /* * propagate if successor is reader * * In ReadWriteLock, there’s no need to check if propagate, it always propagates. */ if (node.resetSignalStatus()) { Node successor = queue.findNormalSuccessor(node); if (successor != null && successor.shared) { LockSupport.unpark(successor.thread.get()); } } } @Override public void unlock() { int rt = reentrantTimes.get(); if (rt < 1) { throw new IllegalStateException(“not the thread holding lock”); } if (rt > 1) { reentrantTimes.set(rt – 1); return; } // rt == 1 reentrantTimes.set(0); if (count.get() < 1) { throw new IllegalStateException(“count < 1”); } if (count.decrementAndGet() > 0) { return; } Node h = queue.head.get(); if (h != null && h.resetSignalStatus()) { unparkNormalSuccessor(h); } } } @SuppressWarnings(“Duplicates”) private class WriteLock extends AbstractLock { private Thread owner; @Override public void lock() { if (tryLock()) { return; } Node node = new Node(Thread.currentThread()); Node predecessor = queue.enqueue(node); while (true) { if (predecessor == queue.head.get() && count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { myTurn(node); return; } if (isReadyToPark(node)) { LockSupport.park(this); } } } @Override public void lockInterruptibly() throws InterruptedException { if (tryLock()) { return; } Node node = new Node(Thread.currentThread()); Node predecessor = queue.enqueue(node); while (true) { if (predecessor == queue.head.get() && count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { myTurn(node); return; } if (isReadyToPark(node)) { LockSupport.park(this); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } @Override public boolean tryLock() { if (owner == Thread.currentThread()) { count.getAndAdd(WRITER_UNIT); return true; } if (count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { owner = Thread.currentThread(); return true; } return false; } @Override public boolean tryLock(long time, @Nonnull TimeUnit unit) throws InterruptedException { if (tryLock()) { return true; } long deadline = System.nanoTime() + unit.toNanos(time); Node node = new Node(Thread.currentThread()); Node predecessor = queue.enqueue(node); long nanos; while (true) { if (predecessor == queue.head.get() && count.get() == 0 && count.compareAndSet(0, WRITER_UNIT)) { myTurn(node); return true; } nanos = deadline – System.nanoTime(); if (nanos <= 0L) { abort(node); return false; } if (isReadyToPark(node)) { LockSupport.parkNanos(this, nanos); } if (Thread.interrupted()) { abort(node); throw new InterruptedException(); } } } private void myTurn(@Nonnull Node node) { node.clearThread(); owner = Thread.currentThread(); queue.head.set(node); } @Override public void unlock() { if (owner != Thread.currentThread()) { throw new IllegalStateException(“not the thread holding write lock”); } int c = count.get(); if (c < WRITER_UNIT) { throw new IllegalStateException(“no writer”); } if (c > WRITER_UNIT) { count.set(c – WRITER_UNIT); return; } // c == WRITER_UNIT owner = null; // linearization point count.set(0); // signal successor Node node = queue.head.get(); if (node != null && node.status.get() == Node.STATUS_SIGNAL) { node.status.set(Node.STATUS_NORMAL); unparkNormalSuccessor(node); } } } private boolean isReadyToPark(@Nonnull Node node) { Node predecessor = node.predecessor.get(); int s = predecessor.status.get(); if (s == Node.STATUS_SIGNAL) { return true; } if (s == Node.STATUS_ABORTED) { predecessor = queue.skipAbortedPredecessor(node); predecessor.successor.set(node); } else { predecessor.status.compareAndSet(Node.STATUS_NORMAL, Node.STATUS_SIGNAL); } return false; } private void abort(@Nonnull Node node) { node.clearThread(); Node p = queue.skipAbortedPredecessor(node); Node ps = p.successor.get(); node.status.set(Node.STATUS_ABORTED); Node t = queue.tail.get(); if (t == node && queue.tail.compareAndSet(t, p)) { p.successor.compareAndSet(ps, null); return; } if (p != queue.head.get() && p.ensureSignalStatus() && p.thread.get() != null) { Node s = node.successor.get(); if (s != null && s.status.get() != Node.STATUS_ABORTED) { p.successor.compareAndSet(ps, s); } } else { node.resetSignalStatus(); unparkNormalSuccessor(node); } } private void unparkNormalSuccessor(@Nonnull Node node) { Node successor = queue.findNormalSuccessor(node); if (successor != null) { LockSupport.unpark(successor.thread.get()); } } @SuppressWarnings(“Duplicates”) private static class Queue { final AtomicReference<Node> head = new AtomicReference<>(); final AtomicReference<Node> tail = new AtomicReference<>(); @Nonnull Node enqueue(@Nonnull Node node) { Node t; while (true) { t = tail.get(); if (t == null) { Node sentinel = new Node(); if (head.compareAndSet(null, sentinel)) { tail.set(sentinel); } } else { node.predecessor.lazySet(t); if (tail.compareAndSet(t, node)) { t.successor.set(node); return t; } } } } @Nullable Node findNormalSuccessor(@Nonnull Node node) { Node s = node.successor.get(); if (s != null && s.status.get() != Node.STATUS_ABORTED) { return s; } // find from tail s = null; Node c = tail.get(); while (c != null && c != node) { if (c.status.get() != Node.STATUS_ABORTED) { s = c; } c = c.predecessor.get(); } return s; } @Nonnull Node skipAbortedPredecessor(@Nonnull Node node) { Node h = head.get(); Node p = node.predecessor.get(); while (p != h && p.status.get() == Node.STATUS_ABORTED) { p = p.predecessor.get(); node.predecessor.set(p); } return p; } } private static class Node { static final int STATUS_NORMAL = 0; static final int STATUS_SIGNAL = 1; static final int STATUS_ABORTED = –1; final AtomicReference<Thread> thread; final boolean shared; final AtomicInteger status = new AtomicInteger(STATUS_NORMAL); final AtomicReference<Node> predecessor = new AtomicReference<>(); // optimization final AtomicReference<Node> successor = new AtomicReference<>(); Node() { this(null, false); } Node(@Nullable Thread thread) { this(thread, false); } Node(@Nullable Thread thread, boolean shared) { this.thread = new AtomicReference<>(thread); this.shared = shared; } void clearThread() { thread.set(null); } /** * Ensure signal status. * If current status is signal, just return. * If current status is normal, then try to CAS status from normal to signal. * * @return true if changed to signal, otherwise false */ boolean ensureSignalStatus() { int s = status.get(); return s == STATUS_SIGNAL || (s == STATUS_NORMAL && status.compareAndSet(STATUS_NORMAL, STATUS_SIGNAL)); } /** * Reset signal status. * SIGNAL -> NORMAL * * @return true if successful, otherwise false */ boolean resetSignalStatus() { return status.get() == STATUS_SIGNAL && status.compareAndSet(STATUS_SIGNAL, STATUS_NORMAL); } } } |
总结
“自己写ReentrantLock和ReentrantReadWriteLock“系列至此结束。总得来说,并发编程需要仔细思考很多情况,包括各种细节,有些时候细节会导致你的设计有所变化。希望这个系列对你学习AQS,学习Java语言下的并发编程有用,以及欢迎交流交换意见。