Chapter 16
Matrix Products in DL & Transformers
Key ideas: Introduction

Introduction#

Neural networks are fundamentally stacks of matrix multiplications. A forward pass through a deep network is a product of weight matrices and activations. Each layer computes $A_{l+1} = \sigma(W_l A_l)$ where $W_l \in \mathbb{R}^{d_{l+1} \times d_l}$, $A_l \in \mathbb{R}^{n \times d_l}$ (batch size $n$, layer input dimension $d_l$, output dimension $d_{l+1}$). The cost is $O(n d_l d_{l+1})$ per layer. Transformers add attention: $\text{Attention}(Q, K, V) = \sigma(Q K^\top / \sqrt{d_k}) V$, which involves three GEMM operations and a softmax (polynomial in $n$ sequence length). For a Transformer with sequence length $L$, hidden dimension $d$, and $H$ attention heads per layer, attention cost is $O(L^2 d)$ (quadratic in sequence length—a major bottleneck). Modern accelerators (GPUs, TPUs) are matrix-multiply engines: billions of floating-point operations per second (TFLOPs). Utilization depends on arithmetic intensity (ops/byte): bandwidth-bound operations underutilize the accelerator; computation-bound operations (high arithmetic intensity) achieve near-peak performance. Understanding how to write matrix products that achieve high arithmetic intensity, and how to distribute them across devices, determines whether you can train billion-parameter models.

Important ideas#

  1. Matrix-matrix multiplication (GEMM) structure

    • Dense GEMM: $C \leftarrow AB$ with $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$.

    • Arithmetic: $mk + kmn + mn = O(mkn)$ floating-point operations (FLOPs).

    • Memory: $O(m + k + n)$ words (inputs + output); on GPU, $m, k, n$ can be $1000$s, requiring GB of memory.

    • Arithmetic intensity: $I = \frac{\text{FLOPs}}{\text{bytes}} = \frac{2mkn}{8(mk + kn + mn)} \approx \frac{mkn}{4(m + k + n)}$ (higher is better).

  2. Blocking and cache efficiency

    • GEMM blocked into $b \times b$ tiles; each tile multiplied using fast cache.

    • Cache line length (64 bytes typical); GEMM loads tile once, reuses it $O(b)$ times.

    • Roofline model: peak FLOP rate vs. memory bandwidth; if arithmetic intensity $< I_{\text{roof}}$, algorithm is bandwidth-bound.

  3. Batch matrix multiplication (batched GEMM)

    • Forward pass: $C_i \leftarrow A_i B_i$ for $i = 1, \ldots, B$ (batch size).

    • Exploit parallelism: process multiple batches on multiple cores/GPU SMs.

    • Highly efficient when batch size is large; small batches underutilize accelerator.

  4. Convolution as matrix multiplication (im2col, Winograd)

    • Convolution unfolds as GEMM: reshape input patches into columns; multiply by filter matrix; reshape output.

    • im2col: input image to column matrix; allows use of highly optimized GEMM (cuBLAS, MKL).

    • Cost: $O(kh kw d_{\text{in}} d_{\text{out}} h_{\text{out}} w_{\text{out}})$ (kernel height/width, input/output channels, spatial dims).

    • Winograd: fast convolution via transformed domain; reduces arithmetic but increases numerical complexity.

  5. Scaled dot-product attention

    • Query-key-value paradigm: $Q \in \mathbb{R}^{L \times d_k}$, $K, V \in \mathbb{R}^{L \times d_v}$ (sequence length $L$, head dimension $d_k, d_v$).

    • Attention: (1) $M = Q K^\top / \sqrt{d_k}$ (matrix product $L \times d_k \times d_k \times L = O(L^2 d_k)$), (2) $A = \text{softmax}(M)$ (per-row normalization, no matrix products), (3) $O = AV$ (matrix product $L \times L \times d_v = O(L^2 d_v)$).

    • Total: $O(L^2 (d_k + d_v)) = O(L^2 d)$ (quadratic in sequence length).

    • Challenge: for $L = 4096$ (typical transformer), $L^2 = 16M$ operations per attention head; billions for multi-head.

  6. Mixed precision and numerical stability

    • FP32 (single precision, float32): 32 bits, ~7 significant digits; gradients, weights commonly stored in FP32.

    • FP16 (half precision, float16): 16 bits, ~4 significant digits; range $[6 \times 10^{-8}, 6 \times 10^4]$; GPU operations 2–10× faster.

    • BFloat16 (Brain Float): 16 bits, same exponent range as FP32, reduced mantissa; intermediate between FP32 and FP16.

    • Mixed precision: compute GEMM in FP16 (fast), accumulate in FP32 (stable); scale loss to prevent underflow.

  7. Distributed matrix multiplication

    • Data parallelism: replicate model on each device; partition minibatches; synchronize gradients via all-reduce.

    • Model parallelism: partition matrix weights across devices; communication within matrix product (e.g., matmul followed by communication).

    • Pipeline parallelism: partition layers across devices; overlap computation on layer $i$ with communication on layer $i-1$.

    • Cost: compute + communication latency; communication often dominates at large scale (Roofline model).

