How to Calculate the Number of FLOPs in GPT-2#

Twitter Handle LinkedIn Profile GitHub Profile Tag Code

This notebook references from Andrej Karpathy’s NanoGPT, which originally stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc.

Configurations, Constants and Enums#

@dataclass
class GPTConfig:
    num_decoder_blocks: int = 12
    context_length: int = 1024
    n_embd: int = 768
    ffw_size: int = 3072  # note, this is 4 * n_embd
    n_head: int = 12
    vocab_size: int = 50257
    bias: Literal[False] = False

    def __post_init__(self) -> None:
        assert self.ffw_size == 4 * self.n_embd, "ffw_size must be 4 * n_embd"
        assert self.bias is False, "bias must be False in this experiment."


class GPT2ModelType(Enum):
    GPT2 = "gpt2"
    GPT2_MEDIUM = "gpt2-medium"
    GPT2_LARGE = "gpt2-large"
    GPT2_XL = "gpt2-xl"


class ByteUnits(IntEnum):
    B = 1  # Byte = 1 byte
    KB = 1000  # Kilobyte = 10^3 bytes
    MB = 1000**2  # Megabyte = 10^6 bytes
    GB = 1000**3  # Gigabyte = 10^9 bytes


class FloatingPointPrecision(IntEnum):
    FP32 = 4  # 32-bit floating-point, 4 bytes
    FP16 = 2  # 16-bit floating-point, 2 bytes
    BFLOAT16 = 2  # bfloat16, 16-bit, 2 bytes


class GPUMemory(Enum):
    A100_40GB = 40e9  # 40 GB for NVIDIA A100
    V100_16GB = 16e9  # 16 GB for NVIDIA V100
    V100_32GB = 32e9  # 32 GB for NVIDIA V100
    T4_16GB = 16e9  # 16 GB for NVIDIA T4
    P100_16GB = 16e9  # 16 GB for NVIDIA P100


class GPU:
    def __init__(self, name: str, flops: Dict[FloatingPointPrecision, float]) -> None:
        self.name = name
        self.flops = flops

class A100(GPU):
    def __init__(self) -> None:
        super().__init__("A100", {
            FloatingPointPrecision.FP32: 19.5e12,
            FloatingPointPrecision.FP16: 312e12,
            FloatingPointPrecision.BFLOAT16: 312e12
        })
gpt2_config = GPTConfig()
pprint(gpt2_config)
GPTConfig(
│   num_decoder_blocks=12,
│   context_length=1024,
│   n_embd=768,
│   ffw_size=3072,
│   n_head=12,
│   vocab_size=50257,
│   bias=False
)

Total Trainable Parameters#

def total_trainable_parameters(model: nn.Module, include_bias: bool = True) -> int:
    """Returns the number of trainable parameters in the model."""
    if not include_bias:
        return sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and "bias" not in name)
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2ModelType.GPT2.value)
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
gpt2_params_no_bias = total_trainable_parameters(gpt2, include_bias=False)
gpt2_params_with_bias = total_trainable_parameters(gpt2, include_bias=True)

print(
    f"Number of trainable parameters in GPT2 model: {gpt2_params_no_bias} (excluding bias) and {gpt2_params_with_bias} (including bias)."
)
Number of trainable parameters in GPT2 model: 124337664 (excluding bias) and 124439808 (including bias).

Since Karpathy’s blog post assumed that there is no bias for simplicity, we will also assume that there is no bias in the linear layers. We confirmed that the number of params (124337664) for the smallest GPT-2 model indeed matches the number of params given by Karpathy.

In what follows, we would assume the smallest GPT-2 model and work out the theoretical model for the Transformer.

# config_args = {
#     'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
#     'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
#     'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
#     'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
# }[model_type]


def params(
    num_decoder_blocks: int = 12,
    context_length: int = 1024,
    n_embd: int = 768,
    ffw_size: int = 3072,
    vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = n_embd * context_length
    out["embedding/token"] = n_embd * vocab_size
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = n_embd  # note, bias=False in our LN
    out["attention/kqv"] = n_embd * 3 * n_embd
    out["attention/proj"] = n_embd**2
    out["attention"] = out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]

    # MLP blocks
    assert ffw_size == 4 * n_embd, "ffw_size must be 4 * n_embd"
    out["mlp/ln"] = n_embd
    out["mlp/ffw"] = n_embd * ffw_size
    out["mlp/proj"] = ffw_size * n_embd
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["ln_f"] = n_embd  # final layernorm
    out["dense"] = 0  # 0 because of parameter sharing. This layer uses the weights from the embedding layer

    # total
    out["total"] = out["embedding"] + out["transformer"] + out["ln_f"] + out["dense"]

    return out
params_dict = params()
gpt2_params_no_bias_manual = params_dict["total"]

# Compare to expected PyTorch model parameter count
expected_params = gpt2_params_no_bias
comparison_result = gpt2_params_no_bias_manual == expected_params
comparison_msg = f"We see: {gpt2_params_no_bias_manual}, Expected: {expected_params}, Match: {comparison_result}"

data = {
    "Name": params_dict.keys(),
    "Parameters": params_dict.values(),
    "Ratio (%)": [value / gpt2_params_no_bias_manual * 100 for value in params_dict.values()],
}
df = pd.DataFrame(data)

# Printing comparison result and parameter distribution table
print(comparison_msg + "\n")
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))
We see: 124337664, Expected: 124337664, Match: True

+--------------------+------------+-----------------------+
|        Name        | Parameters |       Ratio (%)       |
+--------------------+------------+-----------------------+
| embedding/position |   786432   |  0.6324970042866496   |
|  embedding/token   |  38597376  |  31.042384711361475   |
|     embedding      |  39383808  |  31.674881715648123   |
|    attention/ln    |    768     | 0.0006176728557486812 |
|   attention/kqv    |  1769472   |  1.4231182596449616   |
|   attention/proj   |   589824   |  0.47437275321498723  |
|     attention      |  2360064   |  1.8981086857156975   |
|       mlp/ln       |    768     | 0.0006176728557486812 |
|      mlp/ffw       |  2359296   |   1.897491012859949   |
|      mlp/proj      |  2359296   |   1.897491012859949   |
|        mlp         |  4719360   |   3.795599698575646   |
|       block        |  7079424   |   5.693708384291344   |
|    transformer     |  84953088  |   68.32450061149613   |
|        ln_f        |    768     | 0.0006176728557486812 |
|       dense        |     0      |          0.0          |
|       total        | 124337664  |         100.0         |
+--------------------+------------+-----------------------+

Calculating Checkpoint Size and Fluff Ratio#

The functions below perform a series of calculations related to the size of a GPT-2 model checkpoint, both measured and estimated, and computes the “fluff ratio” to compare these sizes. The purpose of these calculations is to evaluate how closely the estimated size of a GPT-2 model checkpoint matches the actual measured size, and to quantify any overhead or additional data in the checkpoint file as a percentage of the estimated size.

def calculate_checkpoint_size(params_count: int, precision: FloatingPointPrecision, units: ByteUnits) -> float:
    """
    Calculate the estimated checkpoint size in specified units.

    This function estimates the checkpoint size for a model given the number
    of parameters, the precision of these parameters, and
    the desired units for the result. It accounts for the AdamW optimizer's
    storage requirements by adding two times the parameter bytes to account
    for the optimizer's moment and velocity vectors.

    Parameters
    ----------
    params_count : int
        The number of parameters excluding biases.
    precision : FloatingPointPrecision
        The floating point precision of the parameters.
    units : ByteUnits
        The units for the resulting checkpoint size.

    Returns
    -------
    float
        The estimated checkpoint size in the specified units.

    Notes
    -----
    The AdamW optimizer requires additional storage for each parameter
    for maintaining momentum and variance vectors, hence the calculation
    includes 2 * params_bytes to accommodate these.
    """
    params_bytes = params_count * precision.value
    params_and_buffers_bytes = params_bytes + 2 * params_bytes  # AdamW optimizer buffers
    return params_and_buffers_bytes / units.value


def calculate_fluff_ratio(measured_bytes: int, estimated_bytes: float, units: ByteUnits) -> float:
    """
    Calculate the fluff ratio between measured and estimated checkpoint sizes.

    The fluff ratio is a measure of the overhead or additional data in the
    checkpoint file, expressed as a percentage of the estimated size. This
    function converts the estimated size from gigabytes (or specified units)
    to bytes before calculating the ratio to ensure consistency in units.

    Parameters
    ----------
    measured_bytes : int
        The actual size of the checkpoint file, in bytes.
    estimated_bytes : float
        The estimated size of the checkpoint file, in the specified units.
    units : ByteUnits
        The units in which the estimated bytes are provided.

    Returns
    -------
    float
        The fluff ratio, expressed as a percentage.
    """
    estimated_bytes_in_bytes = estimated_bytes * units.value
    return (measured_bytes / estimated_bytes_in_bytes) * 100
  1. Measured Checkpoint Size in Bytes:

    • gpt2_checkpoint_size_measured_in_bytes is assigned a numerical value that represents the actual size of a GPT-2 model checkpoint file in bytes. This value is obtained from the output of the Unix command wc -c ckpt.pt, which counts the number of bytes in the file ckpt.pt.

  2. Estimated Checkpoint Size in Bytes:

    • The calculate_checkpoint_size function is called with the number of parameters excluding biases (gpt2_params_no_bias), the precision of the model’s parameters (FloatingPointPrecision.FP32), and the unit of measurement (ByteUnits.B for bytes). This function calculates the estimated total size of the checkpoint in bytes, taking into account the parameters and the additional storage required for the AdamW optimizer’s buffers.

    • It is worth noting we are assuming floating-point precision of 32 bits (4 bytes) for the model’s parameters, and hence we are multiplying the number of parameters by 4 to obtain the size in bytes.

    • The AdamW optimizer, which is commonly used in training deep learning models for tasks like those involving GPT-2, maintains two additional values (buffers) for each parameter: the first for the moment vector (m) and the second for the squared moment vector (v). These buffers are used to adapt the learning rates for each parameter during training. This is why the storage requirement triples (params_bytes + 2 * params_bytes), accounting for the original parameters plus the two buffers.

  3. Fluff Ratio Calculation:

    • The calculate_fluff_ratio function is called with the measured size in bytes, the estimated size in bytes, and the unit of measurement for the estimated size (bytes). This function calculates the fluff ratio, which indicates the percentage of overhead or additional data in the measured checkpoint file compared to the estimated size.

