Thread Safety#
With pre-emptive multitasking, the OS can interrupt a thread at any time, even in the middle of executing a single Python statement. This can lead to issues like race conditions, where the program’s behavior depends on the unpredictable timing of threads.
Consider the below code:
"""With reference to effective python book chapter 54.
Ref: https://github.com/bslatkin/effectivepython/blob/master/example_code/item_54.py
"""
import logging
import threading
from threading import Barrier
from typing import List
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
NUM_THREADS = 5
BARRIER = Barrier(NUM_THREADS)
class Counter:
def __init__(self) -> None:
self.count = 0
def increment(self, offset: int) -> None:
self.count += offset
def worker(thread_index: int, total_iterations: int, counter: Counter) -> None:
"""The barrier is used to synchronize the threads so that they all start counting
at the same time. This makes it easier to get a race condition since we wait for
the other threads to start else in the loop we always have an order that the
first thread likely starts first and then the second and so on.
"""
BARRIER.wait()
logging.debug("Thread %s, starting", thread_index)
for _ in range(total_iterations):
counter.increment(1)
def thread_unsafe(total_iterations: int) -> None:
counter = Counter()
threads: List[threading.Thread] = []
for index in range(NUM_THREADS):
thread = threading.Thread(target=worker, args=(index, total_iterations, counter))
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
expected = total_iterations * NUM_THREADS
found = counter.count
logging.info("Counter should be %s, got %s", expected, found)
if __name__ == "__main__":
total_iterations = 10**6
thread_unsafe(total_iterations)
There are a total of 5 threads, each incrementing the counter by 1 for \(10^6\) times. So ideally the expected output should be \(5 \times 10^6 = 5,000,000\). However, it could be less, for example, we may get \(2,000,000\) or \(3,000,000\) sometimes.
Why Does This Happen?#
The line counter.increment(1)
seems atomic but is actually a shorthand for:
Read: Retrieve the current value of
counter
:current_value = counter.count
Add: Increment the value by 1:
new_value = current_value + 1
Write: Store the new value back to
counter
:counter.count = new_value
If the OS pre-empts a thread after reading but before writing, another thread might read the same initial value, leading to lost updates.
Let’s pan out a scenario where this can happen:
Thread 1:
thread_1_current_value = counter.count = 0
Thread 2:
thread_2_current_value = counter.count = 0
. This is called a context switch.Thread 2:
thread_2_new_value = thread_2_current_value + 1 = 1
Thread 2:
counter.count = thread_2_new_value = 1
.Thread 1:
thread_1_new_value = thread_1_current_value + 1 = 1
Thread 1:
counter.count = thread_1_new_value = 1
.
So we see thread 2 interrupted thread 1 after thread 1 read the value of
counter
but before it could write the new value back. Then even though the
counter should be 2, it is 1 because thread 1 overwrote it.
Lock#
Python’s threading module offers a suite of tools to prevent issues like data
races and data structure corruption in multithreaded environments. Among these,
the Lock
class stands out as a particularly useful and straightforward option.
It implements a mutual-exclusion lock, commonly known as a mutex.
By incorporating a Lock
into the Counter
class, you can safeguard its
current value from concurrent access by multiple threads. The lock ensures that
only one thread can access the protected data at any given moment. To manage the
lock efficiently, you can utilize Python’s with statement, which handles both
the acquisition and release of the lock automatically.
"""With reference to effective python book chapter 54.
Ref: https://github.com/bslatkin/effectivepython/blob/master/example_code/item_54.py
"""
import logging
import threading
from threading import Barrier
from typing import List
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
NUM_THREADS = 5
BARRIER = Barrier(NUM_THREADS)
class CounterLock:
def __init__(self) -> None:
self.count = 0
self.lock = threading.Lock()
def increment(self, offset: int) -> None:
with self.lock:
self.count += offset
def worker(thread_index: int, total_iterations: int, counter: Counter) -> None:
"""The barrier is used to synchronize the threads so that they all start counting
at the same time. This makes it easier to get a race condition since we wait for
the other threads to start else in the loop we always have an order that the
first thread likely starts first and then the second and so on.
"""
BARRIER.wait()
logging.debug("Thread %s, starting", thread_index)
for _ in range(total_iterations):
counter.increment(1)
def thread_safe(total_iterations: int) -> None:
counter = CounterLock()
threads: List[threading.Thread] = []
for index in range(NUM_THREADS):
thread = threading.Thread(target=worker, args=(index, total_iterations, counter))
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
expected = total_iterations * NUM_THREADS
found = counter.count
logging.info("Counter should be %s, got %s", expected, found)
if __name__ == "__main__":
total_iterations = 10**6
thread_safe(total_iterations)