Relevance to ML#

  • Convolutional neural networks (CNNs): Forward and backward passes are GEMM-heavy; efficiency determines whether you can train on billion-pixel images or video.

  • Recurrent neural networks (RNNs), LSTMs, GRUs: Fully-connected layers between timesteps; matrix products per timestep.

  • Transformers and large language models: Attention is $O(L^2 d)$ matrix products; for GPT-3 ($L = 2048$, $d = 12288$), attention dominates forward/backward.

  • Graph neural networks (GNNs): Graph convolution is sparse matrix product; efficiency depends on sparsity and format.

  • Distributed training: Modern LLMs trained on thousands of GPUs/TPUs; communication cost (network bandwidth) often exceeds computation cost.

Algorithmic development (milestones)#

  • 1969: Strassen algorithm: $O(n^{2.807})$ vs. $O(n^3)$ naive GEMM (theoretically significant; rarely used in practice due to constants).

  • 1979–1990: Level-1/2/3 BLAS (Basic Linear Algebra Subprograms); standardized interface for matrix ops; LAPACK (1992) built on BLAS.

  • 1995–2005: GPU era begins: NVIDIA GeForce, Tesla; GPUs have 100× more memory bandwidth than CPUs; GEMMs run 10–100× faster.

  • 2006: CUDA (Compute Unified Device Architecture) released; enables general-purpose GPU computing; cuBLAS optimized GEMM for NVIDIA GPUs.

  • 2011: Mixed precision training proposed; FP16 + loss scaling enables 10–100× speedups on GPUs.

  • 2012: AlexNet (Krizhevsky et al.) demonstrates deep CNN training on GPUs; FLOPs dominate; GEMM-heavy.

  • 2015: Batch normalization (Ioffe & Szegedy); reduces sensitivity to initialization; enables mixed precision at scale.

  • 2017: Transformer architecture (Vaswani et al.); attention is dense GEMM-based; quadratic in sequence length.

  • 2018–2020: Distributed training frameworks mature (PyTorch DDP, TensorFlow Horovod); trillion-parameter models trained via model parallelism.

  • 2020–2023: Flash Attention (Dao et al. 2022) reduces attention memory via block-sparse operations; Megatron-LM and DeepSpeed enable distributed GEMMs at petaflop scales.

Definitions#

  • GEMM (General Matrix Multiply): $C \leftarrow \alpha A B + \beta C$ (standard matrix multiply with scaling/accumulation).

  • FLOP (floating-point operation): One add or multiply; GEMM $C \leftarrow AB$ is $2mkn$ FLOPs.

  • Arithmetic intensity: $I = \frac{\text{FLOPs}}{\text{bytes read/written}}$ (ops per byte); high $I$ means compute-bound; low $I$ means bandwidth-bound.

  • Roofline model: Peak achievable throughput = $\min(\text{peak FLOP rate}, \text{memory bandwidth} \times \text{arithmetic intensity})$.

  • Memory-bound: Algorithm where memory bandwidth is bottleneck; cannot achieve peak FLOP rate.

  • Compute-bound: Algorithm where compute is bottleneck; limited by FLOPs/cycle, not memory.

  • Mixed precision: Using multiple precision levels (e.g., FP16 for compute, FP32 for accumulation) to trade accuracy for speed.

  • All-reduce: Distributed operation: each device sums its values with all others; result replicated on all devices. Cost: $O(\log D)$ communication rounds for $D$ devices.

  • Collective communication: Broadcasting, all-reduce, reduce-scatter, all-gather operations in distributed training.

Essential vs Optional: Theoretical ML

Theoretical (essential)#

  • GEMM arithmetic and complexity: $O(mkn)$ FLOPs, memory $O(m + k + n)$. Reference: Golub & Van Loan (2013).

  • Arithmetic intensity and Roofline model: $I = \text{FLOPs/bytes}$; peak rate is $\min(\text{FLOP rate}, \text{bandwidth} \times I)$. Reference: Williams et al. (2009).

  • Cache-oblivious algorithms: Block-recursive GEMM achieves near-optimal cache behavior independent of cache size. Reference: Frigo et al. (1999).

  • Batched GEMM: Independent products $C_i \leftarrow A_i B_i$; parallelism across batch dimension. Reference: BLAS 3 standard (1990).

  • Attention complexity: Scaled dot-product $O(L^2 d)$ without optimizations; challenges for long contexts. Reference: Vaswani et al. (2017).

  • Distributed GEMM: Communication cost for gradient all-reduce, model/data parallelism. Reference: Thakur et al. (2005) (MPI Collective Communications).

