Shampoo Optimizer
Refresher on machine learning optimizers
1. SGD - First-Order Methods
Stochastic Gradient Descent (SGD) uses first-order gradients to update parameters:
\[\theta = \theta - \eta \nabla f(\theta)\]SGD treats all parameters equally, leading to slow convergence when parameters have different scales or when the loss landscape is ill-conditioned.
2. AdaGrad - First Adaptive Learning Rates
AdaGrad introduced adaptive learning rates by accumulating squared gradients and use as diagonal preconditioner.
\[\begin{align} G_i &= \Sigma g_{i}^2 &&\text{accumulate squared gradient per parameter}\\ \theta_{i} &= \theta_{i} - \eta\frac{ g_{i}}{\sqrt{G_i + \epsilon}} &&\text{scale learning rate per parameter} \end{align}\]-
Benefits: Automatic learning rate scaling per parameter
-
Limitation: Aggressive learning rate decay, treats parameters independently
3. Adam - Exponential Moving Averages
Adam improved upon AdaGrad by using exponential moving averages as preconditioner.
\[\begin{align} m &= \beta_1 m + (1-\beta_1) g &&\text{moving average first moment}\\ v &= \beta_2 v + (1-\beta_2) g^2 &&\text{moving average second moment}\\ \hat{m} &= \frac{m}{1-\beta_1^t} &&\text{bias correction}\\ \hat{v} &= \frac{v}{1-\beta_2^t} &&\text{bias correction} \\ \theta &= \theta - \eta \frac{\hat{m}}{\sqrt{\hat{v} + \epsilon}} &&\text{} \end{align}\]- Benefits: Solves AdaGrad’s vanishing learning rate problem
- Limitation: Still diagonal preconditioning - no parameter correlations
4. Full-Matrix AdaGrad
The theoretical ideal would be full-matrix preconditioning:
\[\begin{align} G_{full} &= \Sigma G \otimes G^T &&\text{accumulate full outer product as approx to covariance matrix} \\ \theta &= \theta -\eta G_{full}^{-\frac{1}{2}}\otimes g &&\text{use full matrix inverse} \end{align}\]Why it’s impossible for larger models:
- Memory: \(O(d²)\) - For 1M parameters: 1TB just for one matrix
- Computation: \(O(d³)\) matrix inversion every step
- Example: GPT-3 has 175B parameters → impossible
Shampoo: Structured Preconditioning
Shampoo, introduced in “Shampoo: Preconditioned Stochastic Tensor Optimization” (2018).
Core Insight: Kronecker Product Approximation
Instead of treating a weight matrix \(\Theta ∈ R^{m×n}\) as a flat vector of \(m\cdot n\) parameters, Shampoo maintains separate statistics for each dimension:
\[\begin{align} L &=\Sigma G \otimes G^T &&\text{Left preconditioner m x m: row correlations}\\ R &=\Sigma G^T\otimes G &&\text{right preconditioner n x n: column correlations}\\ \theta &= \theta -\eta L^{-\frac{1}{4}} \otimes G \otimes R^{-\frac{1}{4}} \end{align}\]The approximation: This is equivalent to approximating the full \(mn×mn\) preconditioner as:
\[G_{full} \approx L^{\frac{1}{2}}\otimes R^{\frac{1}{2}} \quad \text{Kronecker Product Approximation}\]Memory and Computation Savings
For a layer with weight matrix [1000×2000]:
- Full-matrix AdaGrad: 2M×2M = 4T
- Shampoo: 1000² + 2000² = 5M
- Savings: 800,000 fold
Simple Shampoo Implementation
Here’s a basic PyTorch implementation:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
class SimpleShampoo(optim.Optimizer):
def __init__(self, params, lr=0.01, eps=1e-7, update_freq=1):
defaults = dict(lr=lr, eps=eps, update_freq=update_freq)
super().__init__(params, defaults)
def step(self):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
preconditioned_grad = grad
# Initialize state
if len(state) == 0:
state["step"] = 0
if len(grad.shape) == 1: # Vector - use diagonal
state["G"] = torch.zeros_like(grad) + group["eps"]
elif len(grad.shape) == 2: # Matrix - use Shampoo
m, n = grad.shape
state["L"] = torch.eye(m, device=grad.device) * group["eps"]
state["R"] = torch.eye(n, device=grad.device) * group["eps"]
else: # Higher order - fallback to diagonal
state["G"] = torch.zeros_like(grad)
state["step"] += 1
if len(grad.shape) == 2: # Matrix case
# Update statistics
state["L"] += grad @ grad.T
state["R"] += grad.T @ grad
# Compute preconditioned gradient every update_freq steps
if state["step"] % group["update_freq"] == 0:
# Compute matrix power: M^(-1/4)
L_eig_vals, L_eig_vecs = torch.linalg.eigh(state["L"])
R_eig_vals, R_eig_vecs = torch.linalg.eigh(state["R"])
L_inv_quarter = (
L_eig_vecs
@ torch.diag(
torch.pow(
torch.clamp(L_eig_vals, min=group["eps"]), -0.25
)
)
@ L_eig_vecs.T
)
R_inv_quarter = (
R_eig_vecs
@ torch.diag(
torch.pow(
torch.clamp(R_eig_vals, min=group["eps"]), -0.25
)
)
@ R_eig_vecs.T
)
state["L_inv_quarter"] = L_inv_quarter
state["R_inv_quarter"] = R_inv_quarter
# Apply preconditioned update
if "L_inv_quarter" in state:
preconditioned_grad = (
state["L_inv_quarter"] @ grad @ state["R_inv_quarter"]
)
else:
preconditioned_grad = grad
else: # Vector or tensor - use diagonal
state["G"] += grad * grad
preconditioned_grad = grad / (torch.sqrt(state["G"]) + group["eps"])
# Update parameters
p.data -= group["lr"] * preconditioned_grad
optimizers = {"Adam": optim.Adam, "AdaGrad": optim.Adagrad, "Shampoo": SimpleShampoo}
# Create a ill-conditioned quadratic problem
def create_ill_conditioned_data(num_samples=1000, condition_number=1000):
"""Create data where some features are much more important than others"""
torch.manual_seed(42)
input_dim = 50
# Create ill-conditioned covariance matrix
U, _ = torch.linalg.qr(
torch.randn(input_dim, input_dim)
) # Random orthogonal matrix
eigenvals = torch.logspace(
0, np.log10(condition_number), input_dim
) # Large condition number
cov_matrix = U @ torch.diag(eigenvals) @ U.T
# Generate correlated features
X = torch.randn(num_samples, input_dim) @ torch.linalg.cholesky(cov_matrix)
# True weights with different scales (some very important, some not)
true_weights = torch.zeros(input_dim)
true_weights[:5] = torch.randn(5) * 10 # Very important features
true_weights[5:15] = torch.randn(10) * 1 # Moderately important
true_weights[15:] = torch.randn(35) * 0.1 # Less important
y = X @ true_weights + torch.randn(num_samples) * 0.1
return X, y.unsqueeze(1)
def train_model(optimizer_name="Shampoo"):
assert optimizer_name in optimizers, "Expecting a known optimizer in " + ", ".join(
optimizers.keys()
)
# Create data
X, y = create_ill_conditioned_data()
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Create model
model = nn.Linear(X.shape[1], 1)
optimizer = optimizers[optimizer_name](model.parameters(), lr=0.01)
criterion = nn.MSELoss()
print("Starting training...")
print(f"Using optimizer: {type(optimizer).__name__}")
# Training loop
for epoch in range(50):
total_loss = 0.0
num_batches = 0
for batch_idx, (batch_x, batch_y) in enumerate(dataloader):
# Forward pass
optimizer.zero_grad()
predictions = model(batch_x)
loss = criterion(predictions, batch_y)
# Backward pass
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
if epoch % 5 == 0:
print(f"Epoch {epoch:2d}: Average Loss = {avg_loss:.6f}")
print("Training completed!")
return model, avg_loss
if __name__ == "__main__":
for optimizer_name in optimizers.keys():
print("Training with ", optimizer_name, "Optimizer:")
model, final_loss = train_model(optimizer_name)
Improvements and Practical Considerations
1. Grafting for Stability
Google’s implementation uses “grafting” to fix the layerwise scale of Shampoo updates:
# Compute both updates
shampoo_update = L_inv @ grad @ R_inv
diagonal_update = grad / sqrt(accumulated_grad_squares)
# Scale Shampoo to match diagonal magnitude
scale = norm(diagonal_update) / norm(shampoo_update)
final_update = scale * shampoo_update
2. Delayed Preconditioning
Start with simpler methods, then gradually transition to Shampoo after warm-up steps
if step < start_preconditioning_steps:
update = diagonal_update # Use AdaGrad initially
else:
warmup_factor = min(1.0, (step - start_steps) / start_steps)
update = warmup_factor * shampoo_update + (1 - warmup_factor) * diagonal_update
3. SOAP: Adam in Shampoo’s Eigenbasis
Recent work introduced SOAP, which runs “Adam in the Preconditioner’s eigenbasis”
- Decompose preconditioner: \(P = Q \otimes Λ \otimes Q.T\)
- Transform gradient: \(g_rotated = Q.T \otimes grad \otimes Q\)
- Run Adam on g_rotated
- Transform back: \(\text{update} = Q \otimes \text{adam_update} \otimes Q.T\)
Production Implementations
TensorFlow/Lingvo Implementation
The TensorFlow implementation focuses on practical deployment with CPU-based preconditioner computation:
Key Features:
- Asynchronous preconditioning: Expensive matrix operations run on CPU while GPUs continue training
- Simple partitioning: Splits large tensors when dimensions exceed thresholds
- Grafting integration: Built-in support for scaling strategies
def invoke_async_preconditioner_computation(self, global_step):
"""Computes preconditioners asynchronously on CPU"""
return x_ops.compute_preconditioners(
stats, exponents, global_step,
sync=self._synchronous_preconditioning,
preconditioner_compute_graphdef=self._preconditioner_compute_graphdef)
JAX Distributed Implementation
The JAX version provides full distributed training support with advanced features:
Advanced Features:
- Quantized statistics: Reduces memory usage through
QuantizedValue
storage - Sharded computation: Distributes preconditioner computation across devices
- Global statistics aggregation: Coordinates statistics across multiple devices
@struct.dataclass
class ShardedShampooStats:
global_stats: Any # Statistics aggregated across all devices
local_stats: Any # Device-local statistics
class LocalShardedParameterStats:
index_start: int # Starting index in global statistics array
sizes: Any # Partition sizes for this device
Distributed Training Flow:
- Local computation: Each device computes gradients and updates local statistics
- Periodic synchronization: Every N steps, aggregate statistics across devices
- Centralized preconditioning: Master device computes preconditioners
- Broadcast updates: Distribute preconditioners back to all devices
References
The key papers and implementations:
- Preconditing in iterative solvers
- Shampoo: Preconditioned Stochastic Tensor Optimization” (2018)
- Google Tensorflow(Lingvo) Implementation
- Google JAX Implementation
- Google Pytorch(!) Implementation
- Scalable Second Order Optimization for Deep Learning” (2020)
- SOAP: Improving and Stabilizing Shampoo using Adam (2024)