gpt2_checkpoint_size_measured_in_bytes = 1542470366  # from 'wc -c ckpt.pt'
gpt2_checkpoint_size_measured_in_gb = gpt2_checkpoint_size_measured_in_bytes / ByteUnits.GB

gpt2_checkpoint_size_estimated_in_bytes = calculate_checkpoint_size(
    params_count=gpt2_params_no_bias,
    precision=FloatingPointPrecision.FP32,
    units=ByteUnits.B,
)
gpt2_checkpoint_size_estimated_in_gb = gpt2_checkpoint_size_estimated_in_bytes / ByteUnits.GB


fluff_ratio = calculate_fluff_ratio(
    measured_bytes=gpt2_checkpoint_size_measured_in_bytes,
    estimated_bytes=gpt2_checkpoint_size_estimated_in_bytes,
    units=ByteUnits.B,
)

data = [
    ["Measured Checkpoint Size (bytes)", gpt2_checkpoint_size_measured_in_bytes],
    ["Measured Checkpoint Size (GB)", gpt2_checkpoint_size_measured_in_gb],
    ["Estimated Checkpoint Size (bytes)", gpt2_checkpoint_size_estimated_in_bytes],
    ["Estimated Checkpoint Size (GB)", gpt2_checkpoint_size_estimated_in_gb],
    ["Fluff Ratio", fluff_ratio],
]

print(tabulate(data, headers=["Metric", "Value"], tablefmt="pretty"))
+-----------------------------------+-------------------+
|              Metric               |       Value       |
+-----------------------------------+-------------------+
| Measured Checkpoint Size (bytes)  |    1542470366     |
|   Measured Checkpoint Size (GB)   |    1.542470366    |
| Estimated Checkpoint Size (bytes) |   1492051968.0    |
|  Estimated Checkpoint Size (GB)   |    1.492051968    |
|            Fluff Ratio            | 103.3791314968461 |
+-----------------------------------+-------------------+

GPU Memory Footprint of Loading Model and Optimizer#

We can roughly understand that a checkpoint represents the amount of memory needed to store not just the model itself (its weights) but also additional information related to the optimizer state when you’re using GPUs for deep learning tasks.

When loading a model from a checkpoint for further training or inference, the GPU memory must accommodate the model weights and the optimizer state (if continuing training).

Below, we estimate the ratio of our GPU memory that will be taken up by the model and optimizer state when loading a GPT-2 model from a checkpoint.

def calculate_memory_ratio(checkpoint_size: float, gpu_memory: GPUMemory) -> str:
    memory_ratio = checkpoint_size / gpu_memory.value * 100
    return f"Memory ratio taken up just for parameters: {memory_ratio:.2f}%"


print(calculate_memory_ratio(checkpoint_size=gpt2_checkpoint_size_estimated_in_bytes, gpu_memory=GPUMemory.A100_40GB))
Memory ratio taken up just for parameters: 3.73%

Assuming an A100 GPU with roughly 40GB memory, then the code calculates the percentage of the GPU memory that the estimated checkpoint size (in bytes) occupies. This calculation gives an insight into how much of the GPU’s memory is dedicated to storing the model’s weights and the optimizer’s buffers, without considering other memory usages such as activations during forward and backward passes.

This percentage is relatively small, implying that most of the GPU memory is actually used for activations. Activations are the intermediate outputs of layers during the forward pass and their gradients during the backward pass, which can consume significant amounts of memory, especially in deep models and with large batch sizes.

Estimating FLOPs for a Single Forward Pass#

In order to estimate FLOPs for a single forward pass, we would first need to define what is a FLOPS.

Basics of Floating Point Numbers#

  • Floating Point Representation: In computers, numbers can be represented in various formats, and one common format is floating point. This format is used to represent real numbers (numbers with fractions) using a fixed amount of memory, allowing for a wide range of values. A floating point number is composed of a sign, an exponent, and a mantissa (or significand). This representation can handle very large numbers, very small numbers, and fractions.

  • Operations on Floating Point Numbers: Operations on floating point numbers include addition, subtraction, multiplication, and division. Each of these operations takes one or more floating point numbers as input and produces a floating point number as output.

Floating Point Operations (FLOPs)#

Floating Point Operations, or FLOPs, refer to individual mathematical operations (additions, subtractions, multiplications, divisions) performed on floating point numbers. Each operation counts as one FLOP.

Counting FLOPs of Matrix Multiplications#

In the context of deep learning, many operations are done via matrix multiplications, we will take a look at how to count FLOPs for matrix multiplications next.

Deep learning, particularly in neural networks, relies heavily on matrix multiplications. A single matrix multiplication operation involves multiple floating point multiplications and additions.

Consider two matrices \(\mathbf{A}\) and \(\mathbf{B}\) of size \(m \times n\) and \(n \times p\):