Applied (landmark systems)#

  • Level-3 BLAS (cuBLAS, MKL): Industry-standard GEMM implementations; peak performance on CPUs/GPUs. Implementation: NVIDIA cuBLAS, Intel MKL. Reference: Dongarra et al. (1990) (BLAS 3).

  • Convolution as GEMM (im2col): Standard in libcnpy (Caffe, PyTorch); enables reuse of optimized GEMM. Implementation: PyTorch conv2d uses im2col on CPU. Reference: Krizhevsky et al. (2012).

  • Mixed precision training: Automatic mixed precision in PyTorch (torch.cuda.amp), TensorFlow (tf.keras.mixed_precision). Achieves 2–3× speedup on V100/A100. Reference: NVIDIA Automatic Mixed Precision Training Guide (2020).

  • Distributed GEMM (Megatron-LM, DeepSpeed): Tensor parallelism partitions GEMM across devices; pipeline parallelism overlaps layers. Implementation: Microsoft DeepSpeed, NVIDIA Megatron-LM. Reference: Shoeybi et al. (2019); Rasley et al. (2020).

  • Flash Attention: IO-efficient attention via blocked matrix products; reduces memory bandwidth by 10×. Implementation: Tri Dao’s flash-attention library. Reference: Dao et al. (2022).

Key ideas: Where it shows up
  1. Convolutional neural networks and image classification

    • Forward pass: convolutional layers (im2col GEMM), batch norm (element-wise), pooling (no GEMM).

    • Backward: weight gradient, input gradient via GEMM.

    • Achievements: ResNet-50 trains on 8 V100 GPUs in ~100 seconds (Goyal et al. 2017); mixed precision reduces time to ~60 seconds. References: Krizhevsky et al. (2012) (AlexNet); He et al. (2015) (ResNet); Goyal et al. (2017) (accurate large-batch SGD).

  2. Transformer models and large language models

    • Per-layer: projection QKV (3 GEMMs), attention (2 GEMMs), MLP (2 GEMMs) = ~7 GEMMs per layer.

    • Attention cost: $O(L^2 d)$ (quadratic in sequence length); dominates for long sequences.

    • Achievements: GPT-3 (Brown et al. 2020) trained in 300 billion FLOPs; parallelized across 1,024 A100 GPUs using model parallelism. Flash Attention (Dao et al. 2022) reduces attention memory by 10×. References: Vaswani et al. (2017) (Transformer); Brown et al. (2020) (GPT-3); Dao et al. (2022) (Flash Attention).

  3. Distributed training and synchronization

    • Data parallelism: gradient all-reduce after each minibatch.

    • Model parallelism: gradient exchanges within matrix products.

    • Achievements: LAMB optimizer (You et al. 2019) enables BERT training on 32k TPUs in 76 minutes. Megatron-LM (Shoeybi et al. 2019) trains GPT models with tensor parallelism. References: You et al. (2019) (LAMB); Shoeybi et al. (2019) (Megatron-LM).

  4. Mixed precision training

    • Automatic mixed precision (AMP): dynamically select FP16/FP32 for operations.

    • Loss scaling: prevent FP16 gradient underflow.

    • Achievements: NVIDIA Automatic Mixed Precision reduces training time by 2–3× on V100/A100 while maintaining accuracy. References: NVIDIA Mixed Precision Training guide; Micikevicius et al. (2018).

  5. Graph neural networks and sparse matrix products

    • Graph convolution: $X' = \sigma(AXW)$ where $A$ is sparse adjacency matrix.

    • Sparse-dense GEMM: $O(\text{nnz}(A) \cdot d)$ arithmetic intensity lower than dense, but feasible for sparse graphs.

    • Achievements: DGL, PyG enable billion-node GNNs via optimized sparse GEMMs. References: Kipf & Welling (2017) (GCN); Wang et al. (2019) (DGL); Fey et al. (2019) (PyG).

Notation
  • Matrix product: $C \leftarrow A B$ with $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$.

  • Batched product: $C_i \leftarrow A_i B_i$ for $i = 1, \ldots, B$ (batch size); vectorization across batch.

  • Attention: $\text{Attention}(Q, K, V) = \text{softmax}(QK^\top / \sqrt{d_k}) V$ with $Q, K, V \in \mathbb{R}^{L \times d}$ (sequence length $L$, dimension $d$).

  • Complexity: Attention is $O(L^2 d)$ FLOPs; dense GEMM is $O(n d_{\text{in}} d_{\text{out}})$ per layer (batch size $n$).

  • Arithmetic intensity: $I = \frac{2mkn}{8(mk + kn + mn)}$ (depends on matrix shapes; higher $I$ achieves better GPU utilization).

  • FLOP rate: Peak: $P$ (e.g., 20 TFLOP for V100 in FP32); practical: $P \times \text{efficiency}$ (typically 50–80%).

  • Memory bandwidth: $B$ (e.g., 900 GB/s for A100 HBM2e); roofline: achieved throughput $= \min(P, I \times B)$.

  • Example: ResNet-50 forward pass: ~8 GFLOPs per image; batch size 256 = 2 TFLOPs; A100 achieves ~80% utilization = 16 TFLOP achieved; time ~0.1 ms.

