Refresher on machine learning optimizers

1. SGD - First-Order Methods

Stochastic Gradient Descent (SGD) uses first-order gradients to update parameters:

\[\theta = \theta - \eta \nabla f(\theta)\]

SGD treats all parameters equally, leading to slow convergence when parameters have different scales or when the loss landscape is ill-conditioned.

2. AdaGrad - First Adaptive Learning Rates

AdaGrad introduced adaptive learning rates by accumulating squared gradients and use as diagonal preconditioner.

\[\begin{align} G_i &= \Sigma g_{i}^2 &&\text{accumulate squared gradient per parameter}\\ \theta_{i} &= \theta_{i} - \eta\frac{ g_{i}}{\sqrt{G_i + \epsilon}} &&\text{scale learning rate per parameter} \end{align}\]
  • Benefits: Automatic learning rate scaling per parameter

  • Limitation: Aggressive learning rate decay, treats parameters independently

3. Adam - Exponential Moving Averages

Adam improved upon AdaGrad by using exponential moving averages as preconditioner.

\[\begin{align} m &= \beta_1 m + (1-\beta_1) g &&\text{moving average first moment}\\ v &= \beta_2 v + (1-\beta_2) g^2 &&\text{moving average second moment}\\ \hat{m} &= \frac{m}{1-\beta_1^t} &&\text{bias correction}\\ \hat{v} &= \frac{v}{1-\beta_2^t} &&\text{bias correction} \\ \theta &= \theta - \eta \frac{\hat{m}}{\sqrt{\hat{v} + \epsilon}} &&\text{} \end{align}\]
  • Benefits: Solves AdaGrad’s vanishing learning rate problem
  • Limitation: Still diagonal preconditioning - no parameter correlations

4. Full-Matrix AdaGrad

The theoretical ideal would be full-matrix preconditioning:

\[\begin{align} G_{full} &= \Sigma G \otimes G^T &&\text{accumulate full outer product as approx to covariance matrix} \\ \theta &= \theta -\eta G_{full}^{-\frac{1}{2}}\otimes g &&\text{use full matrix inverse} \end{align}\]

Why it’s impossible for larger models:

  • Memory: \(O(d²)\) - For 1M parameters: 1TB just for one matrix
  • Computation: \(O(d³)\) matrix inversion every step
  • Example: GPT-3 has 175B parameters → impossible

Shampoo: Structured Preconditioning

Shampoo, introduced in “Shampoo: Preconditioned Stochastic Tensor Optimization” (2018).

Core Insight: Kronecker Product Approximation

Instead of treating a weight matrix \(\Theta ∈ R^{m×n}\) as a flat vector of \(m\cdot n\) parameters, Shampoo maintains separate statistics for each dimension:

\[\begin{align} L &=\Sigma G \otimes G^T &&\text{Left preconditioner m x m: row correlations}\\ R &=\Sigma G^T\otimes G &&\text{right preconditioner n x n: column correlations}\\ \theta &= \theta -\eta L^{-\frac{1}{4}} \otimes G \otimes R^{-\frac{1}{4}} \end{align}\]

The approximation: This is equivalent to approximating the full \(mn×mn\) preconditioner as:

\[G_{full} \approx L^{\frac{1}{2}}\otimes R^{\frac{1}{2}} \quad \text{Kronecker Product Approximation}\]

Memory and Computation Savings

For a layer with weight matrix [1000×2000]:

  • Full-matrix AdaGrad: 2M×2M = 4T
  • Shampoo: 1000² + 2000² = 5M
  • Savings: 800,000 fold

Simple Shampoo Implementation

Here’s a basic PyTorch implementation:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np


class SimpleShampoo(optim.Optimizer):
    def __init__(self, params, lr=0.01, eps=1e-7, update_freq=1):
        defaults = dict(lr=lr, eps=eps, update_freq=update_freq)
        super().__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad.data
                state = self.state[p]

                preconditioned_grad = grad

                # Initialize state
                if len(state) == 0:
                    state["step"] = 0
                    if len(grad.shape) == 1:  # Vector - use diagonal
                        state["G"] = torch.zeros_like(grad) + group["eps"]
                    elif len(grad.shape) == 2:  # Matrix - use Shampoo
                        m, n = grad.shape
                        state["L"] = torch.eye(m, device=grad.device) * group["eps"]
                        state["R"] = torch.eye(n, device=grad.device) * group["eps"]
                    else:  # Higher order - fallback to diagonal
                        state["G"] = torch.zeros_like(grad)

                state["step"] += 1

                if len(grad.shape) == 2:  # Matrix case
                    # Update statistics
                    state["L"] += grad @ grad.T
                    state["R"] += grad.T @ grad

                    # Compute preconditioned gradient every update_freq steps
                    if state["step"] % group["update_freq"] == 0:
                        # Compute matrix power: M^(-1/4)
                        L_eig_vals, L_eig_vecs = torch.linalg.eigh(state["L"])
                        R_eig_vals, R_eig_vecs = torch.linalg.eigh(state["R"])
                        L_inv_quarter = (
                            L_eig_vecs
                            @ torch.diag(
                                torch.pow(
                                    torch.clamp(L_eig_vals, min=group["eps"]), -0.25
                                )
                            )
                            @ L_eig_vecs.T
                        )
                        R_inv_quarter = (
                            R_eig_vecs
                            @ torch.diag(
                                torch.pow(
                                    torch.clamp(R_eig_vals, min=group["eps"]), -0.25
                                )
                            )
                            @ R_eig_vecs.T
                        )

                        state["L_inv_quarter"] = L_inv_quarter
                        state["R_inv_quarter"] = R_inv_quarter

                    # Apply preconditioned update
                    if "L_inv_quarter" in state:
                        preconditioned_grad = (
                            state["L_inv_quarter"] @ grad @ state["R_inv_quarter"]
                        )
                    else:
                        preconditioned_grad = grad

                else:  # Vector or tensor - use diagonal
                    state["G"] += grad * grad
                    preconditioned_grad = grad / (torch.sqrt(state["G"]) + group["eps"])

                # Update parameters
                p.data -= group["lr"] * preconditioned_grad


