PyTorch’s Event And Profiler#

Twitter Handle LinkedIn Profile GitHub Profile Tag Tag

# %pip install -q omniverse==0.0.63
from typing import Callable
import torch
import logging
import sys
from torch.profiler import profile, ProfilerActivity

import pandas as pd
F = Callable[[torch.Tensor], torch.Tensor]

assert torch.cuda.is_available()
device = torch.device("cuda")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True,
)
logger = logging.getLogger(__name__)

Torch Cuda Event#

Recall in our earlier post that CUDA operations are asynchronous and therefore using timeit naively without synchronization blocks would result in inaccurate measurements. We could use torch.cuda.Event to have more a more precise timing too.

def square_by_multiplication(a: torch.Tensor) -> torch.Tensor:
    return a * a


def square_by_exponentiation(a: torch.Tensor) -> torch.Tensor:
    return a**2


def profile_with_event(func: F, input: torch.Tensor, warmup_steps: int = 5) -> float:
    start = torch.cuda.Event(enable_timing=True)  # Create a start event
    end = torch.cuda.Event(enable_timing=True)  # Create an end event

    logger.info(f"Warmup for {warmup_steps} steps to warm up the GPU")
    for _ in range(warmup_steps):
        func(input)

    start.record()
    func(input)
    end.record()
    torch.cuda.synchronize()  # Synchronize the GPU

    time_spent: float = start.elapsed_time(end)
    return time_spent
x = torch.randn(10000, 10000).to(device)
profile_with_event(square_by_multiplication, x)
2024-08-12 04:20:42,732 - __main__ - INFO - Warmup for 5 steps to warm up the GPU
1.480672001838684

Torch Profiler#

Profiling Square Operation#

Below we refer to Christian Mill’s lecture notes on CUDA MODE.

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as prof:
    torch.square(x)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::square         0.19%      11.000us        76.81%       4.423ms       4.423ms       0.000us         0.00%       1.474ms       1.474ms             1  
                                              aten::pow        65.32%       3.761ms        76.62%       4.412ms       4.412ms       1.474ms       100.00%       1.474ms       1.474ms             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.474ms       100.00%       1.474ms       1.474ms             1  
                                      aten::result_type         0.03%       2.000us         0.03%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us             1  
                                               aten::to         0.02%       1.000us         0.02%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        11.25%     648.000us        11.25%     648.000us     648.000us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize        23.19%       1.335ms        23.19%       1.335ms       1.335ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 5.758ms
Self CUDA time total: 1.474ms
STAGE:2024-08-12 04:40:51 34:34 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-08-12 04:40:51 34:34 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-08-12 04:40:51 34:34 ActivityProfilerController.cpp:322] Completed Stage: Post Processing

First PyTorch C++ API has a foundational tensor and mathematical operation library called ATen, which all other operations are built on top of. Later in the profiling you would see numerous prefixes like aten:: which indicates the operation is from the ATen library.

The table have some key columns, for instance, the aten::square is an operation that squares the input tensor element-wise. The aten::pow operation is used to raise the input tensor to a power element-wise. The void at::native::vectorized_elementwise_kernel operation is a C++ kernel used in PyTorch for optimized elementwise operations on the tensor. And one should also know that squaring can be done in two ways, say given an input \(x\), then we can either do \(x^2\) or \(x*x\).

The Self Cpu % column shows the percentage of total CPU time spent exclusively in this operation, not including time in any called subroutines. What this means is that aten::square uses \(0.17\%\) of the total CPU time by itself with an absolute amount of \(9\) microseconds. Note that the column Self Cpu % sums up to \(100\%\) across all operations in the table.

Perhaps the term exclusively would be clearer if we look at the CPU total % column which represents the total percentage of CPU time spent in this operation and any operations it calls. So aten::square takes up \(74.27\%\) of the CPU time including itself and any functions it invokes. Granted that I am really unsure of the exact operations that aten::square calls because I did not dig further into the source code, it won’t be surprising that it surely calls aten::pow operation and probably the void at::native::vectorized_elementwise_kernel operation. Something like below:

aten::square
└── aten::pow
    └── at::native::vectorized_elementwise_kernel...

And the same concept can be applied to the Self CUDA % and CUDA total %, and it is not surprising that the aten::square operation does not have any direct GPU time recorded for it since the real cuda operation is done in the void at::native::vectorized_elementwise_kernel operation (related to aten::pow).

df = pd.DataFrame(map(vars, prof.key_averages()))
df[['key','self_cpu_time_total','cpu_time_total', 'self_cuda_time_total','cuda_time_total', 'device_type']]
key self_cpu_time_total cpu_time_total self_cuda_time_total cuda_time_total device_type
0 aten::square 11 4423 0 1474 DeviceType.CPU
1 aten::pow 3761 4412 1474 1474 DeviceType.CPU
2 aten::result_type 2 2 0 0 DeviceType.CPU
3 aten::to 1 1 0 0 DeviceType.CPU
4 cudaLaunchKernel 648 648 0 0 DeviceType.CPU
5 void at::native::vectorized_elementwise_kernel... 0 0 1474 1474 DeviceType.CUDA
6 cudaDeviceSynchronize 1335 1335 0 0 DeviceType.CPU
df.sort_values(by="cuda_time_total", ascending=False)
key count node_id is_async is_remote use_device cpu_time_total cuda_time_total privateuse1_time_total self_cpu_time_total ... cuda_memory_usage privateuse1_memory_usage self_cpu_memory_usage self_cuda_memory_usage self_privateuse1_memory_usage cpu_children cpu_parent device_type is_legacy flops
0 aten::square 1 -1 False False None 3974 1473 0 9 ... 0 0 0 0 0 [<FunctionEvent id=2562 name=aten::pow device_... None DeviceType.CPU False 0
1 aten::pow 1 -1 False False None 3965 1473 0 3340 ... 0 0 0 0 0 [<FunctionEvent id=2563 name=aten::result_type... <FunctionEvent id=2561 name=aten::square devic... DeviceType.CPU False 0
5 void at::native::vectorized_elementwise_kernel... 1 -1 False False None 0 1473 0 0 ... 0 0 0 0 0 [] None DeviceType.CUDA False 0
2 aten::result_type 1 -1 False False None 2 0 0 2 ... 0 0 0 0 0 [] <FunctionEvent id=2562 name=aten::pow device_t... DeviceType.CPU False 0
3 aten::to 1 -1 False False None 0 0 0 0 ... 0 0 0 0 0 [] <FunctionEvent id=2562 name=aten::pow device_t... DeviceType.CPU False 0
4 cudaLaunchKernel 1 -1 False False None 623 0 0 623 ... 0 0 0 0 0 [] <FunctionEvent id=2562 name=aten::pow device_t... DeviceType.CPU False 0
6 cudaDeviceSynchronize 1 -1 False False None 1377 0 0 1377 ... 0 0 0 0 0 [] None DeviceType.CPU False 0

7 rows × 26 columns

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as prof:
    square_by_multiplication(x)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::mul        49.96%       2.265ms        70.34%       3.189ms       3.189ms       1.477ms       100.00%       1.477ms       1.477ms             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.477ms       100.00%       1.477ms       1.477ms             1  
                                       cudaLaunchKernel        20.38%     924.000us        20.38%     924.000us     924.000us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize        29.66%       1.345ms        29.66%       1.345ms       1.345ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 4.534ms
Self CUDA time total: 1.477ms
STAGE:2024-08-12 04:43:12 34:34 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-08-12 04:43:12 34:34 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-08-12 04:43:12 34:34 ActivityProfilerController.cpp:322] Completed Stage: Post Processing

Trace#

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    for _ in range(10):
        a = torch.square(torch.randn(10000, 10000).cuda())

prof.export_chrome_trace("default_square_trace.json")
STAGE:2024-08-12 04:44:52 34:34 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-08-12 04:45:05 34:34 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-08-12 04:45:05 34:34 ActivityProfilerController.cpp:322] Completed Stage: Post Processing