Pitfalls & sanity checks
  • Batch size too small: GPUs underutilized; poor arithmetic intensity. Typical minimum: 32–64 per device.

  • Tall-skinny GEMM: Low arithmetic intensity; underutilize accelerator. Prefer square or batched products.

  • Ignoring data layout: Row-major vs. column-major affects cache performance by 10×.

  • Mixed precision without loss scaling: FP16 gradients underflow ($\approx 10^{-6}$); loss scale prevents this (multiply loss by $2^{16}$, divide gradients).

  • Attention without length limits: Quadratic memory; even with batch size 1, $L = 8192$ requires 256 MB for single head.

  • Synchronous all-reduce without compression: Communication time dominates; gradient compression (sparsification, quantization) essential at scale.

  • Assuming linear scaling: Communication cost breaks linear scaling; efficiency drops from 90% (4 devices) to 30% (128 devices).

  • Convolution without im2col: Naive loops 100–1000× slower than GEMM-based implementation.

References

Matrix multiplication theory

  1. Golub, G. H., & Van Loan, C. F. (2013). Matrix Computations (4th ed.).

  2. Strassen, V. (1969). Gaussian elimination is not optimal.

  3. Frigo, M., Leiserson, C. E., Prokop, H., & Ramachandran, S. (1999). Cache-oblivious algorithms.

Performance modeling and BLAS

  1. Dongarra, J., Du Croz, J., Hammarling, S., & Hanson, R. H. (1990). An extended set of FORTRAN basic linear algebra subprograms.

  2. Williams, S., Waterman, A., & Patterson, D. (2009). Roofline: an insightful visual performance model for floating-point programs.

  3. Demmel, J., Gearhart, J., Liphardt, B., Schwartz, O., & Toledo, S. (2009). Communication-avoiding Gaussian elimination.

Deep learning and convolution

  1. Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). ImageNet classification with deep convolutional neural networks.

  2. He, H., Zhang, X., Ren, S., & Sun, J. (2015). Deep residual learning for image recognition.

  3. Jia, Y., Shelhamer, E., Donahue, J., et al. (2014). Caffe: convolutional architecture for fast feature embedding.

Transformer and attention

  1. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention is all you need.

  2. Brown, T. B., Mann, B., Ryder, N., et al. (2020). Language models are unsupervised multitask learners.

  3. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Re, C. (2022). FlashAttention: fast and memory-efficient exact attention with IO-awareness.

Mixed precision and numerical stability

  1. Micikevicius, P., Narang, S., Alben, J., et al. (2018). Mixed precision training.

  2. Ioffe, S., & Szegedy, C. (2015). Batch normalization: accelerating deep network training by reducing internal covariate shift.

  3. NVIDIA Automatic Mixed Precision Training Guide. (2020).

Distributed training

  1. Thakur, R., Rabenseifner, R., & Gropp, W. (2005). Optimization of collective communication operations in MPICH.

  2. Goyal, P., Dollár, P., Girshick, R., et al. (2017). Accurate large-batch SGD: training ImageNet in 1 hour.

  3. Shoeybi, M., Patwary, M., Puri, R., et al. (2019). Megatron-LM: training multi-billion parameter language models using model parallelism.

  4. Rasley, J., He, Y., Yan, F., Ruwase, O., & O’Neill, M. (2020). DeepSpeed: system optimizations enable training deep learning models with over 100 billion parameters.

  5. You, Y., Gitman, I., & Ginsburg, B. (2019). Large batch optimization for deep learning: training BERT in 76 minutes.

Attention optimization

  1. Choromanski, K., Likhosherstov, V., Dohan, D., et al. (2021). Rethinking attention with performers.

  2. Child, A., Gray, S., Radford, A., & Sutskever, I. (2019). Generating long sequences with sparse transformers.

  3. Peng, H., Schwartz-Ziv, R., & Armon, M. (2021). Reducing transformer depth on demand with structured dropout.

Five worked examples

Worked Example 1: GEMM efficiency and arithmetic intensity#

Introduction#

Implement dense matrix multiplication on CPU and GPU; measure FLOP rate and memory bandwidth utilization; demonstrate how matrix shape affects arithmetic intensity.

Purpose#

Understand relationship between GEMM dimensions and arithmetic intensity; show how to achieve peak GPU performance.

Importance#

Foundation for understanding deep learning performance; shapes (batch size, hidden dimensions) directly impact training time.

What this example demonstrates#

  • Construct tall-skinny vs. square GEMM matrices.

  • Measure FLOPs and memory bandwidth for each.

  • Compute arithmetic intensity $I = \text{FLOPs/bytes}$.

  • Compare achieved FLOP rate vs. peak.

  • Predict speedup from roofline model.

Background#

GEMM efficiency depends on matrix shape: square matrices have high arithmetic intensity; tall-skinny have low intensity.

Historical context#

Roofline model (Williams et al. 2009) formalizes this trade-off; guides architecture and algorithm design.

History#

Standard framework for performance modeling in HPC and ML systems.

Prevalence in ML#

Every deep learning practitioner adjusts batch size, layer dimensions to maximize GPU utilization.

Notes#

  • Arithmetic intensity: $I = \frac{2mkn}{8(mk + kn + mn)}$; maximized when $m \approx k \approx n$ (cube).

  • For fixed $k$, varying $m, n$ (batch size, hidden dims) changes $I$ by 10×.

Connection to ML#

Batch size and hidden dimension choices affect both accuracy and training speed; understanding trade-offs is critical.

Connection to Linear Algebra Theory#

GEMM is fundamental linear algebra operation; efficiency is determined by cache locality (blocking theory).

Pedagogical Significance#

Demonstrates practical performance modeling; connects theory (arithmetic intensity) to practice (measured FLOP rates).

References#

  1. Williams, S., Waterman, A., & Patterson, D. (2009). Roofline: an insightful visual performance model for floating-point programs.

  2. Golub, G. H., & Van Loan, C. F. (2013). Matrix Computations (4th ed.).

  3. Frigo, M., Leiserson, C. E., Prokop, H., & Ramachandran, S. (1999). Cache-oblivious algorithms.

Solution (Python)#

import numpy as np
import time

np.random.seed(35)

# Test different matrix shapes (keeping k fixed)
k = 1024
shapes = [
    (128, k, 128),    # Tall-skinny-ish: low intensity
    (1024, k, 1024),  # Square: high intensity
    (4096, k, 4096),  # Large square: even higher
]

print("GEMM Efficiency Analysis")
print("=" * 80)
print(f"{'m x n':15} {'FLOPs (M)':15} {'Memory (MB)':15} {'Intensity':15} {'Est. GFLOPs':15}")
print("-" * 80)

for m, k_dim, n in shapes:
    # Arithmetic
    flops = 2 * m * k_dim * n
    # Memory: read A (m*k), read B (k*n), write C (m*n)
    mem_bytes = 8 * (m * k_dim + k_dim * n + m * n)
    intensity = flops / mem_bytes
    
    # Estimate performance from roofline
    # Assume: Peak 20 TFLOP (V100 FP32), Bandwidth 900 GB/s
    peak_flops = 20e12
    bandwidth = 900e9
    roofline = min(peak_flops, bandwidth * intensity)
    
    print(f"{m}x{n}         {flops/1e6:>14.0f} {mem_bytes/1e6:>14.1f} {intensity:>14.2f} {roofline/1e9:>14.1f}")

print("\n" + "=" * 80)
print("Key insight: Higher arithmetic intensity -> higher roofline GFLOPs")
print("Square matrices (m ~ k ~ n) achieve 10-100x higher intensity than tall-skinny")

Worked Example 2: Batched GEMM and GPU parallelism#

Introduction#

Implement batched matrix multiplication; measure performance as batch size varies; show speedup from batch parallelism.

Purpose#

Demonstrate how batch dimension enables parallelism; show relationship between batch size and GPU utilization.

Importance#

Batch size is a key hyperparameter; understanding its impact on performance guides training setup.

What this example demonstrates#

  • Generate batched matrices $A_i, B_i$ for $i = 1, \ldots, B$.

  • Time batched GEMM vs. sequential.

  • Measure speedup; show scaling with batch size.

  • Explain why small batches underutilize GPU.

Background#

GPUs have thousands of cores; small batches can’t keep all cores busy; large batches achieve better utilization.

Historical context#

Batch GEMM standardized in BLAS Level 3 (1990); essential for CNN/RNN training.

History#

Modern frameworks (PyTorch, TensorFlow) automatically batch GEMMs; rarely needs manual tuning.

Prevalence in ML#

Every training loop uses batched GEMM; batch size choice directly impacts throughput.

Notes#

  • Batch size $B = 1$: each GEMM is independent; throughput limited.

  • $B = 32$: better utilization; GPUs have 80+ SMs (streaming multiprocessors).

  • $B = 256$: excellent utilization; typical for modern training.