optimizers = {"Adam": optim.Adam, "AdaGrad": optim.Adagrad, "Shampoo": SimpleShampoo}


# Create a ill-conditioned quadratic problem
def create_ill_conditioned_data(num_samples=1000, condition_number=1000):
    """Create data where some features are much more important than others"""
    torch.manual_seed(42)
    input_dim = 50

    # Create ill-conditioned covariance matrix
    U, _ = torch.linalg.qr(
        torch.randn(input_dim, input_dim)
    )  # Random orthogonal matrix
    eigenvals = torch.logspace(
        0, np.log10(condition_number), input_dim
    )  # Large condition number
    cov_matrix = U @ torch.diag(eigenvals) @ U.T

    # Generate correlated features
    X = torch.randn(num_samples, input_dim) @ torch.linalg.cholesky(cov_matrix)

    # True weights with different scales (some very important, some not)
    true_weights = torch.zeros(input_dim)
    true_weights[:5] = torch.randn(5) * 10  # Very important features
    true_weights[5:15] = torch.randn(10) * 1  # Moderately important
    true_weights[15:] = torch.randn(35) * 0.1  # Less important

    y = X @ true_weights + torch.randn(num_samples) * 0.1
    return X, y.unsqueeze(1)


def train_model(optimizer_name="Shampoo"):
    assert optimizer_name in optimizers, "Expecting a known optimizer in " + ", ".join(
        optimizers.keys()
    )

    # Create data
    X, y = create_ill_conditioned_data()
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Create model
    model = nn.Linear(X.shape[1], 1)

    optimizer = optimizers[optimizer_name](model.parameters(), lr=0.01)
    criterion = nn.MSELoss()

    print("Starting training...")
    print(f"Using optimizer: {type(optimizer).__name__}")

    # Training loop
    for epoch in range(50):
        total_loss = 0.0
        num_batches = 0

        for batch_idx, (batch_x, batch_y) in enumerate(dataloader):
            # Forward pass
            optimizer.zero_grad()
            predictions = model(batch_x)
            loss = criterion(predictions, batch_y)

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        if epoch % 5 == 0:
            print(f"Epoch {epoch:2d}: Average Loss = {avg_loss:.6f}")

    print("Training completed!")
    return model, avg_loss


if __name__ == "__main__":
    for optimizer_name in optimizers.keys():
        print("Training with ", optimizer_name, "Optimizer:")
        model, final_loss = train_model(optimizer_name)

Improvements and Practical Considerations

1. Grafting for Stability

Google’s implementation uses “grafting” to fix the layerwise scale of Shampoo updates:

# Compute both updates
shampoo_update = L_inv @ grad @ R_inv
diagonal_update = grad / sqrt(accumulated_grad_squares)

# Scale Shampoo to match diagonal magnitude
scale = norm(diagonal_update) / norm(shampoo_update)
final_update = scale * shampoo_update

2. Delayed Preconditioning

Start with simpler methods, then gradually transition to Shampoo after warm-up steps

if step < start_preconditioning_steps:
    update = diagonal_update  # Use AdaGrad initially
else:
    warmup_factor = min(1.0, (step - start_steps) / start_steps)
    update = warmup_factor * shampoo_update + (1 - warmup_factor) * diagonal_update

3. SOAP: Adam in Shampoo’s Eigenbasis

Recent work introduced SOAP, which runs “Adam in the Preconditioner’s eigenbasis”

  1. Decompose preconditioner: \(P = Q \otimes Λ \otimes Q.T\)
  2. Transform gradient: \(g_rotated = Q.T \otimes grad \otimes Q\)
  3. Run Adam on g_rotated
  4. Transform back: \(\text{update} = Q \otimes \text{adam_update} \otimes Q.T\)

Production Implementations

TensorFlow/Lingvo Implementation

The TensorFlow implementation focuses on practical deployment with CPU-based preconditioner computation:

Key Features:

  • Asynchronous preconditioning: Expensive matrix operations run on CPU while GPUs continue training
  • Simple partitioning: Splits large tensors when dimensions exceed thresholds
  • Grafting integration: Built-in support for scaling strategies
def invoke_async_preconditioner_computation(self, global_step):
    """Computes preconditioners asynchronously on CPU"""
    return x_ops.compute_preconditioners(
        stats, exponents, global_step,
        sync=self._synchronous_preconditioning,
        preconditioner_compute_graphdef=self._preconditioner_compute_graphdef)

JAX Distributed Implementation

The JAX version provides full distributed training support with advanced features:

Advanced Features:

  • Quantized statistics: Reduces memory usage through QuantizedValue storage
  • Sharded computation: Distributes preconditioner computation across devices
  • Global statistics aggregation: Coordinates statistics across multiple devices
@struct.dataclass
class ShardedShampooStats:
    global_stats: Any      # Statistics aggregated across all devices  
    local_stats: Any       # Device-local statistics

class LocalShardedParameterStats:
    index_start: int       # Starting index in global statistics array
    sizes: Any            # Partition sizes for this device

Distributed Training Flow:

  1. Local computation: Each device computes gradients and updates local statistics
  2. Periodic synchronization: Every N steps, aggregate statistics across devices
  3. Centralized preconditioning: Master device computes preconditioners
  4. Broadcast updates: Distribute preconditioners back to all devices

References

The key papers and implementations: