Assessment Focus: Python 3 · PyTorch · Matrix manipulations (not model training)
Level: Beginner → Advanced — read end-to-end before your test
Time to complete: ~90 minutes reading + practice
Table of Contents
- What is PyTorch and Why Does It Matter?
- Tensor Fundamentals — Creation & Inspection
- Data Types, Devices & Casting
- Indexing, Slicing & Boolean Masking
- Reshaping, Views & Squeezing
- Element-wise Operations & Reductions
- Broadcasting — The Silent Multiplier
- Matrix Multiplication — The Core Operation
- Batched Matrix Operations
- Einstein Summation (einsum)
- Linear Algebra — linalg Module
- Advanced Indexing — gather & scatter
- Memory Layout, Strides & Contiguity
- Practice Problems with Full Solutions
- Assessment Cheat Sheet
1. What is PyTorch and Why Does It Matter?
PyTorch is a numerical computation library built on top of C++/CUDA, exposing a Python API that feels like NumPy but with two killer features:
- GPU acceleration — tensors can live on CUDA devices and perform operations at GPU speed
- Autograd — automatic differentiation through any computation graph (used for training, but not the focus of matrix-only assessments)
For this assessment, think of PyTorch as NumPy with superpowers: the same mental model of n-dimensional arrays, but with broadcasting, batched operations, and linear algebra built in at native speed.
import torch
import torch.nn.functional as F
# Confirm version
print(torch.__version__) # e.g. 2.2.0
print(torch.cuda.is_available()) # True if GPU available
Key mindset: Every operation in PyTorch either returns a new tensor (functional) or modifies in-place (trailing underscore, e.g.
add_). Always know which one you're calling.
2. Tensor Fundamentals — Creation & Inspection
2.1 Ways to Create Tensors
Understanding how a tensor is created determines whether it shares memory with Python/NumPy data or owns its own memory.
import torch
import numpy as np
# ── From Python data ────────────────────────────────────────────
a = torch.tensor([1.0, 2.0, 3.0]) # copies data → always safe
b = torch.tensor([[1, 2], [3, 4]]) # 2D from nested list
# ── Sharing memory with NumPy ───────────────────────────────────
arr = np.array([1.0, 2.0, 3.0])
t = torch.from_numpy(arr) # SHARES memory! mutation propagates both ways
t2 = torch.as_tensor(arr) # same — shares when possible
t3 = torch.tensor(arr) # COPIES — fully independent
arr[0] = 99.0
print(t[0]) # tensor(99.) — changed!
print(t3[0]) # tensor(1.) — unchanged
# ── Structured tensors ──────────────────────────────────────────
zeros = torch.zeros(3, 4) # all 0.0, shape (3,4)
ones = torch.ones(2, 3, 4) # all 1.0, shape (2,3,4)
eye = torch.eye(4) # 4×4 identity matrix
empty = torch.empty(3, 3) # uninitialized (garbage values!)
full = torch.full((2, 3), fill_value=7.0) # fill with 7.0
# ── Sequence tensors ────────────────────────────────────────────
arange = torch.arange(0, 10, step=2) # [0, 2, 4, 6, 8]
linspace = torch.linspace(0, 1, steps=5) # [0.0, 0.25, 0.5, 0.75, 1.0]
# ── Random tensors ──────────────────────────────────────────────
torch.manual_seed(42) # reproducibility
uniform = torch.rand(3, 3) # Uniform[0, 1)
normal = torch.randn(3, 3) # Normal(0, 1)
randint = torch.randint(0, 10, (3, 3)) # integers in [0, 10)
perm = torch.randperm(8) # random permutation of [0..7]
# ── Tensors like another tensor (same shape/dtype/device) ───────
x = torch.randn(2, 3)
torch.zeros_like(x)
torch.ones_like(x)
torch.rand_like(x)
torch.full_like(x, fill_value=5.0)
2.2 Inspecting a Tensor
Every tensor carries metadata you should query habitually:
t = torch.randn(3, 4, 5)
print(t.shape) # torch.Size([3, 4, 5]) — like a tuple
print(t.size()) # same as .shape
print(t.size(0)) # 3 — size of dimension 0
print(t.ndim) # 3 — number of dimensions
print(t.dtype) # torch.float32
print(t.device) # device(type='cpu') or device(type='cuda', index=0)
print(t.numel()) # 60 — total element count (3 * 4 * 5)
print(t.is_contiguous()) # True (freshly created tensors are contiguous)
print(t.stride()) # (20, 5, 1) — elements to skip to reach next in each dim
✏️ Quick Check 2.2
Q: What does torch.empty(3, 3) contain?
A: Uninitialized memory — whatever bytes happen to be at those addresses. It's the fastest creation method (no fill), but never use its values without first writing to the tensor. Use torch.zeros() when you need a defined starting state.
3. Data Types, Devices & Casting
3.1 The dtype Hierarchy
# Floating point (most ML operations use these)
torch.float16 # fp16 — half precision (2 bytes)
torch.float32 # fp32 — DEFAULT for tensors (4 bytes)
torch.float64 # fp64 — double precision (8 bytes)
torch.bfloat16 # brain float — better range than fp16, same size
# Integer
torch.int8 # (1 byte, signed)
torch.int16 # (2 bytes)
torch.int32 # (4 bytes)
torch.int64 # (8 bytes) — DEFAULT for integer tensors
# Boolean
torch.bool # True/False (1 byte)
# Explicit dtype at creation
t = torch.ones(3, dtype=torch.float64)
i = torch.tensor([1, 2, 3], dtype=torch.int32)
# Casting (creates a new tensor with converted values)
t_fp32 = t.float() # → float32
t_fp16 = t.half() # → float16
t_fp64 = t.double() # → float64
t_int = t.long() # → int64
t_bool = t.bool() # → bool
# Generic cast
t.to(torch.float16)
t.to(dtype=torch.int32)
3.2 Devices
cpu_tensor = torch.randn(3, 3) # on CPU (default)
gpu_tensor = torch.randn(3, 3, device='cuda') # on GPU
gpu_tensor = torch.randn(3, 3, device='cuda:0') # explicit GPU index
# Moving between devices
on_gpu = cpu_tensor.to('cuda') # CPU → GPU
on_cpu = gpu_tensor.to('cpu') # GPU → CPU
on_gpu = cpu_tensor.cuda() # shorthand
on_cpu = gpu_tensor.cpu() # shorthand
# Best practice: use device variable so code works on CPU or GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
t = torch.randn(3, 3, device=device)
Assessment rule: All tensors in an operation must be on the same device and use compatible dtypes. Mixing CPU + GPU, or int + float in some operations, raises errors.
3.3 Mixed Operations & Promotion
a = torch.tensor([1.0]) # float32
b = torch.tensor([2.0], dtype=torch.float64)
# PyTorch promotes to the higher-precision type
c = a + b # float64 — float32 was promoted
# Integer + float → float
d = torch.tensor([1]) + torch.tensor([1.5]) # float32
✏️ Quick Check 3
Q: You have a = torch.randn(3) (float32) and b = torch.randint(0, 5, (3,)) (int64). What happens when you compute a + b?
A: PyTorch promotes int64 to float32, so the result is float32. The rule is: if one operand is floating point, the result is floating point.
4. Indexing, Slicing & Boolean Masking
This is one of the most tested areas because the distinction between views (shared memory) and copies (independent memory) is subtle and important.
4.1 Basic Indexing and Slicing
t = torch.arange(24).reshape(2, 3, 4)
# Shape: (2, 3, 4) — think: 2 matrices, each 3 rows × 4 cols
# t[batch, row, col]
# ── Single-element indexing ─────────────────────────────────────
t[0] # first matrix → shape (3, 4)
t[-1] # last along dim 0 → shape (3, 4)
t[0, 1] # row 1 of first matrix → shape (4,)
t[0, 1, 2] # scalar element → tensor(6)
t[0][1][2] # equivalent but slower (three indexing ops)
# ── Slicing (Python-style, half-open intervals) ─────────────────
t[:, 0, :] # all batches, row 0, all cols → shape (2, 4)
t[..., -1] # last column across all dims → shape (2, 3)
t[:, 1:3, :] # rows 1 and 2 from both batches → shape (2, 2, 4)
t[:, ::2, :] # every other row → shape (2, 2, 4)
# The Ellipsis (...) fills in as many : as needed
# t[..., 2] ≡ t[:, :, 2] for a 3D tensor
4.2 Views vs Copies — Critical Distinction
t = torch.arange(12, dtype=torch.float32).reshape(3, 4)
# [[0,1,2,3],[4,5,6,7],[8,9,10,11]]
# ── SLICES return VIEWS (shared memory) ────────────────────────
s = t[0, :] # first row — this is a VIEW
s[0] = 999.0 # modifies t as well!
print(t[0, 0]) # tensor(999.) — YES, t changed
# Verify they share storage
print(t.storage().data_ptr() == s.storage().data_ptr()) # True
# ── Get an independent copy ─────────────────────────────────────
c = t[0, :].clone() # COPY — fully independent
c[0] = -1.0
print(t[0, 0]) # tensor(999.) — t unchanged
# ── .detach() vs .clone() ──────────────────────────────────────
# .detach() — shares memory, removes from autograd graph
# .clone() — copies memory, preserves grad_fn
# .clone().detach() — copy + remove from graph (common pattern)
4.3 Fancy Indexing (Always Returns a Copy)
t = torch.arange(20).reshape(4, 5)
# Index with a list of indices — picks specific rows
rows = t[[0, 2, 3]] # rows 0, 2, 3 → shape (3, 5)
# Index with parallel lists for each dimension
rows_idx = torch.tensor([0, 1, 3])
cols_idx = torch.tensor([2, 4, 1])
t[rows_idx, cols_idx] # picks (0,2), (1,4), (3,1) — shape (3,)
# t[0,2]=2, t[1,4]=9, t[3,1]=16
# These ALWAYS produce a copy — safe to mutate
4.4 Boolean Masking
t = torch.tensor([[1., -2., 3.], [-4., 5., -6.]])
# Create a boolean mask
mask = t > 0
# tensor([[True, False, True], [False, True, False]])
# Apply mask — returns 1D tensor of selected values
positives = t[mask] # tensor([1., 3., 5.])
# Count True values
mask.sum() # tensor(3)
mask.float().mean() # tensor(0.5) — 50% positive
# Conditional replacement (in-place)
t[t < 0] = 0.0 # ReLU!
print(t)
# tensor([[1., 0., 3.], [0., 5., 0.]])
# torch.where — vectorised ternary (NOT in-place, returns new tensor)
result = torch.where(mask, t, torch.zeros_like(t))
✏️ Practice Problem 4
Problem: Given t = torch.arange(1, 26).reshape(5, 5).float(), write a single expression that:
- Selects the bottom-right 3×3 submatrix
- Sets all values in that submatrix above 20 to 0
Solution:
t = torch.arange(1, 26).reshape(5, 5).float()
sub = t[2:, 2:] # bottom-right 3×3: [[13,14,15],[18,19,20],[23,24,25]]
# sub is a VIEW — mutations propagate to t!
sub[sub > 20] = 0.0
# t is now:
# [[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [16, 17, 18, 19, 20],
# [21, 22, 0, 0, 0]]
Key insight: The slice
t[2:, 2:]is a view, sosub[sub > 20] = 0.0modifiestdirectly. This is intentional and efficient — no copy needed.
5. Reshaping, Views & Squeezing
Shape manipulation is the connective tissue of tensor math — you'll do it constantly to align tensors for operations.
5.1 reshape vs view
t = torch.arange(12, dtype=torch.float32) # shape (12,)
# reshape — safe, handles non-contiguous, may or may not share memory
r1 = t.reshape(3, 4) # (3, 4)
r2 = t.reshape(2, 2, 3) # (2, 2, 3)
r3 = t.reshape(4, -1) # (4, 3) — -1 infers the size
# view — faster but requires contiguous memory; always shares memory
v = t.view(3, 4) # fine — t is contiguous
t_T = t.reshape(3, 4).T # transpose is non-contiguous
# t_T.view(12) # RuntimeError! non-contiguous
t_T.contiguous().view(12) # fix: make contiguous first, then view
# When in doubt: use reshape (safe), not view
5.2 Adding and Removing Dimensions
t = torch.tensor([1., 2., 3.]) # shape (3,)
# unsqueeze — add a dimension of size 1
t.unsqueeze(0) # (1, 3) — new batch dimension
t.unsqueeze(1) # (3, 1) — new column dimension
t.unsqueeze(-1) # (3, 1) — same as unsqueeze(1) for 1D
# squeeze — remove dimensions of size 1
x = torch.randn(1, 3, 1, 4)
x.squeeze() # (3, 4) — removes ALL size-1 dims
x.squeeze(0) # (3, 1, 4) — removes only dim 0
x.squeeze(2) # (1, 3, 4) — removes only dim 2
x.squeeze(-1) # (1, 3, 1) — no wait: that removes dim 3 → (1, 3, 1)?
# Careful: squeeze(dim) only removes that dim IF its size is 1; no-op otherwise
# None indexing — alternative unsqueeze
t[None, :] # (1, 3) — same as t.unsqueeze(0)
t[:, None] # (3, 1) — same as t.unsqueeze(1)
5.3 Concatenation and Stacking
a = torch.randn(3, 4)
b = torch.randn(3, 4)
# cat — join along an EXISTING dimension
torch.cat([a, b], dim=0) # (6, 4) — stack vertically
torch.cat([a, b], dim=1) # (3, 8) — stack horizontally
# stack — create a NEW dimension
torch.stack([a, b], dim=0) # (2, 3, 4) — batch of 2 matrices
torch.stack([a, b], dim=1) # (3, 2, 4)
torch.stack([a, b], dim=2) # (3, 4, 2)
# vstack/hstack — convenience wrappers
torch.vstack([a, b]) # (6, 4) — same as cat(dim=0)
torch.hstack([a, b]) # (3, 8) — same as cat(dim=1)
# split — inverse of cat
parts = torch.split(torch.cat([a, b], dim=0), split_size_or_sections=3, dim=0)
# (tensor(3,4), tensor(3,4)) — splits evenly
# chunk — split into N equal pieces
chunks = torch.chunk(torch.cat([a, b], dim=0), chunks=2, dim=0)
5.4 Transpose and Permute
t = torch.randn(2, 3, 4)
# transpose — swap exactly two dimensions
t.transpose(0, 1) # (3, 2, 4)
t.transpose(1, 2) # (2, 4, 3)
t.T # full reversal of dims — only well-defined for 2D!
# permute — reorder ALL dimensions at once
t.permute(2, 0, 1) # (4, 2, 3) — dim 2 goes to position 0, etc.
t.permute(0, 2, 1) # (2, 4, 3) — swap dims 1 and 2
# IMPORTANT: transpose and permute return non-contiguous tensors
print(t.transpose(0,1).is_contiguous()) # False
# Fix with .contiguous() when needed:
t.permute(2, 0, 1).contiguous()
✏️ Practice Problem 5
Problem: You receive a batch of images in NHWC format (batch, height, width, channels) with shape (8, 32, 32, 3). PyTorch convolutions expect NCHW format (batch, channels, height, width). Convert it.
Solution:
images_nhwc = torch.randn(8, 32, 32, 3)
# Axes: N H W C
# Desired: N C H W
# So we move axis 3 (C) to position 1
images_nchw = images_nhwc.permute(0, 3, 1, 2)
print(images_nchw.shape) # torch.Size([8, 3, 32, 32])
# Or equivalently:
images_nchw = images_nhwc.transpose(1, 3).transpose(2, 3)
# After permute, you likely need contiguous for convolution:
images_nchw = images_nchw.contiguous()
6. Element-wise Operations & Reductions
6.1 Arithmetic Operations
All arithmetic operators work element-wise and require tensors of compatible shapes (or broadcastable shapes — covered in section 7).
a = torch.tensor([1., 2., 3., 4.])
b = torch.tensor([2., 2., 2., 2.])
# Standard arithmetic — all return NEW tensors
a + b # [3., 4., 5., 6.]
a - b # [-1., 0., 1., 2.]
a * b # [2., 4., 6., 8.] — ELEMENT-WISE multiply (not dot!)
a / b # [0.5, 1., 1.5, 2.]
a ** 2 # [1., 4., 9., 16.] — element-wise power
a // b # [0., 1., 1., 2.] — floor division
a % b # [1., 0., 1., 0.] — modulo
# In-place variants (trailing underscore — modifies a!)
a.add_(b) # a += b
a.mul_(2) # a *= 2
a.sub_(1) # a -= 1
# Functional API (equivalent, often needed for clarity)
torch.add(a, b)
torch.mul(a, b)
torch.div(a, b)
torch.pow(a, 2)
# Scalar operations — broadcast the scalar to all elements
a + 10
a * 0.5
a ** 0.5 # element-wise square root via **, equivalent to torch.sqrt(a)
6.2 Math Functions
t = torch.tensor([-2., -1., 0., 1., 2.])
torch.abs(t) # [2., 1., 0., 1., 2.]
torch.sqrt(torch.abs(t)) # requires non-negative input
torch.exp(t) # e^t element-wise
torch.log(t + 3) # ln — watch out: log(0) = -inf, log(-x) = nan
torch.log2(t + 3)
torch.log10(t + 3)
torch.sin(t); torch.cos(t); torch.tanh(t)
torch.sigmoid(t) # 1/(1+e^(-t)) — same as F.sigmoid
torch.relu(t) # max(0, t) — equivalent to t.clamp(min=0)
torch.clamp(t, min=-1, max=1) # clip values to [-1, 1]
torch.floor(t); torch.ceil(t); torch.round(t)
torch.sign(t) # -1., -1., 0., 1., 1.
torch.reciprocal(t + 3) # 1 / (t+3) element-wise
# Useful for numerical stability
torch.clamp(x, min=1e-8) # prevent log(0)
6.3 Reduction Operations
Reductions collapse dimensions. Understanding the dim parameter is crucial.
t = torch.tensor([[1., 2., 3.],
[4., 5., 6.]])
# Shape: (2, 3)
# Global reductions — collapse everything to a scalar
t.sum() # tensor(21.)
t.mean() # tensor(3.5)
t.max() # tensor(6.) — the value
t.min() # tensor(1.)
t.std() # tensor(1.8708...)
t.var() # tensor(3.5)
t.prod() # tensor(720.) — product of all elements
t.norm() # Frobenius norm √(sum of squares)
# Reduce along a dimension — 'dim' specifies which axis to collapse
t.sum(dim=0) # tensor([5., 7., 9.]) — shape (3,): sum each column
t.sum(dim=1) # tensor([6., 15.]) — shape (2,): sum each row
t.sum(dim=1, keepdim=True) # tensor([[6.],[15.]]) — shape (2,1): keeps dims!
# argmax / argmin — return INDEX of max/min, not the value
t.argmax() # tensor(5) — flat index (row-major)
t.argmax(dim=0) # tensor([1, 1, 1]) — row index of max in each column
t.argmax(dim=1) # tensor([2, 2]) — col index of max in each row
# Returns both value and index
vals, idxs = t.max(dim=1) # vals=[3.,6.], idxs=[2, 2]
vals, idxs = t.min(dim=0) # vals=[1.,2.,3.], idxs=[0,0,0]
# Cumulative reductions
t.cumsum(dim=1) # cumulative sum along rows
t.cumprod(dim=0) # cumulative product along columns
# Any / All — boolean reductions
(t > 3).any() # tensor(True) — any element > 3?
(t > 0).all() # tensor(True) — all elements > 0?
(t > 3).any(dim=1) # tensor([False, True]) — per row
✏️ Practice Problem 6
Problem: Given a matrix scores of shape (100, 10) representing 100 students' scores on 10 tests, compute:
- Each student's average score
- Each test's average score
- Number of students who passed (average ≥ 60)
- Normalise the scores so each student has zero mean and unit std
Solution:
torch.manual_seed(0)
scores = torch.randint(30, 100, (100, 10)).float()
# 1. Each student's average — collapse the test dimension
student_avg = scores.mean(dim=1) # shape (100,)
# 2. Each test's average — collapse the student dimension
test_avg = scores.mean(dim=0) # shape (10,)
# 3. Count passing students
passed = (student_avg >= 60).sum() # scalar tensor
print(f"Passed: {passed.item()}")
# 4. Normalise each student's scores independently
# Need keepdim=True so broadcasting works across the 10-test dimension
mean = scores.mean(dim=1, keepdim=True) # (100, 1)
std = scores.std(dim=1, keepdim=True) # (100, 1)
std = torch.clamp(std, min=1e-8) # avoid division by zero
normalised = (scores - mean) / std # (100, 10) — broadcast!
# Verify: each student now has approx mean=0, std=1
print(normalised.mean(dim=1)[:5]) # near zero
print(normalised.std(dim=1)[:5]) # near one
7. Broadcasting — The Silent Multiplier
Broadcasting is PyTorch's mechanism to perform operations between tensors of different shapes without copying data. Mastering it eliminates a huge class of shape errors.
7.1 The Broadcasting Rules
Two shapes are broadcast-compatible if, aligning from the right (trailing dimensions), every pair of sizes is either:
- Equal, or
- One of them is 1 (the size-1 dimension expands to match the other)
If one tensor has fewer dimensions, it's treated as if it has leading size-1 dimensions prepended.
Shape (4, 3) + Shape (3,)
Right-align: (4, 3)
( 3) ← padded to (1, 3) implicitly
─────────
Dim 0: 4 vs 1 → expand (1,3) row to fill 4 rows
Dim 1: 3 vs 3 → equal
Result: (4, 3) ✓
Shape (4, 1) + Shape (1, 3)
Dim 0: 4 vs 1 → expand to 4
Dim 1: 1 vs 3 → expand to 3
Result: (4, 3) ✓
Shape (2, 3) + Shape (4, 3)
Dim 0: 2 vs 4 → NEITHER is 1 → ERROR ✗
7.2 Practical Broadcasting Examples
# ── Example 1: add a bias to every row ─────────────────────────
weight = torch.randn(4, 5) # 4 neurons, each with 5 inputs
bias = torch.randn(5) # one bias per input (not realistic but illustrative)
# Actually: bias per OUTPUT neuron
bias = torch.randn(4) # 4 biases, one per output
output = weight.sum(dim=1) + bias # (4,) + (4,) — trivial, same shape
# Real neural layer: (batch, in_features) @ (in_features, out_features) + (out_features,)
X = torch.randn(32, 10) # 32 samples, 10 features
W = torch.randn(10, 4) # weights
b = torch.randn(4) # one bias per output neuron
out = X @ W + b # (32,4) + (4,) — broadcast adds b to every row of X@W
# ── Example 2: column-wise normalisation ───────────────────────
data = torch.randn(100, 5) # 100 samples, 5 features
col_mean = data.mean(dim=0) # (5,) — mean of each feature
col_std = data.std(dim=0) # (5,)
normalised = (data - col_mean) / col_std # (100,5) - (5,) → broadcast!
# ── Example 3: pairwise distances ──────────────────────────────
# Points: A shape (N, D), B shape (M, D)
# Distance matrix: shape (N, M)
A = torch.randn(10, 3) # 10 points in 3D
B = torch.randn(7, 3) # 7 points in 3D
# A.unsqueeze(1): (10, 1, 3)
# B.unsqueeze(0): (1, 7, 3)
# Subtraction: (10, 7, 3)
diff = A.unsqueeze(1) - B.unsqueeze(0) # (10, 7, 3)
dist = torch.sqrt((diff ** 2).sum(dim=2)) # (10, 7) — distance matrix
# ── Example 4: outer product via broadcasting ───────────────────
a = torch.tensor([1., 2., 3.]) # (3,)
b = torch.tensor([10., 20.]) # (2,)
outer = a.unsqueeze(1) * b.unsqueeze(0) # (3,1) * (1,2) → (3,2)
# [[10,20],[20,40],[30,60]]
7.3 Common Broadcasting Mistakes
a = torch.randn(3, 4)
b = torch.randn(4, 3)
# This will FAIL — shapes are not broadcast-compatible in the wrong orientation
# a + b # RuntimeError!
# Fix 1: Transpose one
a + b.T # (3,4) + (3,4) ✓
# Fix 2: Explicit reshape
a + b.reshape(3, 4) # if that's what you meant
# Mistake: missing unsqueeze
col_mean = a.mean(dim=1) # (3,) — mean of each ROW
# a - col_mean # tries (3,4) - (3,) → broadcast makes (3,) → (1,3) → ERROR
# Fix: keep dimension
col_mean = a.mean(dim=1, keepdim=True) # (3, 1)
a - col_mean # (3,4) - (3,1) → (3,4) ✓ subtracts row mean from each row
✏️ Practice Problem 7
Problem: You have embeddings E of shape (vocab_size=1000, embed_dim=64) and a query vector q of shape (64,). Compute the dot product (similarity score) of the query against every embedding, producing a score vector of shape (1000,).
Solution:
E = torch.randn(1000, 64)
q = torch.randn(64)
# Method 1: Matrix-vector multiply
scores = E @ q # (1000, 64) @ (64,) → (1000,)
# Method 2: Explicit sum reduction (broadcasting)
scores = (E * q).sum(dim=1) # (1000,64)*(64,) broadcast → (1000,64) → sum → (1000,)
# Method 3: Using einsum (elegant and explicit)
scores = torch.einsum('vd,d->v', E, q) # v=vocab, d=dim
# All three produce identical results
# Now find the top-5 most similar vocabulary items
topk_scores, topk_indices = torch.topk(scores, k=5)
print(topk_scores) # 5 highest scores
print(topk_indices) # their indices in the vocabulary
8. Matrix Multiplication — The Core Operation
This is the central operation of all linear algebra in PyTorch. Know every variant cold.
8.1 The Full Operator Taxonomy
A = torch.randn(3, 4)
B = torch.randn(4, 5)
# ── 2D matrix multiplication ────────────────────────────────────
C = torch.mm(A, B) # (3,5) — 2D ONLY, most explicit
C = torch.matmul(A, B) # (3,5) — works for any ndim
C = A @ B # (3,5) — operator alias for matmul
# ── Matrix-vector product ───────────────────────────────────────
v = torch.randn(4)
torch.mv(A, v) # (3,) — matrix × vector, 2D only
A @ v # (3,) — matmul auto-handles (4,) → column vec
# ── Vector dot product (1D only) ────────────────────────────────
a = torch.randn(4)
b = torch.randn(4)
torch.dot(a, b) # scalar — ONLY for 1D vectors
a @ b # scalar — same
# ── Outer product ───────────────────────────────────────────────
torch.outer(a, b) # (4, 4)
a.unsqueeze(1) @ b.unsqueeze(0) # same: (4,1) @ (1,4) → (4,4)
# ── matmul is the universal operator ───────────────────────────
# 1D × 1D → scalar (dot product)
# 2D × 2D → matrix multiply
# 2D × 1D → matrix-vector (treats 1D as column vector, result is 1D)
# 1D × 2D → vector-matrix (treats 1D as row vector, result is 1D)
8.2 matmul Broadcasting Rules
torch.matmul (and @) batch-broadcasts when tensors have more than 2 dimensions. The last two dimensions are the matrix dimensions; all leading dimensions are broadcast.
# Batch of matrices × single matrix
bA = torch.randn(10, 3, 4) # 10 matrices of shape (3,4)
B = torch.randn(4, 5) # 1 matrix of shape (4,5)
torch.matmul(bA, B) # (10,3,5) — B broadcast to each of 10 batches
# Batch × Batch (broadcast)
bA = torch.randn(10, 3, 4)
bB = torch.randn(10, 4, 5)
torch.matmul(bA, bB) # (10,3,5) — one matmul per batch element
# Higher-rank broadcasting
bA = torch.randn(2, 10, 3, 4)
bB = torch.randn( 10, 4, 5) # implicit (1,10,...) prepended
torch.matmul(bA, bB) # (2,10,3,5) — broadcast over dim 0
8.3 Understanding the Shapes
The most common source of bugs is shape confusion. Develop this instinct:
(M, K) × (K, N) → (M, N)
^ ^
These must match — the "inner dimensions"
For batched: (B, M, K) × (B, K, N) → (B, M, N)
(*, M, K) × (*, K, N) → (*, M, N)
(all leading dims broadcast)
# Worked example: multi-head attention shapes
# Q: (batch, seq_len, d_k)
# K: (batch, seq_len, d_k)
# Attention scores: (batch, seq_len, seq_len)
B, S, d_k = 4, 16, 32
Q = torch.randn(B, S, d_k)
K = torch.randn(B, S, d_k)
# Q @ K^T — K must be transposed to (B, d_k, S)
scores = Q @ K.transpose(-2, -1) # (4, 16, 16)
# Scale by sqrt(d_k) for numerical stability
scores = scores / (d_k ** 0.5)
8.4 Performance Tips
# ── Use @ or matmul, NOT manual loops ───────────────────────────
A = torch.randn(1000, 1000)
B = torch.randn(1000, 1000)
# SLOW — pure Python loop
result = torch.zeros(1000, 1000)
for i in range(1000):
result[i] = (A[i].unsqueeze(0) @ B).squeeze()
# FAST — single BLAS call
result = A @ B # milliseconds vs seconds
# ── Reuse transposed tensors ─────────────────────────────────────
# If you need A.T multiple times, store it
At = A.T.contiguous() # make contiguous once
for _ in range(100):
result = At @ B # faster than (A.T @ B) * 100 times
# ── Choose dtype based on precision needs ───────────────────────
A16 = A.half() # float16 — 2x faster matmul on GPU, 2x less memory
B16 = B.half()
result = A16 @ B16 # fast but less precise
✏️ Practice Problem 8
Problem: Implement linear regression prediction from scratch using only tensor operations:
- Input
Xhas shape(N=100, features=5) - Weights
Whas shape(5,)and biasbis a scalar - Compute predictions
y_hatof shape(N,) - Compute Mean Squared Error loss as a scalar
Solution:
torch.manual_seed(42)
N, F = 100, 5
X = torch.randn(N, F) # design matrix
W = torch.randn(F) # weight vector
b = torch.tensor(0.5) # bias
y = torch.randn(N) # true labels
# Predictions: X @ W + b
y_hat = X @ W + b # (100,5)@(5,) → (100,) + scalar → (100,)
# MSE: mean over all samples
residuals = y_hat - y # (100,)
mse = (residuals ** 2).mean() # scalar
# Equivalent:
mse = torch.nn.functional.mse_loss(y_hat, y)
print(f"y_hat shape: {y_hat.shape}") # torch.Size([100])
print(f"MSE: {mse.item():.4f}")
9. Batched Matrix Operations
When you have collections of matrices to operate on simultaneously, batched operations are critical.
9.1 bmm vs matmul
# bmm — STRICTLY 3D (batch, M, K) × (batch, K, N)
bA = torch.randn(8, 3, 4)
bB = torch.randn(8, 4, 5)
torch.bmm(bA, bB) # (8, 3, 5) ✓
# bmm does NOT broadcast — all three must have same batch size
bC = torch.randn(1, 4, 5)
# torch.bmm(bA, bC) # ERROR if batch dim != 8
# matmul — handles broadcast over any leading dims
torch.matmul(bA, bC) # (8, 3, 5) — bC broadcast to 8 batches ✓
torch.matmul(bA, bB) # (8, 3, 5) ✓ — works like bmm when 3D
9.2 Batched Applications
# ── Batched covariance matrices ──────────────────────────────────
# Given B batches of N samples with D features: (B, N, D)
# Compute covariance: (B, D, D)
data = torch.randn(4, 100, 8) # 4 batches, 100 samples, 8 features
mean = data.mean(dim=1, keepdim=True) # (4, 1, 8)
centered = data - mean # (4, 100, 8)
# Covariance: (1/N) * X^T X
# centered.transpose(-2,-1) : (4, 8, 100)
# centered : (4, 100, 8)
cov = torch.bmm(centered.transpose(-2, -1), centered) / 100 # (4, 8, 8)
# ── Applying different transforms to different batches ───────────
transforms = torch.randn(8, 4, 4) # 8 different 4×4 transform matrices
points = torch.randn(8, 4, 100) # 8 batches of 100 4D points
transformed = torch.bmm(transforms, points) # (8, 4, 100)
# ── Solving multiple linear systems simultaneously ───────────────
# Ax = b for A: (B, N, N), b: (B, N, K) → solve for x: (B, N, K)
A = torch.randn(5, 4, 4)
A = A @ A.transpose(-2,-1) + torch.eye(4) # make positive definite
b = torch.randn(5, 4, 3)
x = torch.linalg.solve(A, b) # (5, 4, 3) — solves all 5 systems
✏️ Practice Problem 9
Problem: You have a batch of 16 square matrices, each 6×6 (shape (16, 6, 6)). Compute the trace (sum of diagonal elements) of each matrix.
Solution:
batch = torch.randn(16, 6, 6)
# Method 1: diagonal + sum
# torch.diagonal returns shape (6, 16) by default — need dim1, dim2
diags = torch.diagonal(batch, dim1=-2, dim2=-1) # (16, 6)
traces = diags.sum(dim=-1) # (16,)
# Method 2: einsum — trace is contraction of diagonal indices
traces = torch.einsum('bii->b', batch) # b=batch, i=diagonal index
# Method 3: Using linalg (most readable)
traces = torch.vmap(torch.trace)(batch) # PyTorch 2.0+ vmap
print(traces.shape) # torch.Size([16])
print(traces[:3])
10. Einstein Summation (einsum)
torch.einsum is the most expressive way to write tensor contractions. Once you understand the notation, it replaces nested loops, transposes, and complex matmul chains with a single readable string.
10.1 The Einsum Grammar
The notation is: 'input_subscripts -> output_subscripts'
- Each letter represents one dimension
- Repeated letters on the left are contracted (summed over)
- Letters appearing only in the output are kept
- Letters in the input but missing from the output are summed away
'ij,jk->ik'
^^ ^^ ^^
A B C
i → output dimension (rows of A / rows of C)
j → contracted dimension (must match: cols of A == rows of B)
k → output dimension (cols of B / cols of C)
Result: C[i,k] = Σ_j A[i,j] * B[j,k]
10.2 Einsum for Common Operations
import torch
A = torch.randn(3, 4)
B = torch.randn(4, 5)
a = torch.randn(3)
b = torch.randn(3)
v = torch.randn(4)
# ── Linear Algebra ──────────────────────────────────────────────
torch.einsum('ij,jk->ik', A, B) # (3,5) — matrix multiply
torch.einsum('ij,j->i', A, v) # (3,) — matrix-vector multiply
torch.einsum('i,i->', a, b) # scalar — dot product
torch.einsum('i,j->ij', a, b) # (3,3) — outer product
torch.einsum('ij->ji', A) # (4,3) — transpose
torch.einsum('ii->', A[:min(A.shape):]) # scalar — trace (square submatrix)
torch.einsum('ij->', A) # scalar — sum all elements
torch.einsum('ij->i', A) # (3,) — row sums
torch.einsum('ij->j', A) # (4,) — column sums
# ── Element-wise / Hadamard ─────────────────────────────────────
C = torch.randn(3, 4)
torch.einsum('ij,ij->ij', A, C) # (3,4) — element-wise multiply (A*C)
torch.einsum('ij,ij->', A, C) # scalar — Frobenius inner product (A*C).sum()
# ── Batched operations ──────────────────────────────────────────
bA = torch.randn(8, 3, 4)
bB = torch.randn(8, 4, 5)
torch.einsum('bij,bjk->bik', bA, bB) # (8,3,5) — batched matmul
# ── Diagonal extraction ─────────────────────────────────────────
sq = torch.randn(4, 4)
torch.einsum('ii->i', sq) # (4,) — diagonal elements
# ── Bilinear form: x^T A y ─────────────────────────────────────
x, y = torch.randn(4), torch.randn(4)
torch.einsum('i,ij,j->', x, sq, y) # scalar
# ── Multi-head attention score computation ─────────────────────
# Q: (batch, heads, seq, d_k) K: (batch, heads, seq, d_k)
B, H, S, d = 2, 4, 16, 32
Q = torch.randn(B, H, S, d)
K = torch.randn(B, H, S, d)
# Attention scores: (batch, heads, seq_q, seq_k)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) # (2,4,16,16)
10.3 Understanding Contraction Step by Step
# Let's trace 'ij,jk->ik' by hand for small tensors
A = torch.tensor([[1., 2.],
[3., 4.]]) # (2,2) — i=2, j=2
B = torch.tensor([[5., 6., 7.],
[8., 9., 10.]]) # (2,3) — j=2, k=3
result = torch.einsum('ij,jk->ik', A, B)
# result[0,0] = A[0,0]*B[0,0] + A[0,1]*B[1,0] = 1*5 + 2*8 = 21
# result[0,1] = A[0,0]*B[0,1] + A[0,1]*B[1,1] = 1*6 + 2*9 = 24
# result[0,2] = A[0,0]*B[0,2] + A[0,1]*B[1,2] = 1*7 + 2*10 = 27
# result[1,0] = A[1,0]*B[0,0] + A[1,1]*B[1,0] = 3*5 + 4*8 = 47
print(result)
# tensor([[21., 24., 27.],
# [47., 54., 61.]])
✏️ Practice Problem 10
Problem: Using einsum, implement:
- Batch dot products: given
A(shape(B, D)) andB(shape(B, D)), compute the dot product of each corresponding pair → shape(B,) - Hadamard product summed: element-wise product of two (M, N) matrices, then sum → scalar
- Tensor contraction: for
Tof shape(A, B, C)andMof shape(C, D), contract the last dim → shape(A, B, D)
Solution:
B_size, D = 8, 16
A_t = torch.randn(B_size, D)
B_t = torch.randn(B_size, D)
M_t = torch.randn(3, 4)
N_t = torch.randn(3, 4)
T = torch.randn(5, 6, 7)
M = torch.randn(7, 9)
# 1. Batch dot products
batch_dots = torch.einsum('bd,bd->b', A_t, B_t)
print(batch_dots.shape) # (8,)
# Verify against loop
for i in range(B_size):
assert torch.allclose(batch_dots[i], torch.dot(A_t[i], B_t[i]))
# 2. Hadamard product summed (Frobenius inner product)
fro = torch.einsum('ij,ij->', M_t, N_t)
print(fro.shape) # scalar = torch.Size([])
# 3. Tensor-matrix contraction over last dim
result = torch.einsum('abc,cd->abd', T, M)
print(result.shape) # (5, 6, 9)
# Verify with matmul (equivalent)
result2 = (T.reshape(-1, 7) @ M).reshape(5, 6, 9)
print(torch.allclose(result, result2)) # True
11. Linear Algebra — linalg Module
torch.linalg is the modern (PyTorch 1.9+) home for all linear algebra operations. Prefer it over older scattered functions.
11.1 Norms
t = torch.tensor([[3., 4.], [1., 2.]])
# Vector norms (when applied to a 1D tensor or with dim specified)
torch.linalg.norm(torch.tensor([3., 4.])) # L2 norm: 5.0 (√(9+16))
torch.linalg.norm(t, ord=2) # spectral norm (max singular value)
torch.linalg.norm(t, ord='fro') # Frobenius norm: √(9+16+1+4) = √30
torch.linalg.norm(t, ord=1) # max column sum
torch.linalg.norm(t, ord=float('inf')) # max row sum
# Per-row or per-column norms (using dim parameter)
torch.linalg.norm(t, dim=1) # L2 norm of each row: [5.0, 2.236]
torch.linalg.norm(t, dim=0) # L2 norm of each column: [3.162, 4.472]
# L1 norm of rows (used in normalisation)
torch.linalg.norm(t, ord=1, dim=1) # sum of abs values per row
# Normalise rows to unit L2 norm
normed = t / torch.linalg.norm(t, dim=1, keepdim=True)
print(torch.linalg.norm(normed, dim=1)) # [1., 1.] — unit vectors
11.2 Determinant and Inverse
A = torch.tensor([[4., 7.], [2., 6.]], dtype=torch.float32)
# Determinant
det = torch.linalg.det(A) # tensor(10.) (4*6 - 7*2)
logdet = torch.linalg.slogdet(A) # (sign, log|det|) — numerically stable
# Inverse — use only when you need the matrix explicitly
A_inv = torch.linalg.inv(A)
print(A @ A_inv) # should be close to identity
# BETTER: when solving Ax = b, never invert!
b = torch.tensor([1., 0.])
x = torch.linalg.solve(A, b) # more stable and faster than inv(A) @ b
# Batched versions work the same way
batch_A = torch.randn(10, 4, 4)
batch_b = torch.randn(10, 4, 2)
batch_A = batch_A @ batch_A.transpose(-2,-1) + torch.eye(4) # positive definite
x = torch.linalg.solve(batch_A, batch_b) # (10, 4, 2)
11.3 Eigendecomposition
# For symmetric/Hermitian matrices (more stable)
A_sym = torch.randn(4, 4)
A_sym = A_sym @ A_sym.T # A_sym is now symmetric positive semi-definite
eigenvalues = torch.linalg.eigvalsh(A_sym) # real eigenvalues (sorted!)
eigenvalues, eigenvectors = torch.linalg.eigh(A_sym) # both
# For general matrices
eigenvalues_complex = torch.linalg.eigvals(A_sym) # may return complex
eigenvalues_c, eigenvectors_c = torch.linalg.eig(A_sym)
# Reconstruction: A = V @ diag(Λ) @ V^{-1}
# For symmetric: A = V @ diag(Λ) @ V.T (V is orthogonal)
A_reconstructed = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.T
print(torch.allclose(A_sym, A_reconstructed, atol=1e-5)) # True
11.4 Singular Value Decomposition (SVD)
SVD is one of the most important decompositions. It works for any matrix (not just square), and is the foundation of PCA, low-rank approximation, and pseudo-inverse.
A = U @ diag(S) @ Vh
A: (M, N) — the original matrix
U: (M, M) — left singular vectors (orthogonal columns)
S: (min(M,N),) — singular values (non-negative, sorted descending)
Vh: (N, N) — right singular vectors (Vh is V^H, the conjugate transpose of V)
A = torch.randn(5, 3) # (5, 3)
# Full SVD
U, S, Vh = torch.linalg.svd(A)
print(U.shape) # (5, 5)
print(S.shape) # (3,) — min(5,3) = 3
print(Vh.shape) # (3, 3)
# Reconstruct A
A_recon = U[:, :3] @ torch.diag(S) @ Vh # need only first 3 cols of U
print(torch.allclose(A, A_recon, atol=1e-5)) # True
# Truncated SVD (economy/thin) — more efficient
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
print(U.shape) # (5, 3) — economy
print(S.shape) # (3,)
print(Vh.shape) # (3, 3)
# Reconstruct directly
A_recon = U @ torch.diag(S) @ Vh
print(torch.allclose(A, A_recon, atol=1e-5)) # True
# ── Low-rank approximation (keep top k singular values) ───────────
def low_rank_approx(A, k):
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
return U[:, :k] @ torch.diag(S[:k]) @ Vh[:k, :]
rank2_approx = low_rank_approx(A, k=2)
print("Frobenius error:", torch.linalg.norm(A - rank2_approx, ord='fro'))
# ── Pseudo-inverse (for non-square or singular matrices) ──────────
A_pinv = torch.linalg.pinv(A) # (3, 5) — right pseudo-inverse
print(A_pinv.shape)
# Satisfies: A @ A_pinv @ A ≈ A
11.5 QR and Cholesky Decompositions
# QR Decomposition: A = Q @ R
# Q: orthogonal, R: upper triangular
A = torch.randn(5, 3)
Q, R = torch.linalg.qr(A)
print(Q.shape) # (5, 3) — economy QR by default
print(R.shape) # (3, 3)
print(torch.allclose(A, Q @ R, atol=1e-5)) # True
# Q is orthogonal: Q.T @ Q ≈ I
print(torch.allclose(Q.T @ Q, torch.eye(3), atol=1e-5)) # True
# Cholesky: A = L @ L.T (for positive definite A)
A = torch.randn(4, 4)
A = A @ A.T + torch.eye(4) # make positive definite
L = torch.linalg.cholesky(A)
print(torch.allclose(L @ L.T, A, atol=1e-5)) # True
✏️ Practice Problem 11 — PCA from Scratch
Problem: Implement PCA (Principal Component Analysis) using SVD to project a data matrix X of shape (200, 10) down to 3 principal components.
Solution:
torch.manual_seed(0)
X = torch.randn(200, 10) # 200 samples, 10 features
# Step 1: Center the data (subtract column means)
mean = X.mean(dim=0) # (10,)
X_centered = X - mean # (200, 10) — broadcast
# Step 2: Compute SVD of the centred data
# Covariance structure is encoded in the right singular vectors Vh
U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
# U: (200, 10)
# S: (10,) — singular values (√eigenvalues of covariance * N)
# Vh: (10, 10) — rows are principal components
# Step 3: Project onto top-3 principal components
# Principal components are the ROWS of Vh (right singular vectors)
components = Vh[:3, :] # (3, 10) — top 3 PCs
X_projected = X_centered @ components.T # (200, 10) @ (10, 3) → (200, 3)
print(f"Original: {X.shape}") # (200, 10)
print(f"Projected: {X_projected.shape}") # (200, 3)
# Explained variance ratio
total_var = (S ** 2).sum()
explained = (S[:3] ** 2) / total_var
print(f"Variance explained: {explained.tolist()}")
print(f"Cumulative: {explained.cumsum(0).tolist()}")
# Reconstruction from top-3 (lossy)
X_reconstructed = X_projected @ components + mean # (200, 3) @ (3, 10) + (10,)
reconstruction_error = torch.linalg.norm(X - X_reconstructed, ord='fro')
print(f"Reconstruction error: {reconstruction_error:.4f}")
12. Advanced Indexing — gather & scatter
These operations are fundamental for NLP (token embedding lookups, masked attention), RL (action selection), and any problem requiring non-uniform selection or writing.
12.1 torch.gather — Selective Read
gather collects values from input at specified index positions along a dimension.
output[i][j][k] = input[i][index[i][j][k]][k] (for dim=1)
Crucially: output has the same shape as index.
t = torch.tensor([[10., 20., 30.],
[40., 50., 60.],
[70., 80., 90.]]) # (3, 3)
# For each row, select specific columns
idx = torch.tensor([[0, 2], # row 0: pick columns 0 and 2
[1, 1], # row 1: pick columns 1 and 1 (same)
[2, 0]]) # row 2: pick columns 2 and 0
result = torch.gather(t, dim=1, index=idx)
# result[0] = [t[0,0], t[0,2]] = [10, 30]
# result[1] = [t[1,1], t[1,1]] = [50, 50]
# result[2] = [t[2,2], t[2,0]] = [90, 70]
print(result)
# tensor([[10., 30.],
# [50., 50.],
# [90., 70.]])
12.2 Practical gather — One-Hot, Token Lookup
# ── Token embedding lookup (what nn.Embedding does internally) ─
vocab_size, embed_dim = 1000, 64
embeddings = torch.randn(vocab_size, embed_dim) # (V, D)
token_ids = torch.randint(0, vocab_size, (8, 16)) # (B, S) batch of token sequences
# Method 1: fancy indexing (simple, works)
embedded = embeddings[token_ids] # (8, 16, 64)
# ── Gathering logits for specific classes ──────────────────────
# batch_size=8, num_classes=10
logits = torch.randn(8, 10)
targets = torch.tensor([3, 7, 1, 0, 9, 4, 2, 6]) # one target per sample
# Gather the logit corresponding to the true class
# index must match logits.ndim — need shape (8, 1)
idx = targets.unsqueeze(1) # (8, 1)
correct_logits = torch.gather(logits, dim=1, index=idx) # (8, 1)
correct_logits = correct_logits.squeeze(1) # (8,)
# ── NLL loss from scratch ──────────────────────────────────────
log_probs = torch.log_softmax(logits, dim=1) # (8, 10)
nll = -torch.gather(log_probs, 1, targets.unsqueeze(1)).squeeze() # (8,)
loss = nll.mean() # scalar
12.3 torch.scatter_ and torch.scatter — Selective Write
scatter_ is the inverse of gather: instead of reading from specific positions, it writes to them. Note the trailing underscore — it's in-place by default.
output[i][index[i][j]][k] = src[i][j][k] (for dim=1)
# ── One-hot encoding ──────────────────────────────────────────
batch_size, num_classes = 8, 5
labels = torch.tensor([0, 3, 1, 4, 2, 0, 3, 1]) # (8,)
one_hot = torch.zeros(batch_size, num_classes) # (8, 5)
one_hot.scatter_(1, labels.unsqueeze(1), 1.0) # write 1.0 at label positions
print(one_hot)
# tensor([[1,0,0,0,0],
# [0,0,0,1,0],
# [0,1,0,0,0], ...])
# ── Scatter add (accumulate) ──────────────────────────────────
# Count occurrences of each token in a batch of sequences
vocab_size = 10
tokens = torch.tensor([3, 1, 3, 5, 1, 2, 3]) # token IDs
counts = torch.zeros(vocab_size)
counts.scatter_add_(0, tokens, torch.ones_like(tokens, dtype=torch.float))
print(counts) # [0,2,1,3,0,1,0,0,0,0] — token 3 appears 3 times
# ── Functional scatter (non-in-place) ─────────────────────────
src = torch.ones(2, 3)
idx = torch.tensor([[0, 1, 2], [0, 1, 4]])
out = torch.zeros(2, 5)
out.scatter_(1, idx, src) # writes src values at idx positions
12.4 torch.index_select, torch.index_put
t = torch.randn(5, 4)
# index_select — select rows or columns by index tensor
selected_rows = torch.index_select(t, 0, torch.tensor([0, 2, 4])) # (3, 4)
selected_cols = torch.index_select(t, 1, torch.tensor([1, 3])) # (5, 2)
# index_put — batch assignment
indices = (torch.tensor([0, 2, 4]), torch.tensor([1, 3, 0]))
values = torch.tensor([99., 99., 99.])
t.index_put_(indices, values) # t[0,1]=99, t[2,3]=99, t[4,0]=99
✏️ Practice Problem 12
Problem: Implement a "masked softmax" without any loops:
- Input:
logitsof shape(B=4, S=8)— attention logits -
maskof shape(B=4, S=8)— boolean,True= position to IGNORE - Compute softmax over logits, setting ignored positions to 0 in the output
Solution:
B, S = 4, 8
logits = torch.randn(B, S)
mask = torch.zeros(B, S, dtype=torch.bool)
mask[:, -2:] = True # ignore last 2 positions in every sequence
# Fill ignored positions with -infinity BEFORE softmax
# This makes them contribute 0 after exp(x) / sum(exp(x))
masked_logits = logits.masked_fill(mask, float('-inf')) # (B, S)
# Numerically stable softmax
attention = torch.softmax(masked_logits, dim=-1) # (B, S)
# Verify: masked positions are 0
print(attention[:, -2:]) # should be [[0,0],[0,0],[0,0],[0,0]]
# Verify: non-masked rows sum to 1
print(attention.sum(dim=-1)) # [1., 1., 1., 1.]
# Alternative: torch.where approach
NEG_INF = torch.full_like(logits, float('-inf'))
masked_logits2 = torch.where(mask, NEG_INF, logits)
attention2 = torch.softmax(masked_logits2, dim=-1)
print(torch.allclose(attention, attention2)) # True
13. Memory Layout, Strides & Contiguity
Understanding this makes you dangerous in code reviews and debugging sessions.
13.1 How Tensors Are Stored
A tensor's data is a flat 1D array in memory. The strides tell PyTorch how many elements to skip to advance by 1 in each dimension.
t = torch.arange(12).reshape(3, 4)
Memory: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
Shape: (3, 4)
Strides: (4, 1)
t[i, j] is at memory offset: i * 4 + j * 1
t[0,0]=0, t[0,1]=1, t[1,0]=4, t[2,3]=11 ✓
t = torch.arange(12).reshape(3, 4)
print(t.stride()) # (4, 1) — C-contiguous (row-major)
t_T = t.T # transpose — SWAP strides, no copy
print(t_T.stride()) # (1, 4) — Fortran-contiguous (column-major)
print(t_T.is_contiguous()) # False
# Strides for a sliced tensor
s = t[::2, ::2] # every other row and column
print(s.stride()) # (8, 2) — skipping over elements
# Make contiguous (allocates new memory, rewrites data)
s_c = s.contiguous()
print(s_c.stride()) # (2, 1) — back to C-contiguous
13.2 Why Contiguity Matters
t = torch.randn(4, 4)
t_T = t.T
# Some operations require contiguous memory
# t_T.view(16) # RuntimeError: non-contiguous
t_T.reshape(16) # OK — reshape handles non-contiguous with an internal copy
# Storage_offset: where in the underlying storage does the tensor start
s = t[1:] # starts at element 4 in storage
print(s.storage_offset()) # 4
# Checking if two tensors share storage
print(t.storage().data_ptr() == t_T.storage().data_ptr()) # True — shared!
print(t.storage().data_ptr() == t_T.contiguous().storage().data_ptr()) # False — copy
13.3 Performance Implications
# Contiguous tensors are cache-friendly — sequential memory access
t = torch.randn(1000, 1000)
# Fast: row-major access matches C-contiguous layout
row_sums = t.sum(dim=1) # iterates along contiguous dimension last
# Operations that force contiguous copies are expensive
# Avoid in tight loops:
t_T = t.T.contiguous() # forces copy
# Do instead: plan your layout so you don't need to transpose repeatedly
# float16 vs float32 memory
t32 = torch.randn(1000, 1000) # 4 MB
t16 = t32.half() # 2 MB — half the memory
t16 = t32.to(torch.float16) # equivalent
# In-place ops save memory (no intermediate tensor)
t.add_(1.0) # modifies t, no new allocation
t.relu_() # same
14. Practice Problems with Full Solutions
These problems are calibrated to the style and difficulty of your assessment.
Problem A — Shape Surgery
Problem: Given x of shape (B=4, C=3, H=8, W=8) (a batch of images), produce a tensor where each image's pixels are flattened into a row, resulting in shape (4, 192).
x = torch.randn(4, 3, 8, 8)
# Method 1: reshape
out = x.reshape(4, -1) # (4, 192)
# Method 2: flatten (cleaner API)
out = x.flatten(start_dim=1) # flatten all dims except batch
# Method 3: view (only if contiguous)
out = x.contiguous().view(4, -1)
print(out.shape) # torch.Size([4, 192])
Problem B — Efficient Softmax from Scratch
Problem: Implement a numerically stable softmax (without using torch.softmax). For input x of shape (B, C), compute softmax over the class dimension.
def stable_softmax(x):
# Subtract max for numerical stability: softmax(x) = softmax(x - c)
x_max = x.max(dim=1, keepdim=True).values # (B, 1)
x_shifted = x - x_max # (B, C) — broadcast
exp_x = torch.exp(x_shifted) # (B, C)
sum_exp = exp_x.sum(dim=1, keepdim=True) # (B, 1)
return exp_x / sum_exp # (B, C) — broadcast
x = torch.randn(4, 10)
out = stable_softmax(x)
# Verify: rows sum to 1
print(out.sum(dim=1)) # [1., 1., 1., 1.]
# Verify: all values non-negative
print((out >= 0).all()) # True
# Verify: matches PyTorch's softmax
print(torch.allclose(out, torch.softmax(x, dim=1), atol=1e-6)) # True
Problem C — Pairwise Cosine Similarity Matrix
Problem: Given a matrix A of shape (N, D) where each row is a vector, compute the full (N, N) cosine similarity matrix without any Python loops.
def cosine_similarity_matrix(A):
# Normalise each row to unit L2 norm
norms = torch.linalg.norm(A, dim=1, keepdim=True) # (N, 1)
norms = torch.clamp(norms, min=1e-8) # avoid division by zero
A_norm = A / norms # (N, D)
# Cosine similarity = dot product of unit vectors = A_norm @ A_norm.T
return A_norm @ A_norm.T # (N, N)
N, D = 10, 16
A = torch.randn(N, D)
sim = cosine_similarity_matrix(A)
print(sim.shape) # (10, 10)
print(sim.diagonal()) # all 1.0 — each vector is identical to itself
print(sim.min(), sim.max()) # values in [-1, 1]
# Verify diagonal is exactly 1.0
print(torch.allclose(sim.diagonal(), torch.ones(N), atol=1e-5)) # True
Problem D — Sliding Window Sum (Strided Trick)
Problem: Given a 1D signal x of length N=12, compute the sum of every window of size k=4 without a Python loop, producing output of length N-k+1=9.
x = torch.arange(1., 13.) # [1,2,3,4,5,6,7,8,9,10,11,12]
k = 4
# Method 1: unfold + sum (most elegant)
windows = x.unfold(dimension=0, size=k, step=1) # (9, 4)
window_sums = windows.sum(dim=1) # (9,)
print(window_sums)
# [10, 14, 18, 22, 26, 30, 34, 38, 42]
# Method 2: cumulative sum trick — O(N) not O(N*k)
cumsum = torch.cat([torch.zeros(1), x.cumsum(0)]) # (13,)
window_sums2 = cumsum[k:] - cumsum[:-k] # (9,)
print(torch.allclose(window_sums, window_sums2)) # True
# What does unfold give us?
print(windows)
# tensor([[ 1, 2, 3, 4],
# [ 2, 3, 4, 5],
# ...
# [ 9, 10, 11, 12]])
Problem E — Masked Fill and Causal Attention Mask
Problem: Create a causal (lower-triangular) mask for a sequence of length S=6 and apply it to an attention score matrix, setting future positions to -inf so softmax ignores them.
S = 6
# Create causal mask — True where we should MASK (upper triangle, excluding diagonal)
causal_mask = torch.triu(torch.ones(S, S, dtype=torch.bool), diagonal=1)
print(causal_mask.int())
# [[0,1,1,1,1,1],
# [0,0,1,1,1,1],
# [0,0,0,1,1,1],
# [0,0,0,0,1,1],
# [0,0,0,0,0,1],
# [0,0,0,0,0,0]]
# Apply to attention scores
attention_scores = torch.randn(2, 4, S, S) # (batch, heads, S, S)
# Broadcast mask: (S, S) → applied to all batches and heads automatically
masked_scores = attention_scores.masked_fill(causal_mask, float('-inf'))
# Compute attention weights
attention_weights = torch.softmax(masked_scores, dim=-1) # (2,4,S,S)
# Verify: upper triangle is 0 (softmax of -inf)
print(attention_weights[0, 0]) # lower triangular structure
print((attention_weights[0, 0] > 0).int()) # should be lower triangular
Problem F — Batch Matrix Inversion with Validation
Problem: Given a batch of B=5 random 4×4 matrices, make them invertible (positive definite), compute their inverses, and verify A @ A_inv ≈ I for each.
B, N = 5, 4
# Generate random positive definite matrices
A_raw = torch.randn(B, N, N)
A = torch.bmm(A_raw, A_raw.transpose(-2, -1)) + torch.eye(N) # (B,N,N)
# Compute batch inverse
A_inv = torch.linalg.inv(A) # (B, N, N)
# Verify: A @ A_inv should be identity for each batch element
product = torch.bmm(A, A_inv) # (B, N, N)
identity = torch.eye(N).expand(B, N, N)
max_error = (product - identity).abs().max()
print(f"Max deviation from identity: {max_error:.2e}") # should be ~1e-6
# Better approach for Ax=b: don't invert, just solve
b = torch.randn(B, N, 2)
x = torch.linalg.solve(A, b) # (B, N, 2) — numerically superior
print(torch.allclose(torch.bmm(A, x), b, atol=1e-5)) # True
Problem G — Custom Distance Matrix (Assessment Style)
Problem: Implement L2 (Euclidean) pairwise distance between two sets of points without any Python loops:
-
A: shape(N=50, D=8) -
B: shape(M=30, D=8) - Output:
(N, M)whereout[i,j] = ||A[i] - B[j]||_2
def pairwise_l2(A, B):
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2 * a.b
# This avoids the (N, M, D) intermediate tensor from naive subtraction
A_sq = (A ** 2).sum(dim=1, keepdim=True) # (N, 1)
B_sq = (B ** 2).sum(dim=1) # (M,) → need (1, M)
# (N,1) + (1,M) - 2*(N,M) → (N,M)
sq_dist = A_sq + B_sq.unsqueeze(0) - 2 * (A @ B.T)
# Clamp to avoid negative values from floating point errors near 0
return torch.sqrt(torch.clamp(sq_dist, min=0.0))
N, M, D = 50, 30, 8
A = torch.randn(N, D)
B = torch.randn(M, D)
dist = pairwise_l2(A, B)
print(dist.shape) # (50, 30)
# Verify one distance manually
i, j = 3, 7
manual = torch.linalg.norm(A[i] - B[j])
print(torch.allclose(dist[i, j], manual, atol=1e-5)) # True
# Find nearest neighbour for each point in A
nn_indices = dist.argmin(dim=1) # (N,) — index in B closest to each point in A
nn_dists = dist.min(dim=1).values
print(nn_indices.shape, nn_dists.shape) # (50,) (50,)
Problem H — Full Einsum Challenge
Problem: Use einsum only (no @, mm, matmul) to:
- Compute
W @ X.TwhereW: (4,3),X: (5,3)→(4,5) - Compute the batch outer product of rows of
A: (B,D)→(B,D,D) - Compute a weighted sum:
weights: (B,S),values: (B,S,D)→(B,D)
W = torch.randn(4, 3)
X = torch.randn(5, 3)
A = torch.randn(8, 6)
B_size, S, D_size = 4, 10, 16
weights = torch.softmax(torch.randn(B_size, S), dim=1)
values = torch.randn(B_size, S, D_size)
# 1. W @ X.T: (4,3) × (3,5) → (4,5)
result1 = torch.einsum('ij,kj->ik', W, X) # j is contracted, k appears in output
print(result1.shape) # (4, 5)
print(torch.allclose(result1, W @ X.T)) # True
# 2. Batch outer product: for each b, outer(A[b], A[b]) → (B,D,D)
result2 = torch.einsum('bi,bj->bij', A, A) # no contraction — pure outer
print(result2.shape) # (8, 6, 6)
# result2[b] = outer product of A[b] with itself
# 3. Weighted sum: for each batch, sum_s weight[b,s]*value[b,s,:] → (B,D)
result3 = torch.einsum('bs,bsd->bd', weights, values)
print(result3.shape) # (4, 16)
# Verify against manual:
result3_manual = (weights.unsqueeze(-1) * values).sum(dim=1) # (B,1)*(B,S,D).sum(1)
print(torch.allclose(result3, result3_manual, atol=1e-6)) # True
15. Assessment Cheat Sheet
A one-stop reference for the most common operations you'll need in 60 minutes.
Tensor Creation
| Operation | Code | Output Shape |
|---|---|---|
| From list | torch.tensor([[1,2],[3,4]]) |
(2,2) |
| Zeros | torch.zeros(3,4) |
(3,4) |
| Identity | torch.eye(N) |
(N,N) |
| Range | torch.arange(0, 10, 2) |
(5,) |
| Linspace | torch.linspace(0, 1, 5) |
(5,) |
| Normal | torch.randn(B, D) |
(B,D) |
| Like | torch.zeros_like(x) |
same as x
|
Shape Operations
| Operation | Code | Notes |
|---|---|---|
| Reshape | t.reshape(2, -1) |
-1 = infer |
| View | t.view(N) |
must be contiguous |
| Transpose (2D) |
t.T or t.t()
|
shares memory |
| Transpose (any) | t.transpose(1, 2) |
swap two dims |
| Permute | t.permute(2,0,1) |
reorder all dims |
| Add dim | t.unsqueeze(0) |
insert size-1 |
| Remove dim | t.squeeze(0) |
remove size-1 |
| Flatten | t.flatten(1) |
flatten from dim |
| Cat | torch.cat([a,b], 0) |
join existing dim |
| Stack | torch.stack([a,b], 0) |
create new dim |
Indexing
| Pattern | Code | Notes |
|---|---|---|
| Slice (view) | t[1:3, :] |
shares memory |
| Bool mask (copy) | t[t > 0] |
1D result |
| Fancy (copy) | t[[0,2,4]] |
new tensor |
| Clone | t.clone() |
independent copy |
| Masked fill | t.masked_fill(mask, -inf) |
in-place variant: masked_fill_
|
| Where | torch.where(c, a, b) |
vectorised ternary |
Matrix Operations
| Operation | Code | Dims | Notes |
|---|---|---|---|
| Matrix multiply | A @ B |
2D+ | universal |
| 2D only | torch.mm(A, B) |
2D | explicit |
| Batched | torch.bmm(A, B) |
3D | no broadcast |
| Batch broadcast | torch.matmul(A, B) |
any | broadcasts |
| Einsum | torch.einsum('ij,jk->ik', A, B) |
any | most expressive |
| Outer product | torch.outer(a, b) |
1D | returns 2D |
| Dot product | torch.dot(a, b) |
1D | scalar |
Reductions
| Operation | Code | Notes |
|---|---|---|
| Sum all | t.sum() |
scalar |
| Sum axis | t.sum(dim=1) |
collapses dim 1 |
| Keep dim | t.sum(dim=1, keepdim=True) |
for broadcasting |
| Argmax | t.argmax(dim=1) |
index of max |
| Topk | torch.topk(t, k=5) |
values + indices |
| Norm | torch.linalg.norm(t, dim=1) |
L2 per row |
Linear Algebra
| Operation | Code | Notes |
|---|---|---|
| Inverse | torch.linalg.inv(A) |
square only |
| Solve Ax=b | torch.linalg.solve(A, b) |
prefer over inv |
| SVD | torch.linalg.svd(A, full_matrices=False) |
returns U,S,Vh |
| Eigenvalues | torch.linalg.eigvalsh(A) |
symmetric |
| QR | torch.linalg.qr(A) |
orthogonal + upper tri |
| Det | torch.linalg.det(A) |
scalar |
| Norm | torch.linalg.norm(A, ord='fro') |
Frobenius |
| Pseudo-inv | torch.linalg.pinv(A) |
for non-square |
Advanced Indexing
| Operation | Code | Notes |
|---|---|---|
| Gather | torch.gather(t, dim=1, index=idx) |
output = idx.shape |
| Scatter | out.scatter_(1, idx, src) |
in-place write |
| Scatter add | out.scatter_add_(0, idx, src) |
accumulate |
| Index select | t.index_select(0, idx) |
select rows/cols |
Common Broadcasting Patterns
# Row-wise operation: subtract column mean
mean = t.mean(dim=0) # (C,)
t - mean # (N,C) - (C,) ✓ (right-aligns)
# Col-wise operation: subtract row mean
mean = t.mean(dim=1, keepdim=True) # (N,1) — keepdim is CRITICAL
t - mean # (N,C) - (N,1) → (N,C) ✓
# 3D broadcast: add bias to each batch
bias = torch.randn(C) # (C,)
t = torch.randn(B, N, C)
t + bias # (B,N,C) + (C,) ✓
# Outer product via unsqueeze
a = torch.randn(M) # (M,)
b = torch.randn(N) # (N,)
a.unsqueeze(1) * b.unsqueeze(0) # (M,1) * (1,N) → (M,N) ✓
Critical Gotchas to Avoid in the Assessment
-
a * bis element-wise, NOT matrix multiply — usea @ bfor matmul -
t.Tis non-contiguous — if the next op needs contiguity, call.contiguous()first -
torch.dotis 1D only — use@oreinsumfor higher dimensions -
reshapevsview: view fails on non-contiguous tensors, reshape doesn't - Slices are views — mutating a slice mutates the original tensor
-
keepdim=Truein reductions — critical when you need to broadcast the result back -
bmmdoes NOT broadcast — usematmulor@for broadcasting batch operations - Float/int mixing — always cast before mixed operations to avoid type errors
-
squeeze()with no args removes ALL size-1 dims — be specific:squeeze(0) -
gatheroutput shape = index shape — size the index tensor as the output you want
Happy studying! Focus on the Practice Problems (section 14) — work through each one without looking at the solution first, then check your answer. That's the fastest way to build the pattern recognition you'll need under time pressure.
Top comments (0)