Connection to ML#

Batch size affects both convergence (larger batches can have worse generalization) and speed; practical sweet spot is usually 32–256.

Connection to Linear Algebra Theory#

Batched GEMM exploits structure (independent problems); vectorization across batch dimension.

Pedagogical Significance#

Shows interplay between algorithm structure and hardware parallelism.

References#

  1. Dongarra, J., Du Croz, J., Hammarling, S., & Hanson, R. H. (1990). An extended set of Fortran basic linear algebra subprograms.

  2. Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). ImageNet classification with deep convolutional neural networks.

  3. Goyal, P., Dollár, P., Girshick, R., et al. (2017). Accurate large-batch SGD: training ImageNet in 1 hour.

Solution (Python)#

import numpy as np
import time

np.random.seed(36)

# Batched GEMM performance
batch_sizes = [1, 4, 16, 64, 256]
m, k, n = 1024, 1024, 1024
iterations = 10

print("Batched GEMM Performance (m=k=n={}, {} iterations)".format(m, iterations))
print("=" * 60)
print(f"{'Batch Size':15} {'Total Time (s)':20} {'GFLOPs':15}")
print("-" * 60)

for B in batch_sizes:
    # Create batch of matrices
    A = np.random.randn(B, m, k).astype(np.float32)
    B_mat = np.random.randn(B, k, n).astype(np.float32)
    
    # Batched matmul (sequential in Python; normally GPU would parallelize)
    t0 = time.time()
    for _ in range(iterations):
        C = np.matmul(A, B_mat)
    t_total = time.time() - t0
    
    # FLOPs: 2mkn per batch, B batches, iterations
    flops = iterations * B * 2 * m * k * n
    gflops = flops / (t_total * 1e9)
    
    print(f"{B:>14} {t_total:>19.4f} {gflops:>14.1f}")

print("\n" + "=" * 60)
print("Note: Larger batch sizes achieve higher GFLOPs due to better parallelism")

Worked Example 3: Convolution as GEMM (im2col)#

Introduction#

Implement convolution using naive loops, then via im2col GEMM; measure speedup from optimized GEMM.

Purpose#

Show how convolution is equivalent to matrix multiplication; demonstrate efficiency gain from reusing optimized GEMM.

Importance#

Foundational for understanding why GPUs excel at CNNs; im2col is standard in production frameworks.

What this example demonstrates#

  • Naive 5-loop convolution implementation.

  • im2col transformation: reshape patches into columns.

  • GEMM on im2col matrix; reshape output.

  • Compare naive vs. GEMM time.

Background#

Convolution unfolds into GEMM; allows reuse of highly-tuned BLAS kernels; 10–100× speedup.

Historical context#

im2col technique developed for efficient convolution implementations in early deep learning (Caffe, 2013).

History#

Standard in all deep learning frameworks; sometimes augmented by Winograd for further speedup.

Prevalence in ML#

Every CNN implementation uses im2col or similar GEMM-based convolution.

Notes#

  • im2col memory overhead: factors of 2–4× larger than direct convolution; trade memory for speed.

  • Winograd convolution (for $3 \times 3$ kernels): lower arithmetic but numerically complex.

Connection to ML#

Convolutional layers dominate image classification and detection models; efficiency here directly impacts training speed.

Connection to Linear Algebra Theory#

Convolution is linear transformation; im2col exploits structure to reduce to GEMM.

Pedagogical Significance#

Demonstrates how abstract operations (convolution) map to concrete linear algebra (GEMM).

References#

  1. Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). ImageNet classification with deep convolutional neural networks.

  2. Jia, Y., Shelhamer, E., Donahue, J., et al. (2014). Caffe: convolutional architecture for fast feature embedding.

  3. Lavin, A., & Gray, S. (2016). Fast algorithms for convolutional neural networks.

Solution (Python)#

import numpy as np
import time

np.random.seed(37)

# Convolution parameters
batch_size, in_height, in_width, in_channels = 32, 64, 64, 3
out_channels, kernel_h, kernel_w, stride = 16, 3, 3, 1
pad = 1

# Padded input
X_padded = np.pad(np.random.randn(batch_size, in_height, in_width, in_channels),
                   ((0,0), (pad,pad), (pad,pad), (0,0)), mode='constant')
W = np.random.randn(out_channels, kernel_h, kernel_w, in_channels)

# Output dimensions
out_height = (in_height + 2*pad - kernel_h) // stride + 1
out_width = (in_width + 2*pad - kernel_w) // stride + 1

# Naive convolution (slow)
print("Naive convolution (5-loop implementation):")
t0 = time.time()
Y_naive = np.zeros((batch_size, out_height, out_width, out_channels))
for b in range(batch_size):
    for h in range(out_height):
        for w in range(out_width):
            for c in range(out_channels):
                h_start = h * stride
                w_start = w * stride
                patch = X_padded[b, h_start:h_start+kernel_h, w_start:w_start+kernel_w, :]
                Y_naive[b, h, w, c] = np.sum(patch * W[c])
