Skip to content

Commit 94770c5

Browse files
AnthonyMDevgh-action-runner
authored andcommitted
Fix data race in AsyncReadWriteLock (apollographql/apollo-ios-dev#862)
1 parent 68aba4b commit 94770c5

File tree

1 file changed

+67
-48
lines changed

1 file changed

+67
-48
lines changed

Sources/Apollo/Internal Utilities/AsyncReadWriteLock.swift

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,9 @@ actor AsyncReadWriteLock {
77
case writing
88
}
99

10-
private final class ReadTask: Sendable {
11-
let task: Task<Void, any Swift.Error>
12-
13-
init(_ body: @Sendable @escaping () async throws -> Void) {
14-
task = Task {
15-
try await body()
16-
}
17-
}
18-
}
19-
20-
private var currentReadTasks: [ObjectIdentifier: ReadTask] = [:]
21-
private var currentWriteTask: Task<Void, any Swift.Error>?
22-
private var queue: [(handle: CheckedContinuation<Void, Never>, isWriter: Bool)] = []
23-
24-
private var state: State {
25-
if currentWriteTask != nil { return .writing }
26-
if !currentReadTasks.isEmpty { return .reading }
27-
return .idle
28-
}
10+
private var state: State = .idle
11+
private var queue: Queue = Queue()
12+
private var runningReadCount: UInt = 0
2913

3014
/// Waits for all current reads/writes to be completed, then calls the provided closure while preventing
3115
/// any other reads/writes from beginning.
@@ -38,17 +22,14 @@ actor AsyncReadWriteLock {
3822
await addToQueueAndWait(isWriter: true)
3923

4024
case .idle:
41-
break
25+
state = .writing
4226
}
4327

44-
defer {
45-
currentWriteTask = nil
46-
writeTaskDidFinish()
47-
}
28+
defer { writeTaskDidFinish() }
29+
4830
let writeTask = Task {
4931
try await body()
5032
}
51-
currentWriteTask = writeTask
5233

5334
try await writeTask.value
5435
}
@@ -69,29 +50,36 @@ actor AsyncReadWriteLock {
6950
if !queue.isEmpty {
7051
await addToQueueAndWait(isWriter: false)
7152
}
53+
7254
case .idle:
73-
break
55+
state = .reading
7456
}
7557

76-
let readTask = ReadTask(body)
77-
let taskID = ObjectIdentifier(readTask)
78-
defer {
79-
currentReadTasks[taskID] = nil
80-
readTaskDidFinish()
58+
runningReadCount += 1
59+
60+
defer { readTaskDidFinish() }
61+
62+
let readTask = Task {
63+
try await body()
8164
}
82-
currentReadTasks[taskID] = readTask
8365

84-
try await readTask.task.value
66+
try await readTask.value
8567
}
8668

8769
private func addToQueueAndWait(isWriter: Bool) async {
8870
await withCheckedContinuation { continuation in
89-
queue.append((handle: continuation, isWriter: isWriter))
71+
switch isWriter {
72+
case true:
73+
queue.addWriteTask(continuation)
74+
case false:
75+
queue.addReadTask(continuation)
76+
}
9077
}
9178
}
9279

9380
private func readTaskDidFinish() {
94-
if state == .idle {
81+
runningReadCount -= 1
82+
if runningReadCount == 0 {
9583
wakeNext()
9684
}
9785
}
@@ -101,25 +89,56 @@ actor AsyncReadWriteLock {
10189
}
10290

10391
private func wakeNext() {
104-
guard !queue.isEmpty else {
92+
guard let nextItem = queue.pop() else {
93+
state = .idle
10594
return
10695
}
10796

108-
let next = queue[0]
109-
next.handle.resume(returning: ())
97+
switch nextItem {
98+
case let .write(continuation):
99+
state = .writing
100+
continuation.resume()
110101

111-
if next.isWriter {
112-
queue.remove(at: 0)
113-
return
102+
case let .readBatch(continuations):
103+
state = .reading
104+
for continuation in continuations {
105+
continuation.resume()
106+
}
107+
}
108+
}
109+
110+
/// MARK: - Queue
111+
private struct Queue {
112+
enum Item {
113+
case write(CheckedContinuation<Void, Never>)
114+
case readBatch([CheckedContinuation<Void, Never>])
115+
}
116+
117+
private var queueItems: [Item] = []
118+
119+
mutating func addWriteTask(_ continuation: CheckedContinuation<Void, Never>) {
120+
queueItems.append(.write(continuation))
121+
}
122+
123+
mutating func addReadTask(_ continuation: CheckedContinuation<Void, Never>) {
124+
if case var .readBatch(batch) = queueItems.first {
125+
batch.append(continuation)
126+
queueItems[0] = .readBatch(batch)
114127

115-
} else {
116-
var lastReader = 0
117-
for i in 1..<queue.count {
118-
guard !queue[i].isWriter else { break }
119-
queue[i].handle.resume(returning: ())
120-
lastReader = i
128+
} else {
129+
queueItems.append(.readBatch([continuation]))
121130
}
122-
queue.removeSubrange(0...lastReader)
131+
}
132+
133+
mutating func pop() -> Item? {
134+
guard !queueItems.isEmpty else {
135+
return nil
136+
}
137+
return queueItems.removeFirst()
138+
}
139+
140+
var isEmpty: Bool {
141+
queueItems.isEmpty
123142
}
124143
}
125144

0 commit comments

Comments
 (0)