\[\begin{split} \mathbf{A} = \begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{m1} & a_{m2} & \cdots & a_{mn} \end{bmatrix}_{m \times n} \quad \mathbf{B} = \begin{bmatrix} b_{11} & b_{12} & \cdots & b_{1p} \\ b_{21} & b_{22} & \cdots & b_{2p} \\ \vdots & \vdots & \ddots & \vdots \\ b_{n1} & b_{n2} & \cdots & b_{np} \end{bmatrix}_{n \times p} \end{split}\]

It is easy to see that if we want to compute the product \(\mathbf{C} = \mathbf{A} \mathbf{B}\), the element \(c_{ij}\) of \(\mathbf{C}\) is given by:

\[ c_{ij} = \sum_{k=1}^{n} a_{ik} b_{kj} \]

and therefore there are a total of \(m \times n \times p\) multiplications and \(m \times (n-1) \times p\) additions. This amounts to roughly:

(1)#\[ m \times n \times p + m \times (n-1) \times p \approx 2 \times m \times n \times p \]

FLOPs. Note this is basically because matrix multiplication is a series of dot products, and each dot product involves \(n\) multiplications and \(n-1\) additions.

Estimating FLOPs for a Single Forward Pass of GPT-2#

def flops(
    num_decoder_blocks: int = 12,
    context_length: int = 1024,
    n_embd: int = 768,
    n_head: int = 12,
    ffw_size: int = 3072,
    vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = n_embd // n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * context_length * (n_embd * 3 * n_embd)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * context_length * context_length * n_embd
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * n_head * (context_length * context_length * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * context_length * (n_embd * n_embd)
    out["attention"] = sum(out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"])

    # MLP blocks
    ffw_size = 4 * n_embd  # feed forward size
    out["mlp/ffw1"] = 2 * context_length * (n_embd * ffw_size)
    out["mlp/ffw2"] = 2 * context_length * (ffw_size * n_embd)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["dense"] = 2 * context_length * (n_embd * vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["dense"]
    out["backward_total"] = 2 * out["forward_total"]  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out

The flops function calculates the total number of floating point operations required to process a single sample \(\mathbf{x} = \left(x_1, x_2, \ldots, x_{T}\right)\) of length \(T\) through the entire model for a single forward pass.

We take one sample snippet of code to explain how the flops are calculated:

# 2) calculating the attention scores
out["attention/scores"] = 2 * context_length * context_length * n_embd

This is not difficult to see if one recalls the attention mechanism:

\[ \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V} \]

where \(\mathbf{Q}\), \(\mathbf{K}\), and \(\mathbf{V}\) are the query, key, and value matrices of size \(T \times d_q\), \(T \times d_k\), and \(T \times d_v\), respectively. For simplicity, we assume that \(d_q = d_k = d_v = D\) (which is n_embd in the code).

In particular attention_scores is calculated as:

attention_scores  = torch.matmul(query, key.transpose(dim0=-2, dim1=-1)) / torch.sqrt(torch.tensor(d_q).float())

which is the dot product of the query and key matrices, divided by the square root of the dimension of the query matrix and is of shape \(T \times T\). However, recall that ultimately the matrix multiplication of the two matrices \(\mathbf{Q}\) and \(\mathbf{K}^{\top}\) is of shape \(T \times D\) and \(D \times T\), and by our earlier equation (1), this would be a total of \(2 \times T \times T \times D\) FLOPs, coinciding with the code.

f = flops()
flops_total = f["forward_total"]

table = [("name", "flops", "ratio (%)")]
for k, v in f.items():
    table.append((k, v, v / flops_total * 100))

print(tabulate(table, headers="firstrow", tablefmt="pretty", numalign="right"))
+------------------+--------------+---------------------+
|       name       |    flops     |      ratio (%)      |
+------------------+--------------+---------------------+
|  attention/kqv   |  3623878656  | 1.2425508965889174  |
| attention/scores |  1610612736  | 0.5522448429284077  |
| attention/reduce |  1610612736  | 0.5522448429284077  |
|  attention/proj  |  1207959552  | 0.41418363219630583 |
|    attention     |  8053063680  | 2.7612242146420387  |
|     mlp/ffw1     |  4831838208  | 1.6567345287852233  |
|     mlp/ffw2     |  4831838208  | 1.6567345287852233  |
|       mlp        |  9663676416  | 3.3134690575704466  |
|      block       | 17716740096  |  6.074693272212485  |
|   transformer    | 212600881152 |  72.89631926654981  |
|      dense       | 79047426048  |  27.10368073345018  |
|  forward_total   | 291648307200 |        100.0        |
|  backward_total  | 583296614400 |        200.0        |
|      total       | 874944921600 |        300.0        |
+------------------+--------------+---------------------+

Sanity Check with Palm Paper’s FLOPs Calculation#

NOTE: The notations below are purely based on the PaLM 2 paper, and is not what I usually use.

The palm_flops function, as inspired by the technical report of Google’s PaLM 2, calculates an estimate of the model’s floating point operations (FLOPs) but introduces a specific formula for computing model FLOPs utilization (MFU) per token and for the entire model.

  1. Non-Embedding Model Parameters (N): This calculation starts by estimating the number of non-embedding model parameters. In transformer models like PaLM, a significant portion of the parameters resides in the embedding layers. This formula adjusts for that by subtracting the embedding/position parameters from the total, focusing on the parameters actively involved in computations outside of embeddings.

  2. Model Dimensions (L, H, Q, T):

    • L = Number of layers (n_layer)

    • H = Number of attention heads (n_head)

    • Q = Size of each attention head (n_embd // n_head)

    • T = Sequence length (block_size), also referred to as context length in other discussions.

  3. MF Per Token (mf_per_token): This represents the estimated FLOPs for processing a single token, calculated as 6*N + 12*L*H*Q*T.

  4. Total Model FLOPs for a sequence (mf_per_sequence): This is calculated by multiplying the per-token FLOPs estimate (mf_per_token) by the sequence length (block_size or T). This gives the total estimated FLOPs for processing a sequence of length T.

This is more of a sanity check for Karpathy, and he confirms if using PaLM’s palm_flops function to calculate FLOPs for our GPT-2 model yields similar results to the ones he wrote himself (flops function).

# now here is an estimate copy pasted from the PaLM paper
# this formula is often used to calculate MFU (model flops utilization)
def palm_flops(
    params: OrderedDict[str, int], num_decoder_blocks: int, n_head: int, n_embd: int, context_length: int
) -> int:
    """Estimate of the model flops following PaLM paper formula."""
    # non-embedding model parameters. note that we do not subtract the
    # embedding/token params because those are tied and get used in the last layer.
    N = params()["total"] - params()["embedding/position"]
    L, H, Q, T = num_decoder_blocks, n_head, n_embd // n_head, context_length
    mf_per_token = 6 * N + 12 * L * H * Q * T
    mf_per_sequence = mf_per_token * context_length
    return mf_per_sequence
gpt2_flops_using_palm_flops_calculation = palm_flops(
    params=params,
    num_decoder_blocks=gpt2_config.num_decoder_blocks,
    n_head=gpt2_config.n_head,
    n_embd=gpt2_config.n_embd,
    context_length=gpt2_config.context_length,
)
gpt2_flops_using_own_flops_calculation = f["total"]
print(
    f"PaLM paper estimate of GPT2 flops: {gpt2_flops_using_palm_flops_calculation}, Our estimate: {gpt2_flops_using_own_flops_calculation}"
)
print(f"Ratio: {gpt2_flops_using_palm_flops_calculation / gpt2_flops_using_own_flops_calculation * 100:.5f}%")
PaLM paper estimate of GPT2 flops: 875062886400, Our estimate: 874944921600
Ratio: 100.01348%

Floating Point Operations Per Second (FLOPS)#

In computational tasks and processor performance assessment, we measure the capacity for floating-point computation in terms of FLOPS, an acronym that stands for Floating Point Operations Per Second. This metric indicates the quantity of floating-point arithmetic operations—specifically, additions, subtractions, multiplications, and divisions—that a computing system is capable of performing every second. For example, a processor with the capability to execute one trillion such operations within a second is said to have a computational performance of 1 teraFLOP (TFLOP) or simply 1 TFLOPS.

Some Practical Considerations for FLOPs in Deep Learning#

  • In deep learning models, the number of FLOPs required for a forward pass through the network gives an indication of the model’s complexity and efficiency. Models with higher FLOPs require more computational resources, which can affect training and inference times, especially on large datasets.

  • Hardware wise we also can gauge what types of GPU is needed for the model, especially when high FLOPS models are used, we would likely need a high-end GPU that can operate at a higher FLOPS per second.

FLOPS Per Second in GPUs#

Given the total number of FLOPs required for a forward pass through a deep learning model and the theoretical FLOPS per second of a GPU, you can estimate the time it takes to perform the forward pass on the GPU by dividing the total FLOPs by the GPU’s FLOPS capacity. This calculation assumes ideal conditions where the model fully utilizes the GPU’s computational capabilities.

\[ \text{Time for Forward Pass (seconds)} = \frac{\text{Total FLOPs for Forward Pass}}{\text{Theoretical FLOPS of GPU}} \]

This is assuming a single forward pass on a single sample. However, in practice, we use Model FLOPs Utilization (MFU) to measure the efficiency of the model, which is the ratio of the actual FLOPs to the theoretical FLOPS of the GPU.

Note that different GPUs have different FLOPS, and the FLOPS of a GPU can be affected by what floating point precision is used (e.g. FP16, FP32, FP64).

Model FLOPs Utilization (MFU)#

Given:

  • Batch size with gradient accumulation factored in: \(\text{batch_size} = 20 \times 5\) samples.

  • Measured time per iteration: \(\text{measured_time} = 0.755\) seconds (note this is Karpathy’s own measured time for the forward and backward pass - 1 iteration)

  • Total FLOPs required by the model for one forward and backward pass for a single sample/sequence \(\mathbf{x} = \left(x_1, x_2, \ldots, x_{T}\right)\): \(f[\text{total}]\) having a unit of FLOPs/sample.

  • The calculation of total FLOPs for a model’s operations is independent of the floating-point precision used (e.g., FP32, FP16, BF16). However when training models, we may choose different floating-point precisions and by extension, different hardware GPUs has different FLOPS for different precisions.

We will use FLOPS and TFLOPS (teraFLOPS) interchangeably where \(1\) TFLOPS represents \(10^{12}\) FLOPS.

  1. Measured Throughput:

    The throughput, in terms of samples processed per second, can be calculated as follows:

    \[ \text{measured_throughput} = \frac{\text{batch_size}}{\text{measured_time}} \text{ samples/second} \]

    Substituting the given values:

    \[ \text{measured_throughput} = \frac{20 \times 5}{0.755} = 132.4503311258278 \text{ samples/second} \]
  2. FLOPs Achieved:

    The total floating point operations per second (FLOPs) achieved, based on the measured throughput, is:

    \[ \text{flops_achieved_per_second} = f[\text{total}] \times \text{measured_throughput} \]

    Here, \(f[\text{total}]\) represents the total FLOPs required by the model for one complete pass (forward and backward) of a single sample, with a unit of FLOPs/sample. And multiplying it by the measured throughput gives the effective FLOPs achieved per second.

    \[ \text{flops_achieved_per_second} = \left( \frac{\text{FLOPs}}{\text{sample}} \right) \times \left( \frac{\text{sample}}{\text{second}} \right) \]

    On a side note, there should be no confusion in the cancellation of the sample term in the numerator and denominator even though I used sample and samples earlier. It is like saying if I can process \(100\) samples per second, and my model requires \(1000\) FLOPs per sample, then I can process \(100 \times 1000 = 100000\) FLOPs per second because I processed \(100\) samples in \(1\) second, and each sample required \(1000\) FLOPs.

  3. Fraction of A100 Utilization:

    Given the A100’s promised performance for bfloat16 operations:

    \[ \text{a100_bfloat16_flops_promised} = 312 \times 10^{12} \text{ FLOPS} \]

    The fraction of the A100 GPU utilized can be expressed as a percentage of the promised FLOPs:

    \[ \text{fraction of A100 used (\%)} = \left( \frac{\text{flops_achieved_per_second}}{\text{a100_bfloat16_flops_promised}} \right) \times 100 \]

    Substituting \(\text{flops_achieved_per_second}\) and \(\text{a100_bfloat16_flops_promised}\) with their respective values provides the percentage utilization of the A100 GPU’s computational capability.

# here is what we currently roughly measure
batch_size = 20 * 5  # 5 is grad_accum, so total batch size is 100
measured_time = 0.755  # in seconds per iteration
measured_throughput = batch_size / measured_time # number of samples processed per second
flops_achieved_per_second = f["total"] * measured_throughput

# A100 is cited to be 312 TFLOPS of bfloat16 running on tensor cores
a100_bfloat16_promised_flops = 312e12

# the fraction of the A100 that we are using:
print(f"fraction of A100 used: {flops_achieved_per_second / a100_bfloat16_promised_flops * 100:.2f}%")
fraction of A100 used: 37.14%

Theoretical Model FLOPs Utilization (MFU) Indicates a Rough Benchmark of Efficiency#

The Model FLOPs Utilization (MFU) of 37% in this context signifies that, for a specific model (like GPT-2), on a particular GPU (e.g., A100), and using a given floating point precision (e.g., bfloat16), the model’s training or inference process is utilizing 37% of the GPU’s theoretical maximum FLOPS capability. This percentage reflects the efficiency with which the model leverages the computational power of the GPU under those specific conditions.

If we treat this metric of 37% as a benchmark reported by some paper, and if your actual MFU is significantly lower than this 37%, it suggests that there might be room for optimization in how the model is executed on the hardware. It could indicate inefficiencies in data loading, model architecture not fully leveraging the GPU’s capabilities, or potential bottlenecks in the computation process that are preventing the GPU from being fully utilized. Consequently, MFU is a very common metric to monitor in the context of training large language models like GPT-2, as it can provide insights into the efficiency of the training process and help identify areas for potential optimization. After all, the GPUs are expensive to run, and we want to make sure we are getting the most out of them!

Relation of MFU and TFLOPS#

As we have seen earlier, we can denote the MFU formula as:

\[ \operatorname{MFU} = \frac{\operatorname{MODEL_FLOPS}}{\operatorname{GPU_FLOPS}}, \]

where \(\operatorname{MODEL_FLOPS}\) is the number of floating point operations per second and \(\operatorname{GPU_FLOPS}\) is the number of floating point operations per second that the GPU can perform.

Then we can easily see that the \(\operatorname{MODEL_FLOPS}\) is given by a re-arrangement of the MFU formula:

\[ \operatorname{MODEL_FLOPS} = \operatorname{MFU} \times \operatorname{GPU_FLOPS}. \]

This means if a report says they achieved \(250\) TFLOPS, and they are using NVIDIA A100 GPU with bfloat16 precision at a rate of \(312\) TFLOPS, then we can calculate the MFU as:

\[ \operatorname{MFU} = \frac{250}{312} \approx 0.8. \]

Theoretical FLOPs in Transformer Models#

The excerpt from Appendix B, page 66 of the PaLM paper PaLM: Scaling Language Modeling with Pathways provides a detailed explanation of how Model FLOPs Utilization (MFU) is calculated for a dense Transformer language model, focusing on the relationship between observed throughput (tokens-per-second) and the theoretical maximum throughput based on the model’s computational demands and the hardware’s peak FLOPs.

Given a corpus \(\mathcal{S}\) with \(M\) sequences:

\[ \mathcal{S} = \left\{\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_M\right\}, \]

where \(\mathbf{x}_m\) is a sequence of length \(T\) consisting of tokens:

\[ \mathbf{x}_m = \left\{x_{m,1}, x_{m,2}, \ldots, x_{m,T}\right\}, \]

Then we can easily see that there are a total of \(D = M \times T\) tokens. If we further denote the model (in our case GPT-2) as \(\mathcal{G}\), and the model’s learnable parameters as \(\theta \in \Theta\), then the total number of learnable parameters in the model can be denoted as \(M = |\theta|\).

  1. Basic Computation per Token for Non-Attention Components: The paper notes that, excluding self-attention, a decoder-only model with \(N\) parameters requires \(6N\) matrix multiplication (matmul) FLOPs per token — \(2N\) FLOPs for the forward pass and \(4N\) FLOPs for the backward pass. This doubling accounts for the additional computations required during backpropagation (gradient calculation and weight updates), where each matmul involves one multiplication and one addition per pair of input values.

  2. Self-Attention Computation: For the self-attention mechanism, an additional \(6LHQ(2T)\) FLOPs per token are needed, with \(L\), \(H\), \(Q\), and \(T\) representing the number of layers/blocks, heads, head dimension (embedding size of 1 head’s query matrix) and sequence/context length, respectively. This term accounts for the computations within the self-attention layers that are more complex due to their dependency on the sequence length and the architecture of the Transformer model.

  3. Total Computational Requirement for 1 Token: The total FLOPs per token, considering both matmuls for the non-attention components and the self-attention computations, are therefore summarized as \(6N + 12LHQ(2T)\). However, the self-attention part is noted to contribute a much smaller value for large models, suggesting that the primary computational load comes from the non-attention components.

  4. Total Computational Requirement for the Entire Corpus: The total computational requirement for the entire corpus is then given by \(D(6N + 12LHQ(2T))\), where \(D\) is the total number of tokens in the corpus.

    Sometimes people just use \(6ND\) as a rough estimate of the total FLOPs required for the entire corpus, which is a reasonable approximation for large models - in a sense \(\mathcal{O}(ND)\).

# Finally let's check out the 6ND approximation as total cost of training in FLOPs
N = params()["total"]  # this is number of parameters, N
D = 300e9  # 300B tokens, this is dataset size in tokens, D
a100_bfloat16_promised_flops = 312e12  # 312 TFLOPS
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
flops_throughput = a100_bfloat16_promised_flops * 8 * assumed_mfu  # assume an 8XA100 node at 30% utilization
flops_needed = 6 * N * D  # 6ND
time_needed_over_all_tokens_in_seconds = flops_needed / flops_throughput  # in seconds
print(f"time needed to train the model: {time_needed_over_all_tokens_in_seconds/3600/24:.2f} days")
time needed to train the model: 3.46 days

Karpathy reported that this number 3.46 days is close to his training time ~4 days. We see a modular function estimate_mfu below.

class MFUEstimationResult(BaseModel):
    flops_per_token_per_fwdbwd: float
    flops_per_sequence_per_fwdbwd: float
    flops_per_iter_per_fwdbwd: float
    flops_achieved_per_second: float
    mfu: float


def estimate_mfu(
    num_decoder_blocks: int,
    num_heads: int,
    d_model: int,
    context_length: int,
    model_total_parameters: int,
    effective_batch_size_per_iter: int,
    time_taken_per_iter: float,
    gpu_promised_flops: float = 312e12,  # A100 GPU bfloat16/float16 peak flops is 312 TFLOPS
) -> MFUEstimationResult:
    """
    Estimate Model FLOPs Utilization (MFU) as a ratio of achieved FLOPs to
    the A100 GPU's peak FLOPs capability.

    Parameters
    ----------
    num_decoder_blocks : int
        Number of decoder blocks in the Transformer model.
    num_heads : int
        Number of attention heads in each Transformer block.
    d_model : int
        Dimension of the model's embeddings.
    context_length : int
        Number of tokens in each input sequence.
    model_total_parameters : int
        Total number of learnable parameters in the model.
    effective_batch_size_per_iter : int
        Effective batch size processed in one iteration, accounting for
        gradient accumulation.
    time_taken_per_iter : float
        Time taken per training iteration in seconds.
    gpu_promised_flops : float, optional
        Theoretical peak performance of the GPU in FLOPs (default is
        312e12 for A100 GPU bfloat16/float16 operations).

    Returns
    -------
    MFUEstimationResult:
        Pydantic model with the following fields:
        - flops_per_token_per_fwdbwd: FLOPs required for forward and backward pass of a single token
        - flops_per_sequence_per_fwdbwd: FLOPs required for forward and backward pass of a single sequence
        - flops_per_iter_per_fwdbwd: FLOPs required for forward and backward pass of the effective batch size
        - flops_achieved_per_second: FLOPs achieved per second
        - mfu: Model FLOPs Utilization (MFU) as a ratio of achieved FLOPs to the A100 GPU's peak FLOPs capability

    Example
    -------
    >>> estimate_mfu(
    ...     num_decoder_blocks=6,
    ...     num_heads=8,
    ...     d_model=512,
    ...     context_length=1024,
    ...     model_total_parameters=1_000_000,
    ...     effective_batch_size_per_iter=20 * 8,  # 20 sequences per GPU, 8 GPUs
    ...     time_taken_per_iter=0.1, # 0.1 seconds per iteration
    ...     gpu_promised_flops=312e12, # A100 GPU bfloat16/float16 peak flops
    ... )

    Notes
    -----
    This function utilizes the formula from the PaLM paper Appendix B
    (https://arxiv.org/abs/2204.02311) for estimating the FLOPs required
    for one forward and backward pass of a single token and scales it up to
    the effective batch size and the given model architecture to calculate MFU.
    You can likely use it as a callback in your `Trainer` to log the MFU during training.
    """
    # fmt: off
    N, L, H, Q, T = model_total_parameters, num_decoder_blocks, num_heads, d_model // num_heads, context_length
    flops_per_token_per_fwdbwd = 6 * N + 12 * L * H * Q * T # 1 token forward and backward flops
    flops_per_sequence_per_fwdbwd = flops_per_token_per_fwdbwd * T # 1 sequence = T tokens
    flops_per_iter_per_fwdbwd = flops_per_sequence_per_fwdbwd * effective_batch_size_per_iter # 1 iter means if batch size is 100, then 100 sequences are processed in 1 iter
    # express our flops throughput as ratio of A100 bfloat16 peak flops
    flops_achieved_per_second = flops_per_iter_per_fwdbwd * (1.0 / time_taken_per_iter)  # per second
    mfu = flops_achieved_per_second / gpu_promised_flops
    # fmt: on
    return MFUEstimationResult(
        flops_per_token_per_fwdbwd=flops_per_token_per_fwdbwd,
        flops_per_sequence_per_fwdbwd=flops_per_sequence_per_fwdbwd,
        flops_per_iter_per_fwdbwd=flops_per_iter_per_fwdbwd,
        flops_achieved_per_second=flops_achieved_per_second,
        mfu=mfu,
    )
gradient_accumulation = 5
batch_size = 20
effective_batch_size = gradient_accumulation * batch_size
measured_time = 0.755  # in seconds per iteration
model_total_parameters = gpt2_params_no_bias - params()["embedding/position"]

mfu_estimates = estimate_mfu(
    num_decoder_blocks=gpt2_config.num_decoder_blocks,
    num_heads=gpt2_config.n_head,
    d_model=gpt2_config.n_embd,
    context_length=gpt2_config.context_length,
    model_total_parameters=model_total_parameters,
    effective_batch_size_per_iter=effective_batch_size,
    time_taken_per_iter=measured_time,
    gpu_promised_flops=312e12,  # A100 GPU bfloat16/float16 peak flops
)

pprint(mfu_estimates)
MFUEstimationResult(
│   flops_per_token_per_fwdbwd=854553600.0,
│   flops_per_sequence_per_fwdbwd=875062886400.0,
│   flops_per_iter_per_fwdbwd=87506288640000.0,
│   flops_achieved_per_second=115902369059602.66,
│   mfu=0.37148195211411106
)
def estimate_training_days(
    total_tokens_in_corpus: float,
    mfu_result: MFUEstimationResult,
    gpu_promised_flops: float = 312e12,  # Default A100 GPU peak FLOPs
    num_gpus: int = 8,  # Default number of GPUs
    assumed_mfu: Optional[float] = None,  # Optional manual MFU override
) -> float:
    """
    Estimate the total training time in days based on the model FLOPs
    utilization and other training parameters.

    Parameters
    ----------
    total_tokens_in_corpus : float
        Total number of tokens to be processed during training.
    mfu_result : MFUEstimationResult
        The result from the estimate_mfu function, containing FLOPs
        metrics and Model FLOPs Utilization.
    gpu_promised_flops : float, optional
        Theoretical peak FLOPs performance of a single GPU.
    num_gpus : int, optional
        Number of GPUs used in the training setup.
    assumed_mfu : float, optional
        If provided, overrides the MFU calculated in mfu_result to
        manually adjust the utilization rate.

    Returns
    -------
    training_days : float
        Estimated total training time in days.
    """
    mfu = assumed_mfu if assumed_mfu is not None else mfu_result.mfu

    # Total FLOPs needed for the entire dataset
    total_flops_needed = total_tokens_in_corpus * mfu_result.flops_per_token_per_fwdbwd

    # Effective throughput considering MFU
    effective_flops_per_second = gpu_promised_flops * num_gpus * mfu

    # Total training time in seconds
    total_training_time_seconds = total_flops_needed / effective_flops_per_second
    total_training_time_days = total_training_time_seconds / (3600 * 24)

    return total_training_time_days
total_tokens_in_corpus = 300e9  # 300B tokens
assumed_mfu = 0.3  # assume this model flops utilization
num_gpus = 8  # 8 GPUs
training_days = estimate_training_days(
    total_tokens_in_corpus=total_tokens_in_corpus,
    mfu_result=mfu_estimates,
    gpu_promised_flops=A100().flops[FloatingPointPrecision.BFLOAT16],
    num_gpus=8,
    assumed_mfu=assumed_mfu,
)
training_days
3.9626068376068373

For practical usage in training, see the following:

References and Further Readings#