Profile GPT Small Time And Memory

Profile GPT Small Time And Memory#

Twitter Handle LinkedIn Profile GitHub Profile Tag Tag

%pip install -q omniverse==0.0.63
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 24.6.1 requires cubinlinker, which is not installed.
cudf 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 24.6.1 requires ptxcompiler, which is not installed.
cuml 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.8 which is incompatible.
apache-beam 2.46.0 requires numpy<1.25.0,>=1.14.3, but you have numpy 1.26.4 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have pyarrow 16.1.0 which is incompatible.
beatrix-jupyterlab 2023.128.151533 requires jupyterlab~=3.6.0, but you have jupyterlab 4.2.3 which is incompatible.
cudf 24.6.1 requires cuda-python<12.0a0,>=11.7.1, but you have cuda-python 12.5.0 which is incompatible.
jupyterlab 4.2.3 requires jupyter-lsp>=2.0.0, but you have jupyter-lsp 1.5.1 which is incompatible.
jupyterlab-lsp 5.1.0 requires jupyter-lsp>=2.0.0, but you have jupyter-lsp 1.5.1 which is incompatible.
pointpats 2.5.0 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.
spaghetti 1.7.6 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.
spopt 0.6.1 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.4.1 which is incompatible.
ydata-profiling 4.6.4 requires matplotlib<3.9,>=3.2, but you have matplotlib 3.9.0 which is incompatible.
ydata-profiling 4.6.4 requires numpy<1.26,>=1.16.0, but you have numpy 1.26.4 which is incompatible.
ydata-profiling 4.6.4 requires seaborn<0.13,>=0.10.1, but you have seaborn 0.13.2 which is incompatible.
Note: you may need to restart the kernel to use updated packages.

Common Functions#

This module include GPT model definitions as well as some common config.

from __future__ import annotations

from typing import Literal, Tuple, cast

import torch
from pydantic import BaseModel
from torch import nn

from omnivault.modules.activation import GELU, SoftmaxStable
from omnivault.transformer.modules.layers.normalization import RMSNorm

__tagged__ = "This code tags to `30d963e` of cs336-stanford-spring2024-assignment1-gpt-from-scratch."
__reference__ = ["https://github.com/marcelroed/spring2024-assignment2-systems/blob/master/writeup.pdf"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class General(BaseModel):
    batch_size: int = 16
    seed: int = 20230310


class GPTConfig(BaseModel):
    approximate: Literal["tanh"] | None = None
    activation_name: Literal["gelu"] = "gelu"
    d_model: int
    d_ff: int | None = None
    num_heads: int
    context_length: int
    attn_pdrop: float = 0.0
    resid_pdrop: float = 0.0
    bias: bool = False
    vocab_size: int
    num_blocks: int
    token_position_pdrop: float = 0.0
    weight_tie: bool = False


class PositionwiseFeedForward(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int | None = None,
        bias: bool = False,
        activation_name: Literal["gelu"] = "gelu",
        dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff or 4 * d_model
        self.bias = bias  # bias False in this exercise
        self.activation_name = activation_name
        self.dropout = dropout

        self.ffn = nn.ModuleDict(
            {
                # incoming `B x T x D` and we are interested in `T x D` so weight is `D x d_ff`
                # so that `Z @ W1 -> (T x D) @ (D x d_ff)`
                "context_fc": nn.Linear(in_features=self.d_model, out_features=self.d_ff, bias=self.bias),
                "activation": self.activation,
                # apply dropout after activation for random lights out
                "dropout": nn.Dropout(p=self.dropout, inplace=False),
                # incoming is Z @ W1 -> T x d_ff -> (T x d_ff) @ (d_ff x D) project back to D
                "context_projection": nn.Linear(in_features=self.d_ff, out_features=self.d_model, bias=self.bias),
            }
        )

    @property
    def activation(self) -> nn.Module:
        if self.activation_name == "gelu":
            activation = GELU(approximate=None)  # no approx using tanh
        else:
            raise ValueError(f"Unsupported activation: {self._activation}")
        return activation

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # fmt: off
        z = self.ffn["context_fc"](z)           # Z @ W1 = [B, T, D] @ [D, d_ff] = [B, T, d_ff]
        z = self.ffn["activation"](z)           # \sigma(Z @ W1) = [B, T, d_ff]
        z = self.ffn["dropout"](z)              # \dropout(\sigma(Z @ W1)) = [B, T, d_ff]
        z = self.ffn["context_projection"](z)   # \dropout(\sigma(Z @ W1)) @ W2 = [B, T, d_ff] @ [d_ff, D] = [B, T, D]
        # fmt: on
        return z


class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout: float = 0.0) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.BoolTensor | None = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # fmt: off
        T, d_q = query.size(-2), query.size(-1)

        attention_scores  = torch.matmul(query, key.transpose(dim0=-2, dim1=-1)) / torch.sqrt(torch.tensor(d_q).float())        # Q @ K.T = [B, H, T, d_q] @ [B, H, d_q, T] = [B, H, T, T]

        if mask is not None:
            mask = mask[:, :, :T, :T] # type: ignore[assignment]
            attention_scores  = attention_scores.masked_fill(mask == 1, float("-inf")) if mask is not None else attention_scores    # [B, H, T, T]

        softmax           = SoftmaxStable(dim=-1)
        attention_weights = softmax(attention_scores)               # [B, H, T, T]
        attention_weights = self.dropout(attention_weights)         # [B, H, T, T]

        context_vector    = torch.matmul(attention_weights, value)  # [B, H, T, T] @ [B, H, T, d_v] = [B, H, T, d_v]
        # fmt: on
        return context_vector, attention_weights


class CausalMultiHeadSelfAttention(nn.Module):
    context_vector: torch.Tensor
    attention_weights: torch.Tensor

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        context_length: int,
        attn_pdrop: float = 0.0,  # pdrop means prob of dropout
        resid_pdrop: float = 0.0,
        bias: bool = False,
    ) -> None:
        super().__init__()

        assert d_model % num_heads == 0

        self.d_model = d_model
        self.H = num_heads
        self.context_length = context_length
        self.attn_pdrop = attn_pdrop
        self.resid_pdrop = resid_pdrop
        self.bias = bias

        self.W_Q = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
        self.W_K = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
        self.W_V = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)

        # alias of W_O
        self.context_projection = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)

        # regularization
        self.resid_dropout = nn.Dropout(self.resid_pdrop)

        self.attention = ScaledDotProductAttention(dropout=self.attn_pdrop)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        # register buffer cause not learnable weights
        self.register_buffer(
            "causal_mask",
            torch.triu(
                torch.ones((self.context_length, self.context_length)).bool(),
                diagonal=1,
            ).view(1, 1, self.context_length, self.context_length),
        )

    def forward(self, *, z: torch.Tensor) -> torch.Tensor:
        B, T, D = z.size()

        # fmt: off
        Q: torch.Tensor = self.W_Q(z).contiguous() # Z @ W_Q = [B, T, D] @ [D, D] = [B, T, D]
        K: torch.Tensor = self.W_K(z).contiguous() # Z @ W_K = [B, T, D] @ [D, D] = [B, T, D]
        V: torch.Tensor = self.W_V(z).contiguous() # Z @ W_V = [B, T, D] @ [D, D] = [B, T, D]

        Q = Q.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2) # [B, T, D] -> [B, T, H, D // H] -> [B, H, T, D//H]
        K = K.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2)
        V = V.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2)

        # Now pass them to self attention
        self.context_vector, self.attention_weights = self.attention(query=Q, key=K, value=V, mask=self.causal_mask) # ([B, H, T, D // H], [B, H, T, T])
        assert isinstance(self.context_vector, torch.Tensor) # do this for type hint in IDE

        # Now context vector is shape [B, H, T, D // H] but we want [B, T, D] to matmul with W_O/context_projection
        self.context_vector = self.context_vector.transpose(dim0=1, dim1=2).contiguous().view(B, T, D) # merge all heads together
        # fmt: on

        projected_context_vector: torch.Tensor = self.resid_dropout(
            self.context_projection(self.context_vector)  # [B, T, D] @ [D, D] = [B, T, D]
        )
        return projected_context_vector


class GPTBlock(nn.Module):
    def __init__(
        self,
        config: GPTConfig,
    ) -> None:
        super().__init__()

        self.rmns_1 = RMSNorm(d_model=config.d_model, eps=1e-5)
        self.attn = CausalMultiHeadSelfAttention(
            d_model=config.d_model,
            num_heads=config.num_heads,
            context_length=config.context_length,
            attn_pdrop=config.attn_pdrop,
            resid_pdrop=config.resid_pdrop,
            bias=config.bias,
        )
        self.rmns_2 = RMSNorm(d_model=config.d_model, eps=1e-5)
        self.ffn = PositionwiseFeedForward(
            d_model=config.d_model,
            d_ff=config.d_ff,
            bias=config.bias,
            activation_name=config.activation_name,
            dropout=config.resid_pdrop,
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        z = z + self.attn(z=self.rmns_1(z))
        z = z + self.ffn(self.rmns_2(z))
        return z


class GPT(nn.Module):
    def __init__(self, config: GPTConfig) -> None:
        super().__init__()

        self.config = config
        self.d_model = config.d_model
        self.num_blocks = config.num_blocks
        self.vocab_size = config.vocab_size

        self.blocks = nn.ModuleList([GPTBlock(config=config) for _ in range(self.num_blocks)])

        self.backbone = nn.ModuleDict(
            dict(  # noqa: C408
                token_embeddings=nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.d_model),
                position_embeddings=nn.Embedding(num_embeddings=config.context_length, embedding_dim=self.d_model),
                dropout=nn.Dropout(p=config.token_position_pdrop),
                layers=self.blocks,
                ln_final=RMSNorm(d_model=self.d_model, eps=1e-5),
            )
        )
        self.head = nn.Linear(in_features=self.d_model, out_features=self.vocab_size, bias=config.bias)

        self.apply(self._init_weights)

        context_projections = "context_projection.weight"
        # apply special scaled init to the residual projections, per GPT-2 paper
        for parameter_name, parameter in self.named_parameters():
            # NOTE: W_O is also projection but I did not have foresight to name it as such.
            if parameter_name.endswith(context_projections):
                mean = 0.0
                std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_blocks, dtype=torch.float))
                torch.nn.init.normal_(parameter, mean=mean, std=std_dev)

        if config.weight_tie:
            self.backbone.token_embeddings.weight = self.head.weight

    def crop_context_length(self, context_length: int) -> None:
        # NOTE: conveniently took Karpathy's implementation here for cropping
        assert context_length <= self.config.context_length
        self.config.context_length = context_length  # update config

        self.backbone.position_embeddings.weight = nn.Parameter(
            self.backbone.position_embeddings.weight[:context_length]
        )
        for block in self.backbone.layers:
            if hasattr(block.attn, "causal_mask"):
                block.attn.causal_mask = block.attn.causal_mask[:, :, :context_length, :context_length]

            # update context length attribute in MultiHeadSelfAttention
            block.attn.context_length = context_length

    def _init_weights(self, module: nn.Module) -> None:
        normal_init_modules = (nn.Linear, nn.Embedding)
        if isinstance(module, normal_init_modules):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, "bias") and module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, in_indices: torch.LongTensor) -> torch.FloatTensor:
        device = in_indices.device

        B, T = in_indices.size()

        positions = torch.arange(0, T, dtype=torch.long, device=device)  # [T]
        token_embeddings = self.backbone.token_embeddings(in_indices)  # [B, T, D]
        positional_embeddings = self.backbone.position_embeddings(positions)  # [T, D]
        # fmt: off
        positional_embeddings = positional_embeddings.unsqueeze(0) # .expand(B, -1, -1) # [B, T, D]
        # fmt: on

        z = self.backbone.dropout(token_embeddings + positional_embeddings)  # [B, T, D]

        for block in self.backbone.layers:
            z = block(z)  # [B, T, D]

        z = self.backbone.ln_final(z)  # [B, T, D]

        logits = self.head(z)  # [B, T, V]
        return cast(torch.FloatTensor, logits)  # [B, T, V]