We can put this json file to chrome://tracing/ to see a flamegraph like visual.

../../_images/default_square_trace.png

Fig. 38 Trace of the default square operation.#

## With warmup and skip
# Non-default profiler schedule allows user to turn profiler on and off
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):
    print(prof.key_averages().table(
        sort_by="self_cuda_time_total", row_limit=-1))
    prof.export_chrome_trace("non_default_trace_" + str(prof.step_num) + ".json")

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],

    # In this example with wait=1, warmup=1, active=2, repeat=1,
    # profiler will skip the first step/iteration,
    # start warming up on the second, record
    # the third and the forth iterations,
    # after which the trace will become available
    # and on_trace_ready (when set) is called;
    # the cycle repeats starting with the next step

    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=2,
        repeat=1),
    on_trace_ready=trace_handler
    # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
    # used when outputting for tensorboard
    ) as p:
        for iter in range(10):
            torch.square(torch.randn(10000, 10000).cuda())
            # send a signal to the profiler that the next iteration has started
            p.step()
STAGE:2024-08-12 04:47:44 34:34 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-08-12 04:47:47 34:34 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-08-12 04:47:47 34:34 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         0.00%     111.000us         6.95%     171.340ms      85.670ms     170.696ms        98.30%     170.696ms      85.348ms             2  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     170.696ms        98.30%     170.696ms      85.348ms             2  
                                              aten::pow         0.01%     192.000us         0.01%     278.000us     139.000us       2.951ms         1.70%       2.951ms       1.476ms             2  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.951ms         1.70%       2.951ms       1.476ms             2  
                                          ProfilerStep*         2.01%      49.532ms        99.95%        2.464s        1.232s       0.000us         0.00%     173.647ms      86.823ms             2  
                                            aten::randn         0.00%      55.000us        90.97%        2.243s        1.121s       0.000us         0.00%       0.000us       0.000us             2  
                                            aten::empty         0.00%      77.000us         0.00%      77.000us      38.500us       0.000us         0.00%       0.000us       0.000us             2  
                                          aten::normal_        90.96%        2.242s        90.96%        2.242s        1.121s       0.000us         0.00%       0.000us       0.000us             2  
                                               aten::to         0.00%      53.000us         6.96%     171.507ms      42.877ms       0.000us         0.00%     170.696ms      42.674ms             4  
                                         aten::_to_copy         0.00%      60.000us         6.95%     171.454ms      85.727ms       0.000us         0.00%     170.696ms      85.348ms             2  
                                    aten::empty_strided         0.00%      54.000us         0.00%      54.000us      27.000us       0.000us         0.00%       0.000us       0.000us             2  
                                        cudaMemcpyAsync         6.94%     171.179ms         6.94%     171.179ms      85.590ms       0.000us         0.00%       0.000us       0.000us             2  
                                  cudaStreamSynchronize         0.00%      50.000us         0.00%      50.000us      25.000us       0.000us         0.00%       0.000us       0.000us             2  
                                           aten::square         0.00%       9.000us         0.01%     287.000us     143.500us       0.000us         0.00%       2.951ms       1.476ms             2  
                                      aten::result_type         0.00%       3.000us         0.00%       3.000us       1.500us       0.000us         0.00%       0.000us       0.000us             2  
                                       cudaLaunchKernel         0.00%      83.000us         0.00%      83.000us      41.500us       0.000us         0.00%       0.000us       0.000us             2  
                                  cudaDeviceSynchronize         0.05%       1.323ms         0.05%       1.323ms       1.323ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.465s
Self CUDA time total: 173.647ms

References And Further Readings#