Profile GPT Small Time And Memory#
%pip install -q omniverse==0.0.63
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 24.6.1 requires cubinlinker, which is not installed.
cudf 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 24.6.1 requires ptxcompiler, which is not installed.
cuml 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.8 which is incompatible.
apache-beam 2.46.0 requires numpy<1.25.0,>=1.14.3, but you have numpy 1.26.4 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have pyarrow 16.1.0 which is incompatible.
beatrix-jupyterlab 2023.128.151533 requires jupyterlab~=3.6.0, but you have jupyterlab 4.2.3 which is incompatible.
cudf 24.6.1 requires cuda-python<12.0a0,>=11.7.1, but you have cuda-python 12.5.0 which is incompatible.
jupyterlab 4.2.3 requires jupyter-lsp>=2.0.0, but you have jupyter-lsp 1.5.1 which is incompatible.
jupyterlab-lsp 5.1.0 requires jupyter-lsp>=2.0.0, but you have jupyter-lsp 1.5.1 which is incompatible.
pointpats 2.5.0 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.
spaghetti 1.7.6 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.
spopt 0.6.1 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.4.1 which is incompatible.
ydata-profiling 4.6.4 requires matplotlib<3.9,>=3.2, but you have matplotlib 3.9.0 which is incompatible.
ydata-profiling 4.6.4 requires numpy<1.26,>=1.16.0, but you have numpy 1.26.4 which is incompatible.
ydata-profiling 4.6.4 requires seaborn<0.13,>=0.10.1, but you have seaborn 0.13.2 which is incompatible.
Note: you may need to restart the kernel to use updated packages.
Common Functions#
This module include GPT model definitions as well as some common config.
from __future__ import annotations
from typing import Literal, Tuple, cast
import torch
from pydantic import BaseModel
from torch import nn
from omnivault.modules.activation import GELU, SoftmaxStable
from omnivault.transformer.modules.layers.normalization import RMSNorm
__tagged__ = "This code tags to `30d963e` of cs336-stanford-spring2024-assignment1-gpt-from-scratch."
__reference__ = ["https://github.com/marcelroed/spring2024-assignment2-systems/blob/master/writeup.pdf"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class General(BaseModel):
batch_size: int = 16
seed: int = 20230310
class GPTConfig(BaseModel):
approximate: Literal["tanh"] | None = None
activation_name: Literal["gelu"] = "gelu"
d_model: int
d_ff: int | None = None
num_heads: int
context_length: int
attn_pdrop: float = 0.0
resid_pdrop: float = 0.0
bias: bool = False
vocab_size: int
num_blocks: int
token_position_pdrop: float = 0.0
weight_tie: bool = False
class PositionwiseFeedForward(nn.Module):
def __init__(
self,
d_model: int,
d_ff: int | None = None,
bias: bool = False,
activation_name: Literal["gelu"] = "gelu",
dropout: float = 0.0,
) -> None:
super().__init__()
self.d_model = d_model
self.d_ff = d_ff or 4 * d_model
self.bias = bias # bias False in this exercise
self.activation_name = activation_name
self.dropout = dropout
self.ffn = nn.ModuleDict(
{
# incoming `B x T x D` and we are interested in `T x D` so weight is `D x d_ff`
# so that `Z @ W1 -> (T x D) @ (D x d_ff)`
"context_fc": nn.Linear(in_features=self.d_model, out_features=self.d_ff, bias=self.bias),
"activation": self.activation,
# apply dropout after activation for random lights out
"dropout": nn.Dropout(p=self.dropout, inplace=False),
# incoming is Z @ W1 -> T x d_ff -> (T x d_ff) @ (d_ff x D) project back to D
"context_projection": nn.Linear(in_features=self.d_ff, out_features=self.d_model, bias=self.bias),
}
)
@property
def activation(self) -> nn.Module:
if self.activation_name == "gelu":
activation = GELU(approximate=None) # no approx using tanh
else:
raise ValueError(f"Unsupported activation: {self._activation}")
return activation
def forward(self, z: torch.Tensor) -> torch.Tensor:
# fmt: off
z = self.ffn["context_fc"](z) # Z @ W1 = [B, T, D] @ [D, d_ff] = [B, T, d_ff]
z = self.ffn["activation"](z) # \sigma(Z @ W1) = [B, T, d_ff]
z = self.ffn["dropout"](z) # \dropout(\sigma(Z @ W1)) = [B, T, d_ff]
z = self.ffn["context_projection"](z) # \dropout(\sigma(Z @ W1)) @ W2 = [B, T, d_ff] @ [d_ff, D] = [B, T, D]
# fmt: on
return z
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout: float = 0.0) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.BoolTensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# fmt: off
T, d_q = query.size(-2), query.size(-1)
attention_scores = torch.matmul(query, key.transpose(dim0=-2, dim1=-1)) / torch.sqrt(torch.tensor(d_q).float()) # Q @ K.T = [B, H, T, d_q] @ [B, H, d_q, T] = [B, H, T, T]
if mask is not None:
mask = mask[:, :, :T, :T] # type: ignore[assignment]
attention_scores = attention_scores.masked_fill(mask == 1, float("-inf")) if mask is not None else attention_scores # [B, H, T, T]
softmax = SoftmaxStable(dim=-1)
attention_weights = softmax(attention_scores) # [B, H, T, T]
attention_weights = self.dropout(attention_weights) # [B, H, T, T]
context_vector = torch.matmul(attention_weights, value) # [B, H, T, T] @ [B, H, T, d_v] = [B, H, T, d_v]
# fmt: on
return context_vector, attention_weights
class CausalMultiHeadSelfAttention(nn.Module):
context_vector: torch.Tensor
attention_weights: torch.Tensor
def __init__(
self,
d_model: int,
num_heads: int,
context_length: int,
attn_pdrop: float = 0.0, # pdrop means prob of dropout
resid_pdrop: float = 0.0,
bias: bool = False,
) -> None:
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.H = num_heads
self.context_length = context_length
self.attn_pdrop = attn_pdrop
self.resid_pdrop = resid_pdrop
self.bias = bias
self.W_Q = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
self.W_K = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
self.W_V = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
# alias of W_O
self.context_projection = nn.Linear(in_features=self.d_model, out_features=self.d_model, bias=self.bias)
# regularization
self.resid_dropout = nn.Dropout(self.resid_pdrop)
self.attention = ScaledDotProductAttention(dropout=self.attn_pdrop)
# causal mask to ensure that attention is only applied to the left in the input sequence
# register buffer cause not learnable weights
self.register_buffer(
"causal_mask",
torch.triu(
torch.ones((self.context_length, self.context_length)).bool(),
diagonal=1,
).view(1, 1, self.context_length, self.context_length),
)
def forward(self, *, z: torch.Tensor) -> torch.Tensor:
B, T, D = z.size()
# fmt: off
Q: torch.Tensor = self.W_Q(z).contiguous() # Z @ W_Q = [B, T, D] @ [D, D] = [B, T, D]
K: torch.Tensor = self.W_K(z).contiguous() # Z @ W_K = [B, T, D] @ [D, D] = [B, T, D]
V: torch.Tensor = self.W_V(z).contiguous() # Z @ W_V = [B, T, D] @ [D, D] = [B, T, D]
Q = Q.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2) # [B, T, D] -> [B, T, H, D // H] -> [B, H, T, D//H]
K = K.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2)
V = V.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2)
# Now pass them to self attention
self.context_vector, self.attention_weights = self.attention(query=Q, key=K, value=V, mask=self.causal_mask) # ([B, H, T, D // H], [B, H, T, T])
assert isinstance(self.context_vector, torch.Tensor) # do this for type hint in IDE
# Now context vector is shape [B, H, T, D // H] but we want [B, T, D] to matmul with W_O/context_projection
self.context_vector = self.context_vector.transpose(dim0=1, dim1=2).contiguous().view(B, T, D) # merge all heads together
# fmt: on
projected_context_vector: torch.Tensor = self.resid_dropout(
self.context_projection(self.context_vector) # [B, T, D] @ [D, D] = [B, T, D]
)
return projected_context_vector
class GPTBlock(nn.Module):
def __init__(
self,
config: GPTConfig,
) -> None:
super().__init__()
self.rmns_1 = RMSNorm(d_model=config.d_model, eps=1e-5)
self.attn = CausalMultiHeadSelfAttention(
d_model=config.d_model,
num_heads=config.num_heads,
context_length=config.context_length,
attn_pdrop=config.attn_pdrop,
resid_pdrop=config.resid_pdrop,
bias=config.bias,
)
self.rmns_2 = RMSNorm(d_model=config.d_model, eps=1e-5)
self.ffn = PositionwiseFeedForward(
d_model=config.d_model,
d_ff=config.d_ff,
bias=config.bias,
activation_name=config.activation_name,
dropout=config.resid_pdrop,
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
z = z + self.attn(z=self.rmns_1(z))
z = z + self.ffn(self.rmns_2(z))
return z
class GPT(nn.Module):
def __init__(self, config: GPTConfig) -> None:
super().__init__()
self.config = config
self.d_model = config.d_model
self.num_blocks = config.num_blocks
self.vocab_size = config.vocab_size
self.blocks = nn.ModuleList([GPTBlock(config=config) for _ in range(self.num_blocks)])
self.backbone = nn.ModuleDict(
dict( # noqa: C408
token_embeddings=nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.d_model),
position_embeddings=nn.Embedding(num_embeddings=config.context_length, embedding_dim=self.d_model),
dropout=nn.Dropout(p=config.token_position_pdrop),
layers=self.blocks,
ln_final=RMSNorm(d_model=self.d_model, eps=1e-5),
)
)
self.head = nn.Linear(in_features=self.d_model, out_features=self.vocab_size, bias=config.bias)
self.apply(self._init_weights)
context_projections = "context_projection.weight"
# apply special scaled init to the residual projections, per GPT-2 paper
for parameter_name, parameter in self.named_parameters():
# NOTE: W_O is also projection but I did not have foresight to name it as such.
if parameter_name.endswith(context_projections):
mean = 0.0
std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_blocks, dtype=torch.float))
torch.nn.init.normal_(parameter, mean=mean, std=std_dev)
if config.weight_tie:
self.backbone.token_embeddings.weight = self.head.weight
def crop_context_length(self, context_length: int) -> None:
# NOTE: conveniently took Karpathy's implementation here for cropping
assert context_length <= self.config.context_length
self.config.context_length = context_length # update config
self.backbone.position_embeddings.weight = nn.Parameter(
self.backbone.position_embeddings.weight[:context_length]
)
for block in self.backbone.layers:
if hasattr(block.attn, "causal_mask"):
block.attn.causal_mask = block.attn.causal_mask[:, :, :context_length, :context_length]
# update context length attribute in MultiHeadSelfAttention
block.attn.context_length = context_length
def _init_weights(self, module: nn.Module) -> None:
normal_init_modules = (nn.Linear, nn.Embedding)
if isinstance(module, normal_init_modules):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, in_indices: torch.LongTensor) -> torch.FloatTensor:
device = in_indices.device
B, T = in_indices.size()
positions = torch.arange(0, T, dtype=torch.long, device=device) # [T]
token_embeddings = self.backbone.token_embeddings(in_indices) # [B, T, D]
positional_embeddings = self.backbone.position_embeddings(positions) # [T, D]
# fmt: off
positional_embeddings = positional_embeddings.unsqueeze(0) # .expand(B, -1, -1) # [B, T, D]
# fmt: on
z = self.backbone.dropout(token_embeddings + positional_embeddings) # [B, T, D]
for block in self.backbone.layers:
z = block(z) # [B, T, D]
z = self.backbone.ln_final(z) # [B, T, D]
logits = self.head(z) # [B, T, V]
return cast(torch.FloatTensor, logits) # [B, T, V]
def initialize_model(
config: GPTConfig,
device: str = "cuda",
) -> GPT:
if config.d_ff is None:
config.d_ff = 4 * config.d_model
model = GPT(config)
return model.to(device)
def get_random_batch(
batch_size: int,
context_length: int,
vocab_size: int,
device: str = "cuda",
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs = torch.randint( # [B, T]
0,
vocab_size,
(batch_size, context_length),
dtype=torch.long,
device=device,
)
targets = torch.randint( # [B, T]
0,
vocab_size,
(batch_size, context_length),
dtype=torch.long,
device=device,
)
return inputs, targets
Main Profiling Code#
from __future__ import annotations
import logging
import socket
import sys
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Iterable, Tuple
import torch
from torch import nn
from torch._C._profiler import _ExperimentalConfig
from torch.autograd.profiler import record_function
from torch.profiler import ProfilerActivity, profile, record_function
from omnivault.modules.loss import CrossEntropyLoss
from omnivault.utils.reproducibility.seed import seed_all
seed_all(42, True, False)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
force=True,
)
logger = logging.getLogger(__name__)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
def profile_one_step(
model: nn.Module,
batch: Tuple[torch.Tensor, torch.Tensor],
optimizer: torch.optim.Optimizer,
criterion: CrossEntropyLoss,
enable_backward: bool,
enable_optimizer: bool,
mixed_precision: bool = False,
) -> None:
context = torch.autocast("cuda", dtype=torch.bfloat16) if mixed_precision else nullcontext()
inputs, targets = batch[0], batch[1] # typically this doesn't need to be under context.
with context: # type: ignore[attr-defined]
with record_function(name="forward_pass"):
logits = model(inputs)
loss = criterion(logits, targets)
if enable_backward:
with record_function(name="backward_pass"):
loss.backward()
if enable_optimizer:
with record_function(name="optimizer"):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
def run_warmup(
model: nn.Module,
batch: Tuple[torch.Tensor, torch.Tensor],
optimizer: torch.optim.Optimizer,
criterion: CrossEntropyLoss,
) -> None:
inputs, targets = batch[0], batch[1]
logits = model(inputs)
loss = criterion(logits, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.synchronize()
def trace_handler(prof: torch.profiler.profile) -> None:
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
# Construct the trace file.
prof.export_chrome_trace(f"{file_prefix}.json.gz")
# Construct the memory timeline file.
prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
prof.export_stacks("lm_profiler_stacks.txt", "self_cuda_time_total")
def run_profiler(
model: nn.Module,
batch: Tuple[torch.Tensor, torch.Tensor],
optimizer: torch.optim.Optimizer,
criterion: CrossEntropyLoss,
enable_backward: bool,
enable_optimizer: bool,
mixed_precision: bool = False,
profile_steps: int = 5,
*,
activities: Iterable[ProfilerActivity] | None = None,
profile_memory: bool = False,
with_stack: bool = True,
record_shapes: bool = True,
with_flops: bool = False,
**profile_kwargs: Any,
) -> torch.profiler.profile:
# profile_memory requires with_stack and record_shapes, hence we override these if profile_memory is True
# See torch.profiler.profiler._memory_profile
if profile_memory:
logger.warning(
"`profile_memory` requires `with_stack` and `record_shapes`, these will be enabled since `profile_memory` is True"
)
with_stack = with_stack or profile_memory
record_shapes = record_shapes or profile_memory
# experimental config is needed to export stacks: see https://github.com/pytorch/pytorch/issues/100253
experimental_config = _ExperimentalConfig(verbose=True) if with_stack else None
if profile_memory:
torch.cuda.memory._record_memory_history(max_entries=1_000_000)
with profile(
activities=activities, # [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
experimental_config=experimental_config,
record_shapes=record_shapes,
profile_memory=profile_memory,
with_stack=with_stack,
with_flops=with_flops,
**profile_kwargs,
) as prof:
for _ in range(profile_steps):
profile_one_step(
model=model,
batch=batch,
optimizer=optimizer,
criterion=criterion,
enable_backward=enable_backward,
enable_optimizer=enable_optimizer,
mixed_precision=mixed_precision,
)
prof.step()
if profile_memory:
torch.cuda.memory._dump_snapshot('memory_snapshot.pickle')
torch.cuda.memory._record_memory_history(enabled=None)
return prof # type: ignore[no-any-return]
gpt_small_config = GPTConfig(
context_length=128,
vocab_size=10_000,
d_model=768,
num_blocks=12,
num_heads=12,
)
general = General()
seed_all(general.seed, True, False)
batch = get_random_batch(
batch_size=general.batch_size,
context_length=gpt_small_config.context_length,
vocab_size=gpt_small_config.vocab_size,
)
model = GPT(gpt_small_config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()
run_warmup(model=model, batch=batch, optimizer=optimizer, criterion=criterion)
profiled = run_profiler(
model=model,
batch=batch,
optimizer=optimizer,
criterion=criterion,
enable_backward=True,
enable_optimizer=True,
mixed_precision=False,
profile_steps=5,
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
profile_memory=False,
with_stack=True,
record_shapes=True,
with_flops=False,
)
STAGE:2024-08-12 08:48:52 34:34 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-08-12 08:48:53 34:34 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-08-12 08:48:53 34:34 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
profiled.export_stacks("lm_profiler_stacks.txt", "self_cuda_time_total")
print(profiled.key_averages().table(sort_by="cpu_time_total", row_limit=10))
print(profiled.key_averages().table(sort_by="cuda_time_total", row_limit=10))
profiled.export_chrome_trace("trace.json")
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
backward_pass 36.08% 643.198ms 36.09% 643.418ms 128.684ms 0.000us 0.00% 5.000us 1.000us 5
cudaLaunchKernel 34.75% 619.425ms 34.75% 619.425ms 80.654us 0.000us 0.00% 0.000us 0.000us 7680
forward_pass 3.08% 54.873ms 18.44% 328.673ms 65.735ms 0.000us 0.00% 355.509ms 71.102ms 5
autograd::engine::evaluate_function: DivBackward0 0.34% 5.976ms 11.06% 197.217ms 458.644us 0.000us 0.00% 57.947ms 134.760us 430
aten::mul 1.23% 21.851ms 8.95% 159.555ms 112.363us 83.133ms 7.15% 83.133ms 58.544us 1420
cudaDeviceSynchronize 8.42% 150.048ms 8.42% 150.048ms 150.048ms 0.000us 0.00% 0.000us 0.000us 1
DivBackward0 0.21% 3.775ms 6.89% 122.752ms 285.470us 0.000us 0.00% 47.889ms 111.370us 430
aten::div 1.22% 21.803ms 6.68% 119.108ms 87.259us 50.880ms 4.38% 50.880ms 37.275us 1365
aten::mm 1.38% 24.547ms 6.10% 108.698ms 99.268us 757.305ms 65.15% 757.305ms 691.603us 1095
autograd::engine::evaluate_function: MmBackward0 0.20% 3.608ms 5.07% 90.308ms 247.419us 0.000us 0.00% 507.949ms 1.392ms 365
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 1.783s
Self CUDA time total: 1.162s
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 1.38% 24.547ms 6.10% 108.698ms 99.268us 757.305ms 65.15% 757.305ms 691.603us 1095
autograd::engine::evaluate_function: MmBackward0 0.20% 3.608ms 5.07% 90.308ms 247.419us 0.000us 0.00% 507.949ms 1.392ms 365
MmBackward0 0.26% 4.547ms 4.86% 86.700ms 237.534us 0.000us 0.00% 507.949ms 1.392ms 365
forward_pass 3.08% 54.873ms 18.44% 328.673ms 65.735ms 0.000us 0.00% 355.509ms 71.102ms 5
aten::matmul 0.30% 5.300ms 3.74% 66.703ms 137.532us 0.000us 0.00% 269.186ms 555.023us 485
aten::linear 0.15% 2.613ms 2.51% 44.805ms 122.753us 0.000us 0.00% 237.179ms 649.805us 365
sgemm_128x128x8_TN_vec 0.00% 0.000us 0.00% 0.000us 0.000us 184.272ms 15.85% 184.272ms 1.474ms 125
maxwell_sgemm_128x64_nn 0.00% 0.000us 0.00% 0.000us 0.000us 181.512ms 15.62% 181.512ms 427.087us 425
maxwell_sgemm_128x64_tn 0.00% 0.000us 0.00% 0.000us 0.000us 166.766ms 14.35% 166.766ms 397.062us 420
sgemm_128x128x8_NT_vec 0.00% 0.000us 0.00% 0.000us 0.000us 94.916ms 8.17% 94.916ms 1.460ms 65
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 1.783s
Self CUDA time total: 1.162s
run_warmup(model=model, batch=batch, optimizer=optimizer, criterion=criterion)
profiled = run_profiler(
model=model,
batch=batch,
optimizer=optimizer,
criterion=criterion,
enable_backward=True,
enable_optimizer=True,
mixed_precision=False,
profile_steps=5,
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
profile_memory=True,
with_stack=True,
record_shapes=True,
with_flops=False,
schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=3),
on_trace_ready=trace_handler,
)
You can plug your memory_snapshot.pickle
into https://pytorch.org/memory_viz.
print(profiled.key_averages().table(sort_by="cuda_memory_usage", row_limit=50))
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::mul 1.81% 5.259ms 2.62% 7.591ms 26.729us 16.628ms 7.17% 16.628ms 58.549us 24 b 24 b 3.21 Gb 3.21 Gb 284
forward_pass 4.13% 11.992ms 12.06% 34.999ms 34.999ms 0.000us 0.00% 70.702ms 70.702ms 440 b 244 b 2.15 Gb -1.94 Gb 1
aten::div 1.63% 4.720ms 2.37% 6.889ms 25.234us 10.180ms 4.39% 10.180ms 37.289us 8 b 8 b 1.99 Gb 1.99 Gb 273
aten::mm 1.94% 5.618ms 3.01% 8.729ms 39.858us 150.881ms 65.08% 150.881ms 688.954us 0 b 0 b 1.71 Gb 1.71 Gb 219
aten::empty 0.72% 2.082ms 0.72% 2.082ms 13.788us 0.000us 0.00% 0.000us 0.000us 96 b 96 b 1.17 Gb 1.17 Gb 151
MulBackward0 0.10% 297.000us 0.90% 2.600ms 53.061us 0.000us 0.00% 5.903ms 120.469us 0 b 0 b 1.14 Gb 24.00 Mb 49
aten::matmul 0.32% 925.000us 2.77% 8.047ms 82.959us 0.000us 0.00% 53.449ms 551.021us 0 b 0 b 1.13 Gb 0 b 97
MmBackward0 0.38% 1.100ms 2.77% 8.050ms 110.274us 0.000us 0.00% 101.425ms 1.389ms 0 b 0 b 1.00 Gb 0 b 73
DivBackward0 0.30% 867.000us 2.32% 6.722ms 78.163us 0.000us 0.00% 9.584ms 111.442us 0 b 0 b 1020.20 Mb -882.00 Mb 86
aten::empty_like 0.14% 419.000us 0.56% 1.627ms 13.446us 0.000us 0.00% 0.000us 0.000us 0 b 0 b 864.00 Mb 24.00 Mb 121
aten::clone 0.25% 726.000us 1.85% 5.358ms 44.650us 0.000us 0.00% 5.261ms 43.842us 0 b 0 b 864.00 Mb -24.00 Mb 120
aten::neg 0.35% 1.002ms 0.55% 1.589ms 24.828us 3.204ms 1.38% 3.204ms 50.062us 0 b 0 b 804.14 Mb 804.14 Mb 64
aten::linear 0.12% 353.000us 1.87% 5.432ms 74.411us 0.000us 0.00% 47.933ms 656.616us 0 b 0 b 726.12 Mb 30.00 Mb 73
aten::pow 0.47% 1.376ms 0.82% 2.391ms 38.565us 1.737ms 0.75% 2.377ms 38.339us 0 b 0 b 588.00 Mb 588.00 Mb 62
aten::bmm 0.61% 1.767ms 0.83% 2.409ms 33.458us 7.221ms 3.11% 7.221ms 100.292us 0 b 0 b 576.00 Mb 576.00 Mb 72
aten::add 0.47% 1.369ms 0.67% 1.945ms 25.933us 2.466ms 1.06% 2.466ms 32.880us 0 b 0 b 510.20 Mb 510.20 Mb 75
aten::exp 0.15% 433.000us 0.22% 632.000us 25.280us 2.020ms 0.87% 2.020ms 80.800us 0 b 0 b 510.12 Mb 510.12 Mb 25
aten::reshape 0.41% 1.199ms 2.11% 6.133ms 15.807us 0.000us 0.00% 3.612ms 9.309us 0 b 0 b 504.00 Mb 0 b 388
aten::empty_strided 0.25% 723.000us 0.25% 723.000us 3.394us 0.000us 0.00% 0.000us 0.000us 300 b 300 b 397.07 Mb 397.07 Mb 213
aten::_foreach_sqrt 0.11% 310.000us 0.34% 981.000us 981.000us 1.601ms 0.69% 1.601ms 1.601ms 0 b 0 b 397.07 Mb 0 b 1
BmmBackward0 0.10% 285.000us 0.75% 2.185ms 91.042us 0.000us 0.00% 4.825ms 201.042us 0 b 0 b 360.00 Mb 0 b 24
ErfBackward0 0.07% 207.000us 0.60% 1.739ms 144.917us 0.000us 0.00% 6.193ms 516.083us 0 b 0 b 336.00 Mb -1.08 Gb 12
aten::masked_fill 0.11% 325.000us 0.71% 2.070ms 86.250us 0.000us 0.00% 2.533ms 105.542us 0 b 0 b 288.00 Mb 0 b 24
aten::erf 0.06% 180.000us 0.09% 254.000us 21.167us 1.142ms 0.49% 1.142ms 95.167us 0 b 0 b 288.00 Mb 288.00 Mb 12
aten::zeros 0.03% 84.000us 0.20% 589.000us 39.267us 0.000us 0.00% 438.000us 29.200us 0 b 0 b 251.80 Mb 0 b 15
aten::sub 0.08% 239.000us 0.11% 331.000us 25.462us 984.000us 0.42% 984.000us 75.692us 0 b 0 b 222.12 Mb 222.12 Mb 13
ExpBackward0 0.03% 89.000us 0.15% 430.000us 33.077us 0.000us 0.00% 1.173ms 90.231us 0 b 0 b 222.12 Mb 0 b 13
SubBackward0 0.01% 40.000us 0.12% 361.000us 27.769us 0.000us 0.00% 880.000us 67.692us 0 b 0 b 222.12 Mb 0 b 13
MaxBackward0 0.02% 47.000us 0.38% 1.113ms 85.615us 0.000us 0.00% 521.000us 40.077us 0 b 0 b 222.12 Mb 0 b 13
aten::value_selecting_reduction_backward 0.04% 108.000us 0.36% 1.042ms 80.154us 0.000us 0.00% 511.000us 39.308us 0 b 0 b 222.12 Mb 0 b 13
PowBackward0 0.09% 265.000us 0.97% 2.820ms 112.800us 0.000us 0.00% 1.976ms 79.040us 0 b -8 b 174.00 Mb -276.00 Mb 25
UnsafeViewBackward0 0.16% 473.000us 0.76% 2.216ms 16.662us 0.000us 0.00% 1.032ms 7.759us 0 b 0 b 144.00 Mb 12.00 Mb 133
ViewBackward0 0.16% 464.000us 0.75% 2.170ms 16.316us 0.000us 0.00% 792.000us 5.955us 0 b 0 b 144.00 Mb 12.00 Mb 133
MaskedFillBackward0 0.07% 205.000us 0.39% 1.136ms 94.667us 0.000us 0.00% 1.057ms 88.083us 0 b 0 b 144.00 Mb 24.00 Mb 12
GatherBackward0 0.00% 6.000us 0.03% 88.000us 88.000us 0.000us 0.00% 145.000us 145.000us 0 b 0 b 78.12 Mb 0 b 1
aten::gather_backward 0.00% 8.000us 0.03% 82.000us 82.000us 0.000us 0.00% 145.000us 145.000us 0 b 0 b 78.12 Mb 0 b 1
aten::new_zeros 0.00% 7.000us 0.01% 40.000us 40.000us 0.000us 0.00% 136.000us 136.000us 0 b 0 b 78.12 Mb 0 b 1
aten::new_empty 0.00% 3.000us 0.00% 13.000us 13.000us 0.000us 0.00% 0.000us 0.000us 0 b 0 b 78.12 Mb 0 b 1
aten::contiguous 0.01% 21.000us 0.18% 513.000us 42.750us 0.000us 0.00% 432.000us 36.000us 0 b 0 b 72.00 Mb 0 b 12
EmbeddingBackward0 0.00% 6.000us 0.04% 108.000us 54.000us 0.000us 0.00% 214.000us 107.000us 0 b 0 b 29.67 Mb 0 b 2
aten::embedding_backward 0.00% 5.000us 0.04% 102.000us 51.000us 0.000us 0.00% 214.000us 107.000us 0 b 0 b 29.67 Mb 0 b 2
aten::embedding_dense_backward 0.01% 29.000us 0.03% 97.000us 48.500us 161.000us 0.07% 214.000us 107.000us 0 b 0 b 29.67 Mb 0 b 2
autograd::engine::evaluate_function: EmbeddingBackwa... 0.01% 18.000us 0.04% 126.000us 63.000us 0.000us 0.00% 214.000us 107.000us 0 b 0 b 23.30 Mb -6.38 Mb 2
aten::resize_ 0.01% 34.000us 0.01% 34.000us 11.333us 0.000us 0.00% 0.000us 0.000us 0 b 0 b 6.38 Mb 6.38 Mb 3
aten::embedding 0.01% 36.000us 0.05% 156.000us 78.000us 0.000us 0.00% 80.000us 40.000us 0 b 0 b 6.38 Mb 0 b 2
aten::index_select 0.02% 50.000us 0.04% 111.000us 55.500us 80.000us 0.03% 80.000us 40.000us 0 b 0 b 6.38 Mb 0 b 2
autograd::engine::evaluate_function: MmBackward0 0.32% 921.000us 3.09% 8.971ms 122.890us 0.000us 0.00% 101.425ms 1.389ms 0 b 0 b 4.92 Mb -1020.12 Mb 73
aten::sum 0.99% 2.874ms 1.35% 3.923ms 34.412us 4.123ms 1.78% 4.123ms 36.167us 0 b 0 b 4.23 Mb 4.23 Mb 114
aten::max 0.15% 442.000us 0.19% 544.000us 41.846us 1.390ms 0.60% 1.390ms 106.923us 0 b 0 b 3.40 Mb 3.40 Mb 13
autograd::engine::evaluate_function: SubBackward0 0.06% 175.000us 0.32% 921.000us 70.846us 0.000us 0.00% 1.718ms 132.154us 0 b 0 b 1.13 Mb -222.12 Mb 13
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 290.215ms
Self CUDA time total: 231.841ms