def initialize_model(
    config: GPTConfig,
    device: str = "cuda",
) -> GPT:
    if config.d_ff is None:
        config.d_ff = 4 * config.d_model

    model = GPT(config)
    return model.to(device)


def get_random_batch(
    batch_size: int,
    context_length: int,
    vocab_size: int,
    device: str = "cuda",
) -> Tuple[torch.Tensor, torch.Tensor]:
    inputs = torch.randint(  # [B, T]
        0,
        vocab_size,
        (batch_size, context_length),
        dtype=torch.long,
        device=device,
    )

    targets = torch.randint(  # [B, T]
        0,
        vocab_size,
        (batch_size, context_length),
        dtype=torch.long,
        device=device,
    )
    return inputs, targets

Main Profiling Code#

from __future__ import annotations

import logging
import socket
import sys
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Iterable, Tuple

import torch
from torch import nn
from torch._C._profiler import _ExperimentalConfig
from torch.autograd.profiler import record_function
from torch.profiler import ProfilerActivity, profile, record_function

from omnivault.modules.loss import CrossEntropyLoss
from omnivault.utils.reproducibility.seed import seed_all

seed_all(42, True, False)


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True,
)
logger = logging.getLogger(__name__)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"


def profile_one_step(
    model: nn.Module,
    batch: Tuple[torch.Tensor, torch.Tensor],
    optimizer: torch.optim.Optimizer,
    criterion: CrossEntropyLoss,
    enable_backward: bool,
    enable_optimizer: bool,
    mixed_precision: bool = False,
) -> None:
    context = torch.autocast("cuda", dtype=torch.bfloat16) if mixed_precision else nullcontext()
    inputs, targets = batch[0], batch[1]  # typically this doesn't need to be under context.
    with context:  # type: ignore[attr-defined]
        with record_function(name="forward_pass"):
            logits = model(inputs)
            loss = criterion(logits, targets)

        if enable_backward:
            with record_function(name="backward_pass"):
                loss.backward()
            if enable_optimizer:
                with record_function(name="optimizer"):
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)


def run_warmup(
    model: nn.Module,
    batch: Tuple[torch.Tensor, torch.Tensor],
    optimizer: torch.optim.Optimizer,
    criterion: CrossEntropyLoss,
) -> None:
    inputs, targets = batch[0], batch[1]
    logits = model(inputs)
    loss = criterion(logits, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)
    torch.cuda.synchronize()


def trace_handler(prof: torch.profiler.profile) -> None:
    # Prefix for file names.
    host_name = socket.gethostname()
    timestamp = datetime.now().strftime(TIME_FORMAT_STR)
    file_prefix = f"{host_name}_{timestamp}"

    # Construct the trace file.
    prof.export_chrome_trace(f"{file_prefix}.json.gz")

    # Construct the memory timeline file.
    prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
    prof.export_stacks("lm_profiler_stacks.txt", "self_cuda_time_total")


def run_profiler(
    model: nn.Module,
    batch: Tuple[torch.Tensor, torch.Tensor],
    optimizer: torch.optim.Optimizer,
    criterion: CrossEntropyLoss,
    enable_backward: bool,
    enable_optimizer: bool,
    mixed_precision: bool = False,
    profile_steps: int = 5,
    *,
    activities: Iterable[ProfilerActivity] | None = None,
    profile_memory: bool = False,
    with_stack: bool = True,
    record_shapes: bool = True,
    with_flops: bool = False,
    **profile_kwargs: Any,
) -> torch.profiler.profile:
    # profile_memory requires with_stack and record_shapes, hence we override these if profile_memory is True
    # See torch.profiler.profiler._memory_profile
    if profile_memory:
        logger.warning(
            "`profile_memory` requires `with_stack` and `record_shapes`, these will be enabled since `profile_memory` is True"
        )
    with_stack = with_stack or profile_memory
    record_shapes = record_shapes or profile_memory

    # experimental config is needed to export stacks: see https://github.com/pytorch/pytorch/issues/100253
    experimental_config = _ExperimentalConfig(verbose=True) if with_stack else None

    if profile_memory:
        torch.cuda.memory._record_memory_history(max_entries=1_000_000)

    with profile(
        activities=activities,  # [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
        experimental_config=experimental_config,
        record_shapes=record_shapes,
        profile_memory=profile_memory,
        with_stack=with_stack,
        with_flops=with_flops,
        **profile_kwargs,
    ) as prof:
        for _ in range(profile_steps):
            profile_one_step(
                model=model,
                batch=batch,
                optimizer=optimizer,
                criterion=criterion,
                enable_backward=enable_backward,
                enable_optimizer=enable_optimizer,
                mixed_precision=mixed_precision,
            )
            prof.step()

    if profile_memory:
        torch.cuda.memory._dump_snapshot('memory_snapshot.pickle')
        torch.cuda.memory._record_memory_history(enabled=None)
    return prof  # type: ignore[no-any-return]