t_naive = time.time() - t0
print(f"  Time: {t_naive:.4f} s")

# im2col GEMM (fast)
print("\nim2col GEMM (optimized convolution):")
t0 = time.time()

# im2col: extract patches
X_col = np.zeros((batch_size * out_height * out_width, kernel_h * kernel_w * in_channels))
idx = 0
for b in range(batch_size):
    for h in range(out_height):
        for w in range(out_width):
            h_start = h * stride
            w_start = w * stride
            patch = X_padded[b, h_start:h_start+kernel_h, w_start:w_start+kernel_w, :]
            X_col[idx] = patch.reshape(-1)
            idx += 1

# Weight matrix (reshape filters)
W_mat = W.reshape(out_channels, -1).T  # (kernel_h*kernel_w*in_channels, out_channels)

# GEMM
Y_col = X_col @ W_mat  # (batch*out_h*out_w, out_channels)

# Reshape to output
Y_gemm = Y_col.reshape(batch_size, out_height, out_width, out_channels)

t_gemm = time.time() - t0
print(f"  Time: {t_gemm:.4f} s")

print(f"\nSpeedup: {t_naive / t_gemm:.1f}x")
print(f"Results match: {np.allclose(Y_naive, Y_gemm, atol=1e-5)}")

Worked Example 4: Scaled dot-product attention complexity#

Introduction#

Implement attention operation; measure memory and time complexity; show quadratic dependence on sequence length.

Purpose#

Understand why attention is a bottleneck for long sequences; motivate approximate attention methods.

Importance#

Attention scales as $O(L^2 d)$; for long sequences (4K tokens), this dominates; critical for efficiency research.

What this example demonstrates#

  • Implement attention: QK^T, softmax, output.

  • Measure memory (intermediate softmax matrix is $L \times L$).

  • Time scaling with $L$; show quadratic growth.

  • Compare attention time vs. other layers.

Background#

Quadratic attention complexity is fundamental limitation of transformer architecture; many proposed approximations.

Historical context#

Vaswani et al. (2017) introduce attention; complexity not initially recognized as bottleneck for $L > 512$.

History#

Post-2020, attention optimization becomes major research area: Flash Attention, sparse attention, linear attention variants.

Prevalence in ML#

Every transformer model suffers from quadratic attention; common workaround is to limit context length or use approximations.

Notes#

  • Attention FLOPs: $2L^2 d$ (dominant for $L > d$).

  • Memory: $O(L^2)$ for attention matrix; for $L = 4096, d = 768$: 64 MB per sequence.

Connection to ML#

Limiting context length ($L = 512$ vs. $L = 4096$) is common trade-off between expressiveness and efficiency.

Connection to Linear Algebra Theory#

Attention is polynomial in sequence length; matrix products scale quadratically in one dimension.

Pedagogical Significance#

Shows concrete example of how algorithmic bottleneck (quadratic) impacts practical ML.

References#

  1. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention is all you need.

  2. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Re, C. (2022). FlashAttention: fast and memory-efficient exact attention with IO-awareness.

  3. Choromanski, K., Likhosherstov, V., Dohan, D., et al. (2021). Rethinking attention with performers.

Solution (Python)#

import numpy as np
import time

np.random.seed(38)

# Attention parameters
d = 768  # Hidden dimension
num_heads = 12
d_k = d // num_heads
L_values = [128, 256, 512, 1024, 2048]  # Sequence lengths

print("Attention Complexity Analysis (d={}, num_heads={})".format(d, num_heads))
print("=" * 70)
print(f"{'Seq Len L':15} {'FLOPs (M)':15} {'Memory (MB)':15} {'Time (ms)':15}")
print("-" * 70)

for L in L_values:
    batch_size = 1
    
    # Create Q, K, V
    Q = np.random.randn(batch_size, num_heads, L, d_k).astype(np.float32)
    K = np.random.randn(batch_size, num_heads, L, d_k).astype(np.float32)
    V = np.random.randn(batch_size, num_heads, L, d_k).astype(np.float32)
    
    # Measure time and memory
    t0 = time.time()
    
    # Attention: QK^T / sqrt(d_k)
    scores = np.matmul(Q, K.transpose(0, 1, 3, 2))  # (batch, heads, L, L)
    scores = scores / np.sqrt(d_k)
    
    # Softmax
    scores = scores - np.max(scores, axis=-1, keepdims=True)
    exp_scores = np.exp(scores)
    weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
    
    # Output
    output = np.matmul(weights, V)  # (batch, heads, L, d_k)
    
    t_attn = time.time() - t0
    
    # FLOPs: QK^T = 2*L^2*d_k, softmax ~L^2, output = 2*L^2*d_k
    flops = batch_size * num_heads * (2 * L * L * d_k + 2 * L * L * d_k)
    
    # Memory: scores matrix is L x L per head
    mem_bytes = batch_size * num_heads * L * L * 4
    
    print(f"{L:>14} {flops/1e6:>14.0f} {mem_bytes/1e6:>14.1f} {t_attn*1e3:>14.2f}")

