Introduction#
Neural networks are fundamentally stacks of matrix multiplications. A forward pass through a deep network is a product of weight matrices and activations. Each layer computes $A_{l+1} = \sigma(W_l A_l)$ where $W_l \in \mathbb{R}^{d_{l+1} \times d_l}$, $A_l \in \mathbb{R}^{n \times d_l}$ (batch size $n$, layer input dimension $d_l$, output dimension $d_{l+1}$). The cost is $O(n d_l d_{l+1})$ per layer. Transformers add attention: $\text{Attention}(Q, K, V) = \sigma(Q K^\top / \sqrt{d_k}) V$, which involves three GEMM operations and a softmax (polynomial in $n$ sequence length). For a Transformer with sequence length $L$, hidden dimension $d$, and $H$ attention heads per layer, attention cost is $O(L^2 d)$ (quadratic in sequence length—a major bottleneck). Modern accelerators (GPUs, TPUs) are matrix-multiply engines: billions of floating-point operations per second (TFLOPs). Utilization depends on arithmetic intensity (ops/byte): bandwidth-bound operations underutilize the accelerator; computation-bound operations (high arithmetic intensity) achieve near-peak performance. Understanding how to write matrix products that achieve high arithmetic intensity, and how to distribute them across devices, determines whether you can train billion-parameter models.
Important ideas#
Matrix-matrix multiplication (GEMM) structure
Dense GEMM: $C \leftarrow AB$ with $A \in \mathbb{R}^{m \times k}$, $B \in \mathbb{R}^{k \times n}$, $C \in \mathbb{R}^{m \times n}$.
Arithmetic: $mk + kmn + mn = O(mkn)$ floating-point operations (FLOPs).
Memory: $O(m + k + n)$ words (inputs + output); on GPU, $m, k, n$ can be $1000$s, requiring GB of memory.
Arithmetic intensity: $I = \frac{\text{FLOPs}}{\text{bytes}} = \frac{2mkn}{8(mk + kn + mn)} \approx \frac{mkn}{4(m + k + n)}$ (higher is better).
Blocking and cache efficiency
GEMM blocked into $b \times b$ tiles; each tile multiplied using fast cache.
Cache line length (64 bytes typical); GEMM loads tile once, reuses it $O(b)$ times.
Roofline model: peak FLOP rate vs. memory bandwidth; if arithmetic intensity $< I_{\text{roof}}$, algorithm is bandwidth-bound.
Batch matrix multiplication (batched GEMM)
Forward pass: $C_i \leftarrow A_i B_i$ for $i = 1, \ldots, B$ (batch size).
Exploit parallelism: process multiple batches on multiple cores/GPU SMs.
Highly efficient when batch size is large; small batches underutilize accelerator.
Convolution as matrix multiplication (im2col, Winograd)
Convolution unfolds as GEMM: reshape input patches into columns; multiply by filter matrix; reshape output.
im2col: input image to column matrix; allows use of highly optimized GEMM (cuBLAS, MKL).
Cost: $O(kh kw d_{\text{in}} d_{\text{out}} h_{\text{out}} w_{\text{out}})$ (kernel height/width, input/output channels, spatial dims).
Winograd: fast convolution via transformed domain; reduces arithmetic but increases numerical complexity.
Scaled dot-product attention
Query-key-value paradigm: $Q \in \mathbb{R}^{L \times d_k}$, $K, V \in \mathbb{R}^{L \times d_v}$ (sequence length $L$, head dimension $d_k, d_v$).
Attention: (1) $M = Q K^\top / \sqrt{d_k}$ (matrix product $L \times d_k \times d_k \times L = O(L^2 d_k)$), (2) $A = \text{softmax}(M)$ (per-row normalization, no matrix products), (3) $O = AV$ (matrix product $L \times L \times d_v = O(L^2 d_v)$).
Total: $O(L^2 (d_k + d_v)) = O(L^2 d)$ (quadratic in sequence length).
Challenge: for $L = 4096$ (typical transformer), $L^2 = 16M$ operations per attention head; billions for multi-head.
Mixed precision and numerical stability
FP32 (single precision, float32): 32 bits, ~7 significant digits; gradients, weights commonly stored in FP32.
FP16 (half precision, float16): 16 bits, ~4 significant digits; range $[6 \times 10^{-8}, 6 \times 10^4]$; GPU operations 2–10× faster.
BFloat16 (Brain Float): 16 bits, same exponent range as FP32, reduced mantissa; intermediate between FP32 and FP16.
Mixed precision: compute GEMM in FP16 (fast), accumulate in FP32 (stable); scale loss to prevent underflow.
Distributed matrix multiplication
Data parallelism: replicate model on each device; partition minibatches; synchronize gradients via all-reduce.
Model parallelism: partition matrix weights across devices; communication within matrix product (e.g., matmul followed by communication).
Pipeline parallelism: partition layers across devices; overlap computation on layer $i$ with communication on layer $i-1$.
Cost: compute + communication latency; communication often dominates at large scale (Roofline model).
Relevance to ML#
Convolutional neural networks (CNNs): Forward and backward passes are GEMM-heavy; efficiency determines whether you can train on billion-pixel images or video.
Recurrent neural networks (RNNs), LSTMs, GRUs: Fully-connected layers between timesteps; matrix products per timestep.
Transformers and large language models: Attention is $O(L^2 d)$ matrix products; for GPT-3 ($L = 2048$, $d = 12288$), attention dominates forward/backward.
Graph neural networks (GNNs): Graph convolution is sparse matrix product; efficiency depends on sparsity and format.
Distributed training: Modern LLMs trained on thousands of GPUs/TPUs; communication cost (network bandwidth) often exceeds computation cost.
Algorithmic development (milestones)#
1969: Strassen algorithm: $O(n^{2.807})$ vs. $O(n^3)$ naive GEMM (theoretically significant; rarely used in practice due to constants).
1979–1990: Level-1/2/3 BLAS (Basic Linear Algebra Subprograms); standardized interface for matrix ops; LAPACK (1992) built on BLAS.
1995–2005: GPU era begins: NVIDIA GeForce, Tesla; GPUs have 100× more memory bandwidth than CPUs; GEMMs run 10–100× faster.
2006: CUDA (Compute Unified Device Architecture) released; enables general-purpose GPU computing; cuBLAS optimized GEMM for NVIDIA GPUs.
2011: Mixed precision training proposed; FP16 + loss scaling enables 10–100× speedups on GPUs.
2012: AlexNet (Krizhevsky et al.) demonstrates deep CNN training on GPUs; FLOPs dominate; GEMM-heavy.
2015: Batch normalization (Ioffe & Szegedy); reduces sensitivity to initialization; enables mixed precision at scale.
2017: Transformer architecture (Vaswani et al.); attention is dense GEMM-based; quadratic in sequence length.
2018–2020: Distributed training frameworks mature (PyTorch DDP, TensorFlow Horovod); trillion-parameter models trained via model parallelism.
2020–2023: Flash Attention (Dao et al. 2022) reduces attention memory via block-sparse operations; Megatron-LM and DeepSpeed enable distributed GEMMs at petaflop scales.
Definitions#
GEMM (General Matrix Multiply): $C \leftarrow \alpha A B + \beta C$ (standard matrix multiply with scaling/accumulation).
FLOP (floating-point operation): One add or multiply; GEMM $C \leftarrow AB$ is $2mkn$ FLOPs.
Arithmetic intensity: $I = \frac{\text{FLOPs}}{\text{bytes read/written}}$ (ops per byte); high $I$ means compute-bound; low $I$ means bandwidth-bound.
Roofline model: Peak achievable throughput = $\min(\text{peak FLOP rate}, \text{memory bandwidth} \times \text{arithmetic intensity})$.
Memory-bound: Algorithm where memory bandwidth is bottleneck; cannot achieve peak FLOP rate.
Compute-bound: Algorithm where compute is bottleneck; limited by FLOPs/cycle, not memory.
Mixed precision: Using multiple precision levels (e.g., FP16 for compute, FP32 for accumulation) to trade accuracy for speed.
All-reduce: Distributed operation: each device sums its values with all others; result replicated on all devices. Cost: $O(\log D)$ communication rounds for $D$ devices.
Collective communication: Broadcasting, all-reduce, reduce-scatter, all-gather operations in distributed training.
Comments