gpt_small_config = GPTConfig(
    context_length=128,
    vocab_size=10_000,
    d_model=768,
    num_blocks=12,
    num_heads=12,
)
general = General()

seed_all(general.seed, True, False)

batch = get_random_batch(
    batch_size=general.batch_size,
    context_length=gpt_small_config.context_length,
    vocab_size=gpt_small_config.vocab_size,
)

model = GPT(gpt_small_config).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()
run_warmup(model=model, batch=batch, optimizer=optimizer, criterion=criterion)

profiled = run_profiler(
    model=model,
    batch=batch,
    optimizer=optimizer,
    criterion=criterion,
    enable_backward=True,
    enable_optimizer=True,
    mixed_precision=False,
    profile_steps=5,
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    profile_memory=False,
    with_stack=True,
    record_shapes=True,
    with_flops=False,
)
STAGE:2024-08-12 08:48:52 34:34 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-08-12 08:48:53 34:34 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-08-12 08:48:53 34:34 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
profiled.export_stacks("lm_profiler_stacks.txt", "self_cuda_time_total")
print(profiled.key_averages().table(sort_by="cpu_time_total", row_limit=10))
print(profiled.key_averages().table(sort_by="cuda_time_total", row_limit=10))
profiled.export_chrome_trace("trace.json")
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          backward_pass        36.08%     643.198ms        36.09%     643.418ms     128.684ms       0.000us         0.00%       5.000us       1.000us             5  
                                       cudaLaunchKernel        34.75%     619.425ms        34.75%     619.425ms      80.654us       0.000us         0.00%       0.000us       0.000us          7680  
                                           forward_pass         3.08%      54.873ms        18.44%     328.673ms      65.735ms       0.000us         0.00%     355.509ms      71.102ms             5  
      autograd::engine::evaluate_function: DivBackward0         0.34%       5.976ms        11.06%     197.217ms     458.644us       0.000us         0.00%      57.947ms     134.760us           430  
                                              aten::mul         1.23%      21.851ms         8.95%     159.555ms     112.363us      83.133ms         7.15%      83.133ms      58.544us          1420  
                                  cudaDeviceSynchronize         8.42%     150.048ms         8.42%     150.048ms     150.048ms       0.000us         0.00%       0.000us       0.000us             1  
                                           DivBackward0         0.21%       3.775ms         6.89%     122.752ms     285.470us       0.000us         0.00%      47.889ms     111.370us           430  
                                              aten::div         1.22%      21.803ms         6.68%     119.108ms      87.259us      50.880ms         4.38%      50.880ms      37.275us          1365  
                                               aten::mm         1.38%      24.547ms         6.10%     108.698ms      99.268us     757.305ms        65.15%     757.305ms     691.603us          1095  
       autograd::engine::evaluate_function: MmBackward0         0.20%       3.608ms         5.07%      90.308ms     247.419us       0.000us         0.00%     507.949ms       1.392ms           365  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.783s