print("\n" + "=" * 70)
print("Key insight: FLOPs and memory scale quadratically with sequence length")
print("For L=4096: 15 GB memory, billions of FLOPs -- attention becomes bottleneck")

Worked Example 5: Distributed GEMM and communication cost#

Introduction#

Implement data parallel training with gradient synchronization; measure computation vs. communication time; show communication overhead.

Purpose#

Understand communication bottleneck in distributed training; motivate communication-efficient algorithms.

Importance#

Modern LLMs trained on 1000s of GPUs; communication often dominates; critical for scaling.

What this example demonstrates#

  • Simulate distributed GEMM (matmul on local device).

  • Simulate all-reduce for gradient synchronization.

  • Measure computation time vs. communication time.

  • Show how communication latency scales with number of devices.

Background#

Distributed training divides minibatches across devices; after each minibatch, devices exchange gradients via all-reduce.

Historical context#

Large-batch SGD and gradient compression (2017–2019) driven by communication bottleneck.

History#

Modern frameworks (PyTorch DDP, Horovod) optimize communication; mixed precision + gradient compression reduce overhead.

Prevalence in ML#

Every distributed training uses all-reduce; communication cost is well-studied bottleneck.

Notes#

  • Computation time: $O(B \cdot d_{\text{in}} \cdot d_{\text{out}})$ (linear in batch size, dimensions).

  • Communication time: $O(\log D + d_{\text{gradient}})$ (logarithmic in device count $D$, linear in gradient size).

  • For 1000 devices: all-reduce with $\log D \approx 10$ rounds; if each round takes 10 μs, total ~100 μs; computation often takes ms.

Connection to ML#

Large-batch training requires communication efficiency; gradient compression and other tricks essential for practical scaling.

Connection to Linear Algebra Theory#

All-reduce is tree-based collective communication; optimal complexity is $O(\log D)$.

Pedagogical Significance#

Shows distributed systems aspect of linear algebra; explains why scaling beyond certain point is challenging.

References#

  1. Thakur, R., Rabenseifner, R., & Gropp, W. (2005). Optimization of collective communication operations in MPICH.

  2. Shoeybi, M., Patwary, M., Puri, R., et al. (2019). Megatron-LM: training multi-billion parameter language models using model parallelism.

  3. Rasley, J., He, Y., Yan, F., Ruwase, O., & O’Neill, M. (2020). DeepSpeed: system optimizations enable training deep learning models with over 100 billion parameters.

Solution (Python)#

import numpy as np
import time

np.random.seed(39)

# Distributed training simulation
num_devices = [1, 4, 8, 16, 32]
batch_size = 256
hidden_dim = 2048

print("Distributed GEMM: Computation vs. Communication")
print("=" * 70)
print(f"{'Devices':15} {'Comp Time (ms)':20} {'Comm Time (μs)':20} {'Comp/Comm Ratio':15}")
print("-" * 70)

# Assume:
# - Computation: 100 GFLOPs/device (V100)
# - Communication: 25 GB/s interconnect (typical)

compute_flops_per_device = 100e9  # 100 GFLOPs
comm_bandwidth = 25e9  # GB/s (25 GB/s)

for D in num_devices:
    # Local batch per device
    local_batch = batch_size // D
    
    # GEMM: local_batch x hidden_dim x hidden_dim
    flops_local = 2 * local_batch * hidden_dim * hidden_dim
    
    # Computation time
    t_compute = flops_local / compute_flops_per_device
    
    # Communication: all-reduce of gradients (hidden_dim)
    # Complexity: O(log D) communication rounds
    # Each round transmits O(hidden_dim) data (simplified)
    comm_rounds = int(np.log2(D)) + 1
    gradient_size = hidden_dim * 4  # bytes (FP32)
    comm_per_round = gradient_size / comm_bandwidth
    t_comm = comm_rounds * comm_per_round
    
    ratio = t_compute / t_comm
    
    print(f"{D:>14} {t_compute*1e3:>19.3f} {t_comm*1e6:>19.2f} {ratio:>14.1f}x")

print("\n" + "=" * 70)
print("Key insight: Communication becomes bottleneck at large scale")
print("For 32 devices: communication ~100 microseconds, computation ~10 milliseconds")
print("Compute/comm ratio decreases -> inefficiency at scale")

Comments