Ablations#
Barrier#
from __future__ import annotations
import argparse
import os
import torch
import torch.multiprocessing as mp
from rich.pretty import pprint
from omnivault.distributed.core import find_free_port, is_free_port
from omnixamples.distributed.a_basic.a_setup import init_process
from omnixamples.distributed.a_basic.config import get_args_parser
def run_with_no_barrier(local_rank: int, args: argparse.Namespace) -> None:
logger, dist_info_per_process = init_process(local_rank, args=args)
logger.info(f"{dist_info_per_process.model_dump_json(indent=4)}")
results = []
logger.info("I HAVE NO BARRIER DUDE!")
# NOTE: add `torch.distributed.barrier()` here if you want to synchronize all processes
results.append([1, 2, 3])
logger.info(f"Results: {results}")
def run_with_barrier(local_rank: int, args: argparse.Namespace) -> None:
logger, dist_info_per_process = init_process(local_rank, args=args)
logger.info(f"{dist_info_per_process.model_dump_json(indent=4)}")
results = []
# We use barrier to synchronize all processes before computation.
# A barrier acts as a checkpoint in the code. When a process reaches this
# checkpoint, it must wait until all other processes in the group also reach this
# checkpoint.
logger.info("I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE...")
torch.distributed.barrier()
results.append([1, 2, 3])
logger.info(f"Results: {results}")
if __name__ == "__main__":
parser = get_args_parser()
parser.add_argument("--run_with_no_barrier", action="store_true", help="Run with no barrier.")
args = parser.parse_args()
pprint(args)
master_addr, master_port = args.master_addr, args.master_port
if not is_free_port(int(master_port)):
master_port = find_free_port()
os.environ["MASTER_ADDR"] = str(master_addr)
os.environ["MASTER_PORT"] = str(master_port)
target_fn = run_with_no_barrier if args.run_with_no_barrier else run_with_barrier
mp.spawn(
fn=target_fn,
args=(args,),
nprocs=args.nproc_per_node,
join=True,
daemon=False,
start_method="spawn",
) # type: ignore[no-untyped-call]
No Distributed Barrier#
If you run:
python omnixamples/distributed/a_basic/d_ablations.py \
--master_addr=localhost \
--master_port=29500 \
--nnodes=1 \
--nproc_per_node=4 \
--node_rank=0 \
--world_size=4 \
--backend=gloo \
--init_method="env://" \
--run_with_no_barrier
Which is invokes run_with_no_barrier
, we would sometimes see the below:
2024-05-05 13:29:55 [INFO]: I HAVE NO BARRIER DUDE! ablations.py:32
2024-05-05 13:29:55 [INFO]: I HAVE NO BARRIER DUDE! ablations.py:32
2024-05-05 13:29:55 [INFO]: Results: [[1, 2, 3]] ablations.py:36
2024-05-05 13:29:55 [INFO]: I HAVE NO BARRIER DUDE! ablations.py:32
2024-05-05 13:29:55 [INFO]: Results: [[1, 2, 3]] ablations.py:36
2024-05-05 13:29:55 [INFO]: {
"master_addr": "localhost",
"master_port": "29500",
"nnodes": 1,
"nproc_per_node": 4,
"node_rank": 0,
"world_size": 4,
"backend": "gloo",
"init_method": "env://",
"global_rank": 3,
"local_world_size": 4,
"local_rank": 3,
"hostname": "Hongnans-Mac-mini.local",
"process_id": 20647
} ablations.py:28
2024-05-05 13:29:55 [INFO]: Results: [[1, 2, 3]] ablations.py:36
2024-05-05 13:29:55 [INFO]: I HAVE NO BARRIER DUDE! ablations.py:32
2024-05-05 13:29:55 [INFO]: Results: [[1, 2, 3]] ablations.py:36
You see that even when printing the results
is after the
I HAVE NO BARRIER DUDE!
message, the results are printed before the message.
This is because there is no barrier to synchronize the processes. This does not
always happen since in distributed systems, since the underlying distributed
system is asynchronous/concurrent in nature and the order of execution is not
guaranteed. Just think of each process being an independent entity and is
governed by say, the underlying resources (i.e. CPU, memory, etc.) and hence
they may not start at the exact same time. Consequently, this eliminates any
race conditions that may arise where you know, 1 process happens to be faster
than the other.
To resolve this, we can add a torch.distributed.barrier()
to synchronize the
processes before printing the results. Honestly I do not know enough about it to
discuss on a rigorous level, but I think you can have a mental model like below:
We do a point of synchronization at the
torch.distributed.barrier()
. This means when each process reaches this checkpoint, it must wait until all other processes in the group also reach this checkpoint.This means there is a waiting period for all processes to reach the checkpoint before proceeding.
logger.info("I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE...") torch.distributed.barrier()
So with this barrier, we would guarantee that the message I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE… is printed before the results because every process must reach the barrier before proceeding.
Once all processes reach the barrier, they can all simultaneously released to continue to the next block of code.
results.append([1, 2, 3]) logger.info(f"Results: {results}")
Now all our processes will print this message after the barrier message.
With Distributed Barrier#
If you run:
python omnixamples/distributed/a_basic/d_ablations.py \
--master_addr=localhost \
--master_port=29500 \
--nnodes=1 \
--nproc_per_node=4 \
--node_rank=0 \
--world_size=4 \
--backend=gloo \
--init_method="env://"
You would then see, the order of the messages are guaranteed:
INFO 2024-08-04 16:45:34 [INFO]: I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE... d_ablations.py:65
INFO 2024-08-04 16:45:34 [INFO]: I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE... d_ablations.py:65
INFO 2024-08-04 16:45:34 [INFO]: I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE... d_ablations.py:65
INFO 2024-08-04 16:45:34 [INFO]: I HAVE BARRIER DUDE! WAITING FOR ALL PROCESSES TO SYNCHRONIZE... d_ablations.py:65
INFO 2024-08-04 16:45:34 [INFO]: Results: [[1, 2, 3]] d_ablations.py:70
INFO 2024-08-04 16:45:34 [INFO]: Results: [[1, 2, 3]] d_ablations.py:70
INFO 2024-08-04 16:45:34 [INFO]: Results: [[1, 2, 3]] d_ablations.py:70
INFO 2024-08-04 16:45:34 [INFO]: Results: [[1, 2, 3]] d_ablations.py:70