Self CUDA time total: 1.162s

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm         1.38%      24.547ms         6.10%     108.698ms      99.268us     757.305ms        65.15%     757.305ms     691.603us          1095  
       autograd::engine::evaluate_function: MmBackward0         0.20%       3.608ms         5.07%      90.308ms     247.419us       0.000us         0.00%     507.949ms       1.392ms           365  
                                            MmBackward0         0.26%       4.547ms         4.86%      86.700ms     237.534us       0.000us         0.00%     507.949ms       1.392ms           365  
                                           forward_pass         3.08%      54.873ms        18.44%     328.673ms      65.735ms       0.000us         0.00%     355.509ms      71.102ms             5  
                                           aten::matmul         0.30%       5.300ms         3.74%      66.703ms     137.532us       0.000us         0.00%     269.186ms     555.023us           485  
                                           aten::linear         0.15%       2.613ms         2.51%      44.805ms     122.753us       0.000us         0.00%     237.179ms     649.805us           365  
                                 sgemm_128x128x8_TN_vec         0.00%       0.000us         0.00%       0.000us       0.000us     184.272ms        15.85%     184.272ms       1.474ms           125  
                                maxwell_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us     181.512ms        15.62%     181.512ms     427.087us           425  
                                maxwell_sgemm_128x64_tn         0.00%       0.000us         0.00%       0.000us       0.000us     166.766ms        14.35%     166.766ms     397.062us           420  
                                 sgemm_128x128x8_NT_vec         0.00%       0.000us         0.00%       0.000us       0.000us      94.916ms         8.17%      94.916ms       1.460ms            65  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.783s
Self CUDA time total: 1.162s
run_warmup(model=model, batch=batch, optimizer=optimizer, criterion=criterion)

profiled = run_profiler(
    model=model,
    batch=batch,
    optimizer=optimizer,
    criterion=criterion,
    enable_backward=True,
    enable_optimizer=True,
    mixed_precision=False,
    profile_steps=5,
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    profile_memory=True,
    with_stack=True,
    record_shapes=True,
    with_flops=False,
    schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=3),
    on_trace_ready=trace_handler,
)

You can plug your memory_snapshot.pickle into https://pytorch.org/memory_viz.

print(profiled.key_averages().table(sort_by="cuda_memory_usage", row_limit=50))
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::mul         1.81%       5.259ms         2.62%       7.591ms      26.729us      16.628ms         7.17%      16.628ms      58.549us          24 b          24 b       3.21 Gb       3.21 Gb           284  
                                           forward_pass         4.13%      11.992ms        12.06%      34.999ms      34.999ms       0.000us         0.00%      70.702ms      70.702ms         440 b         244 b       2.15 Gb      -1.94 Gb             1  
                                              aten::div         1.63%       4.720ms         2.37%       6.889ms      25.234us      10.180ms         4.39%      10.180ms      37.289us           8 b           8 b       1.99 Gb       1.99 Gb           273  
                                               aten::mm         1.94%       5.618ms         3.01%       8.729ms      39.858us     150.881ms        65.08%     150.881ms     688.954us           0 b           0 b       1.71 Gb       1.71 Gb           219  
                                            aten::empty         0.72%       2.082ms         0.72%       2.082ms      13.788us       0.000us         0.00%       0.000us       0.000us          96 b          96 b       1.17 Gb       1.17 Gb           151  
                                           MulBackward0         0.10%     297.000us         0.90%       2.600ms      53.061us       0.000us         0.00%       5.903ms     120.469us           0 b           0 b       1.14 Gb      24.00 Mb            49  
                                           aten::matmul         0.32%     925.000us         2.77%       8.047ms      82.959us       0.000us         0.00%      53.449ms     551.021us           0 b           0 b       1.13 Gb           0 b            97  
                                            MmBackward0         0.38%       1.100ms         2.77%       8.050ms     110.274us       0.000us         0.00%     101.425ms       1.389ms           0 b           0 b       1.00 Gb           0 b            73  
                                           DivBackward0         0.30%     867.000us         2.32%       6.722ms      78.163us       0.000us         0.00%       9.584ms     111.442us           0 b           0 b    1020.20 Mb    -882.00 Mb            86  
                                       aten::empty_like         0.14%     419.000us         0.56%       1.627ms      13.446us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     864.00 Mb      24.00 Mb           121  
                                            aten::clone         0.25%     726.000us         1.85%       5.358ms      44.650us       0.000us         0.00%       5.261ms      43.842us           0 b           0 b     864.00 Mb     -24.00 Mb           120  
                                              aten::neg         0.35%       1.002ms         0.55%       1.589ms      24.828us       3.204ms         1.38%       3.204ms      50.062us           0 b           0 b     804.14 Mb     804.14 Mb            64  
                                           aten::linear         0.12%     353.000us         1.87%       5.432ms      74.411us       0.000us         0.00%      47.933ms     656.616us           0 b           0 b     726.12 Mb      30.00 Mb            73  
                                              aten::pow         0.47%       1.376ms         0.82%       2.391ms      38.565us       1.737ms         0.75%       2.377ms      38.339us           0 b           0 b     588.00 Mb     588.00 Mb            62  
                                              aten::bmm         0.61%       1.767ms         0.83%       2.409ms      33.458us       7.221ms         3.11%       7.221ms     100.292us           0 b           0 b     576.00 Mb     576.00 Mb            72  
                                              aten::add         0.47%       1.369ms         0.67%       1.945ms      25.933us       2.466ms         1.06%       2.466ms      32.880us           0 b           0 b     510.20 Mb     510.20 Mb            75  
                                              aten::exp         0.15%     433.000us         0.22%     632.000us      25.280us       2.020ms         0.87%       2.020ms      80.800us           0 b           0 b     510.12 Mb     510.12 Mb            25  
                                          aten::reshape         0.41%       1.199ms         2.11%       6.133ms      15.807us       0.000us         0.00%       3.612ms       9.309us           0 b           0 b     504.00 Mb           0 b           388  
                                    aten::empty_strided         0.25%     723.000us         0.25%     723.000us       3.394us       0.000us         0.00%       0.000us       0.000us         300 b         300 b     397.07 Mb     397.07 Mb           213  
                                    aten::_foreach_sqrt         0.11%     310.000us         0.34%     981.000us     981.000us       1.601ms         0.69%       1.601ms       1.601ms           0 b           0 b     397.07 Mb           0 b             1  
                                           BmmBackward0         0.10%     285.000us         0.75%       2.185ms      91.042us       0.000us         0.00%       4.825ms     201.042us           0 b           0 b     360.00 Mb           0 b            24  
                                           ErfBackward0         0.07%     207.000us         0.60%       1.739ms     144.917us       0.000us         0.00%       6.193ms     516.083us           0 b           0 b     336.00 Mb      -1.08 Gb            12  
                                      aten::masked_fill         0.11%     325.000us         0.71%       2.070ms      86.250us       0.000us         0.00%       2.533ms     105.542us           0 b           0 b     288.00 Mb           0 b            24  
                                              aten::erf         0.06%     180.000us         0.09%     254.000us      21.167us       1.142ms         0.49%       1.142ms      95.167us           0 b           0 b     288.00 Mb     288.00 Mb            12  
                                            aten::zeros         0.03%      84.000us         0.20%     589.000us      39.267us       0.000us         0.00%     438.000us      29.200us           0 b           0 b     251.80 Mb           0 b            15  
                                              aten::sub         0.08%     239.000us         0.11%     331.000us      25.462us     984.000us         0.42%     984.000us      75.692us           0 b           0 b     222.12 Mb     222.12 Mb            13  
                                           ExpBackward0         0.03%      89.000us         0.15%     430.000us      33.077us       0.000us         0.00%       1.173ms      90.231us           0 b           0 b     222.12 Mb           0 b            13  
                                           SubBackward0         0.01%      40.000us         0.12%     361.000us      27.769us       0.000us         0.00%     880.000us      67.692us           0 b           0 b     222.12 Mb           0 b            13  
                                           MaxBackward0         0.02%      47.000us         0.38%       1.113ms      85.615us       0.000us         0.00%     521.000us      40.077us           0 b           0 b     222.12 Mb           0 b            13  
               aten::value_selecting_reduction_backward         0.04%     108.000us         0.36%       1.042ms      80.154us       0.000us         0.00%     511.000us      39.308us           0 b           0 b     222.12 Mb           0 b            13  
                                           PowBackward0         0.09%     265.000us         0.97%       2.820ms     112.800us       0.000us         0.00%       1.976ms      79.040us           0 b          -8 b     174.00 Mb    -276.00 Mb            25  
                                    UnsafeViewBackward0         0.16%     473.000us         0.76%       2.216ms      16.662us       0.000us         0.00%       1.032ms       7.759us           0 b           0 b     144.00 Mb      12.00 Mb           133  
                                          ViewBackward0         0.16%     464.000us         0.75%       2.170ms      16.316us       0.000us         0.00%     792.000us       5.955us           0 b           0 b     144.00 Mb      12.00 Mb           133  
                                    MaskedFillBackward0         0.07%     205.000us         0.39%       1.136ms      94.667us       0.000us         0.00%       1.057ms      88.083us           0 b           0 b     144.00 Mb      24.00 Mb            12  
                                        GatherBackward0         0.00%       6.000us         0.03%      88.000us      88.000us       0.000us         0.00%     145.000us     145.000us           0 b           0 b      78.12 Mb           0 b             1  
                                  aten::gather_backward         0.00%       8.000us         0.03%      82.000us      82.000us       0.000us         0.00%     145.000us     145.000us           0 b           0 b      78.12 Mb           0 b             1  
                                        aten::new_zeros         0.00%       7.000us         0.01%      40.000us      40.000us       0.000us         0.00%     136.000us     136.000us           0 b           0 b      78.12 Mb           0 b             1  
                                        aten::new_empty         0.00%       3.000us         0.00%      13.000us      13.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      78.12 Mb           0 b             1  
                                       aten::contiguous         0.01%      21.000us         0.18%     513.000us      42.750us       0.000us         0.00%     432.000us      36.000us           0 b           0 b      72.00 Mb           0 b            12  
                                     EmbeddingBackward0         0.00%       6.000us         0.04%     108.000us      54.000us       0.000us         0.00%     214.000us     107.000us           0 b           0 b      29.67 Mb           0 b             2  
                               aten::embedding_backward         0.00%       5.000us         0.04%     102.000us      51.000us       0.000us         0.00%     214.000us     107.000us           0 b           0 b      29.67 Mb           0 b             2  
                         aten::embedding_dense_backward         0.01%      29.000us         0.03%      97.000us      48.500us     161.000us         0.07%     214.000us     107.000us           0 b           0 b      29.67 Mb           0 b             2  
autograd::engine::evaluate_function: EmbeddingBackwa...         0.01%      18.000us         0.04%     126.000us      63.000us       0.000us         0.00%     214.000us     107.000us           0 b           0 b      23.30 Mb      -6.38 Mb             2  
                                          aten::resize_         0.01%      34.000us         0.01%      34.000us      11.333us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       6.38 Mb       6.38 Mb             3  
                                        aten::embedding         0.01%      36.000us         0.05%     156.000us      78.000us       0.000us         0.00%      80.000us      40.000us           0 b           0 b       6.38 Mb           0 b             2  
                                     aten::index_select         0.02%      50.000us         0.04%     111.000us      55.500us      80.000us         0.03%      80.000us      40.000us           0 b           0 b       6.38 Mb           0 b             2  
       autograd::engine::evaluate_function: MmBackward0         0.32%     921.000us         3.09%       8.971ms     122.890us       0.000us         0.00%     101.425ms       1.389ms           0 b           0 b       4.92 Mb   -1020.12 Mb            73  
                                              aten::sum         0.99%       2.874ms         1.35%       3.923ms      34.412us       4.123ms         1.78%       4.123ms      36.167us           0 b           0 b       4.23 Mb       4.23 Mb           114  
                                              aten::max         0.15%     442.000us         0.19%     544.000us      41.846us       1.390ms         0.60%       1.390ms     106.923us           0 b           0 b       3.40 Mb       3.40 Mb            13  
      autograd::engine::evaluate_function: SubBackward0         0.06%     175.000us         0.32%     921.000us      70.846us       0.000us         0.00%       1.718ms     132.154us           0 b           0 b       1.13 Mb    -222.12 Mb            13  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 290.215ms
Self CUDA time total: 231.841ms

References And Further Readings#