<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://ryanzhang.info/feed.xml" rel="self" type="application/atom+xml" /><link href="https://ryanzhang.info/" rel="alternate" type="text/html" /><updated>2025-09-16T02:12:10+00:00</updated><id>https://ryanzhang.info/feed.xml</id><title type="html">Ren’s Cabinet of Curiosities</title><subtitle>Learning is never cumulative, it is a movement of knowing which has no beginning and no end. – Bruce Lee</subtitle><author><name>Ren Zhang</name></author><entry><title type="html">Shampoo Optimizer</title><link href="https://ryanzhang.info/post/2025/06/17/Shampoo-Optimizer.html" rel="alternate" type="text/html" title="Shampoo Optimizer" /><published>2025-06-17T03:23:43+00:00</published><updated>2025-06-17T03:23:43+00:00</updated><id>https://ryanzhang.info/post/2025/06/17/Shampoo-Optimizer</id><content type="html" xml:base="https://ryanzhang.info/post/2025/06/17/Shampoo-Optimizer.html"><![CDATA[<h2 id="refresher-on-machine-learning-optimizers">Refresher on machine learning optimizers</h2>

<h3 id="1-sgd---first-order-methods">1. SGD - First-Order Methods</h3>

<p>Stochastic Gradient Descent (SGD) uses first-order gradients to update parameters:</p>

\[\theta = \theta - \eta \nabla f(\theta)\]

<p>SGD treats all parameters equally, leading to slow convergence when parameters have different scales or when the loss landscape is ill-conditioned.</p>

<h3 id="2-adagrad---first-adaptive-learning-rates">2. AdaGrad - First Adaptive Learning Rates</h3>

<p>AdaGrad introduced adaptive learning rates by accumulating squared gradients and use as diagonal preconditioner.</p>

\[\begin{align}
G_i &amp;= \Sigma g_{i}^2 &amp;&amp;\text{accumulate squared gradient per parameter}\\
\theta_{i} &amp;= \theta_{i} - \eta\frac{ g_{i}}{\sqrt{G_i + \epsilon}} &amp;&amp;\text{scale learning rate per parameter}
\end{align}\]

<ul>
  <li>
    <p><strong>Benefits</strong>: Automatic learning rate scaling per parameter</p>
  </li>
  <li>
    <p><strong>Limitation</strong>: Aggressive learning rate decay, treats parameters independently</p>
  </li>
</ul>

<h3 id="3-adam---exponential-moving-averages">3. Adam - Exponential Moving Averages</h3>

<p>Adam improved upon AdaGrad by using exponential moving averages as preconditioner.</p>

\[\begin{align} 
m &amp;= \beta_1 m + (1-\beta_1) g &amp;&amp;\text{moving average first moment}\\
v &amp;= \beta_2 v + (1-\beta_2) g^2 &amp;&amp;\text{moving average second moment}\\ 
\hat{m} &amp;= \frac{m}{1-\beta_1^t} &amp;&amp;\text{bias correction}\\ 
\hat{v} &amp;= \frac{v}{1-\beta_2^t} &amp;&amp;\text{bias correction} \\
\theta &amp;= \theta - \eta \frac{\hat{m}}{\sqrt{\hat{v} + \epsilon}} &amp;&amp;\text{} 
\end{align}\]

<ul>
  <li><strong>Benefits</strong>: Solves AdaGrad’s vanishing learning rate problem</li>
  <li><strong>Limitation</strong>: Still diagonal preconditioning - no parameter correlations</li>
</ul>

<h3 id="4-full-matrix-adagrad">4. Full-Matrix AdaGrad</h3>

<p>The theoretical ideal would be <strong>full-matrix preconditioning</strong>:</p>

\[\begin{align} G_{full} &amp;= \Sigma G \otimes G^T &amp;&amp;\text{accumulate full outer product as approx to covariance matrix} \\ 
\theta &amp;= \theta -\eta G_{full}^{-\frac{1}{2}}\otimes g &amp;&amp;\text{use full matrix inverse}
\end{align}\]

<p><strong>Why it’s impossible for larger models</strong>:</p>

<ul>
  <li><strong>Memory</strong>: \(O(d²)\) - For 1M parameters: 1TB just for one matrix</li>
  <li><strong>Computation</strong>: \(O(d³)\) matrix inversion every step</li>
  <li><strong>Example</strong>: GPT-3 has 175B parameters → impossible</li>
</ul>

<h2 id="shampoo-structured-preconditioning">Shampoo: Structured Preconditioning</h2>

<p>Shampoo, introduced in “Shampoo: Preconditioned Stochastic Tensor Optimization” (2018).</p>

<h3 id="core-insight-kronecker-product-approximation">Core Insight: Kronecker Product Approximation</h3>

<p>Instead of treating a weight matrix \(\Theta ∈ R^{m×n}\) as a flat vector of \(m\cdot n\) parameters, Shampoo maintains separate statistics for each dimension:</p>

\[\begin{align}
L &amp;=\Sigma G \otimes G^T &amp;&amp;\text{Left preconditioner m x m: row correlations}\\
R &amp;=\Sigma G^T\otimes G &amp;&amp;\text{right preconditioner n x n: column correlations}\\
\theta &amp;= \theta -\eta L^{-\frac{1}{4}} \otimes G \otimes R^{-\frac{1}{4}}
\end{align}\]

<p><strong>The approximation</strong>: This is equivalent to approximating the full \(mn×mn\) preconditioner as:</p>

\[G_{full} \approx L^{\frac{1}{2}}\otimes R^{\frac{1}{2}} \quad \text{Kronecker Product Approximation}\]

<h3 id="memory-and-computation-savings">Memory and Computation Savings</h3>

<p>For a layer with weight matrix [1000×2000]:</p>

<ul>
  <li>Full-matrix AdaGrad: 2M×2M = 4T</li>
  <li>Shampoo: 1000² + 2000² = 5M</li>
  <li>Savings: 800,000 fold</li>
</ul>

<h2 id="simple-shampoo-implementation">Simple Shampoo Implementation</h2>

<p>Here’s a basic PyTorch implementation:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="n">optim</span>
<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">TensorDataset</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>


<span class="k">class</span> <span class="nc">SimpleShampoo</span><span class="p">(</span><span class="n">optim</span><span class="p">.</span><span class="n">Optimizer</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">,</span> <span class="n">update_freq</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">defaults</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">update_freq</span><span class="o">=</span><span class="n">update_freq</span><span class="p">)</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">param_groups</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s">"params"</span><span class="p">]:</span>
                <span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
                    <span class="k">continue</span>

                <span class="n">grad</span> <span class="o">=</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span><span class="p">.</span><span class="n">data</span>
                <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">state</span><span class="p">[</span><span class="n">p</span><span class="p">]</span>

                <span class="n">preconditioned_grad</span> <span class="o">=</span> <span class="n">grad</span>

                <span class="c1"># Initialize state
</span>                <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                    <span class="n">state</span><span class="p">[</span><span class="s">"step"</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
                    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">grad</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>  <span class="c1"># Vector - use diagonal
</span>                        <span class="n">state</span><span class="p">[</span><span class="s">"G"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span> <span class="o">+</span> <span class="n">group</span><span class="p">[</span><span class="s">"eps"</span><span class="p">]</span>
                    <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">grad</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>  <span class="c1"># Matrix - use Shampoo
</span>                        <span class="n">m</span><span class="p">,</span> <span class="n">n</span> <span class="o">=</span> <span class="n">grad</span><span class="p">.</span><span class="n">shape</span>
                        <span class="n">state</span><span class="p">[</span><span class="s">"L"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">grad</span><span class="p">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s">"eps"</span><span class="p">]</span>
                        <span class="n">state</span><span class="p">[</span><span class="s">"R"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">grad</span><span class="p">.</span><span class="n">device</span><span class="p">)</span> <span class="o">*</span> <span class="n">group</span><span class="p">[</span><span class="s">"eps"</span><span class="p">]</span>
                    <span class="k">else</span><span class="p">:</span>  <span class="c1"># Higher order - fallback to diagonal
</span>                        <span class="n">state</span><span class="p">[</span><span class="s">"G"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span>

                <span class="n">state</span><span class="p">[</span><span class="s">"step"</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span>

                <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">grad</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>  <span class="c1"># Matrix case
</span>                    <span class="c1"># Update statistics
</span>                    <span class="n">state</span><span class="p">[</span><span class="s">"L"</span><span class="p">]</span> <span class="o">+=</span> <span class="n">grad</span> <span class="o">@</span> <span class="n">grad</span><span class="p">.</span><span class="n">T</span>
                    <span class="n">state</span><span class="p">[</span><span class="s">"R"</span><span class="p">]</span> <span class="o">+=</span> <span class="n">grad</span><span class="p">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">grad</span>

                    <span class="c1"># Compute preconditioned gradient every update_freq steps
</span>                    <span class="k">if</span> <span class="n">state</span><span class="p">[</span><span class="s">"step"</span><span class="p">]</span> <span class="o">%</span> <span class="n">group</span><span class="p">[</span><span class="s">"update_freq"</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                        <span class="c1"># Compute matrix power: M^(-1/4)
</span>                        <span class="n">L_eig_vals</span><span class="p">,</span> <span class="n">L_eig_vecs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">state</span><span class="p">[</span><span class="s">"L"</span><span class="p">])</span>
                        <span class="n">R_eig_vals</span><span class="p">,</span> <span class="n">R_eig_vecs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">eigh</span><span class="p">(</span><span class="n">state</span><span class="p">[</span><span class="s">"R"</span><span class="p">])</span>
                        <span class="n">L_inv_quarter</span> <span class="o">=</span> <span class="p">(</span>
                            <span class="n">L_eig_vecs</span>
                            <span class="o">@</span> <span class="n">torch</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span>
                                <span class="n">torch</span><span class="p">.</span><span class="nb">pow</span><span class="p">(</span>
                                    <span class="n">torch</span><span class="p">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">L_eig_vals</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s">"eps"</span><span class="p">]),</span> <span class="o">-</span><span class="mf">0.25</span>
                                <span class="p">)</span>
                            <span class="p">)</span>
                            <span class="o">@</span> <span class="n">L_eig_vecs</span><span class="p">.</span><span class="n">T</span>
                        <span class="p">)</span>
                        <span class="n">R_inv_quarter</span> <span class="o">=</span> <span class="p">(</span>
                            <span class="n">R_eig_vecs</span>
                            <span class="o">@</span> <span class="n">torch</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span>
                                <span class="n">torch</span><span class="p">.</span><span class="nb">pow</span><span class="p">(</span>
                                    <span class="n">torch</span><span class="p">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">R_eig_vals</span><span class="p">,</span> <span class="nb">min</span><span class="o">=</span><span class="n">group</span><span class="p">[</span><span class="s">"eps"</span><span class="p">]),</span> <span class="o">-</span><span class="mf">0.25</span>
                                <span class="p">)</span>
                            <span class="p">)</span>
                            <span class="o">@</span> <span class="n">R_eig_vecs</span><span class="p">.</span><span class="n">T</span>
                        <span class="p">)</span>

                        <span class="n">state</span><span class="p">[</span><span class="s">"L_inv_quarter"</span><span class="p">]</span> <span class="o">=</span> <span class="n">L_inv_quarter</span>
                        <span class="n">state</span><span class="p">[</span><span class="s">"R_inv_quarter"</span><span class="p">]</span> <span class="o">=</span> <span class="n">R_inv_quarter</span>

                    <span class="c1"># Apply preconditioned update
</span>                    <span class="k">if</span> <span class="s">"L_inv_quarter"</span> <span class="ow">in</span> <span class="n">state</span><span class="p">:</span>
                        <span class="n">preconditioned_grad</span> <span class="o">=</span> <span class="p">(</span>
                            <span class="n">state</span><span class="p">[</span><span class="s">"L_inv_quarter"</span><span class="p">]</span> <span class="o">@</span> <span class="n">grad</span> <span class="o">@</span> <span class="n">state</span><span class="p">[</span><span class="s">"R_inv_quarter"</span><span class="p">]</span>
                        <span class="p">)</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="n">preconditioned_grad</span> <span class="o">=</span> <span class="n">grad</span>

                <span class="k">else</span><span class="p">:</span>  <span class="c1"># Vector or tensor - use diagonal
</span>                    <span class="n">state</span><span class="p">[</span><span class="s">"G"</span><span class="p">]</span> <span class="o">+=</span> <span class="n">grad</span> <span class="o">*</span> <span class="n">grad</span>
                    <span class="n">preconditioned_grad</span> <span class="o">=</span> <span class="n">grad</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">state</span><span class="p">[</span><span class="s">"G"</span><span class="p">])</span> <span class="o">+</span> <span class="n">group</span><span class="p">[</span><span class="s">"eps"</span><span class="p">])</span>

                <span class="c1"># Update parameters
</span>                <span class="n">p</span><span class="p">.</span><span class="n">data</span> <span class="o">-=</span> <span class="n">group</span><span class="p">[</span><span class="s">"lr"</span><span class="p">]</span> <span class="o">*</span> <span class="n">preconditioned_grad</span>


<span class="n">optimizers</span> <span class="o">=</span> <span class="p">{</span><span class="s">"Adam"</span><span class="p">:</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">,</span> <span class="s">"AdaGrad"</span><span class="p">:</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adagrad</span><span class="p">,</span> <span class="s">"Shampoo"</span><span class="p">:</span> <span class="n">SimpleShampoo</span><span class="p">}</span>


<span class="c1"># Create a ill-conditioned quadratic problem
</span><span class="k">def</span> <span class="nf">create_ill_conditioned_data</span><span class="p">(</span><span class="n">num_samples</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">condition_number</span><span class="o">=</span><span class="mi">1000</span><span class="p">):</span>
    <span class="s">"""Create data where some features are much more important than others"""</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
    <span class="n">input_dim</span> <span class="o">=</span> <span class="mi">50</span>

    <span class="c1"># Create ill-conditioned covariance matrix
</span>    <span class="n">U</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">qr</span><span class="p">(</span>
        <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">)</span>
    <span class="p">)</span>  <span class="c1"># Random orthogonal matrix
</span>    <span class="n">eigenvals</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span>
        <span class="mi">0</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">log10</span><span class="p">(</span><span class="n">condition_number</span><span class="p">),</span> <span class="n">input_dim</span>
    <span class="p">)</span>  <span class="c1"># Large condition number
</span>    <span class="n">cov_matrix</span> <span class="o">=</span> <span class="n">U</span> <span class="o">@</span> <span class="n">torch</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">eigenvals</span><span class="p">)</span> <span class="o">@</span> <span class="n">U</span><span class="p">.</span><span class="n">T</span>

    <span class="c1"># Generate correlated features
</span>    <span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">num_samples</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">)</span> <span class="o">@</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">cholesky</span><span class="p">(</span><span class="n">cov_matrix</span><span class="p">)</span>

    <span class="c1"># True weights with different scales (some very important, some not)
</span>    <span class="n">true_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">input_dim</span><span class="p">)</span>
    <span class="n">true_weights</span><span class="p">[:</span><span class="mi">5</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> <span class="o">*</span> <span class="mi">10</span>  <span class="c1"># Very important features
</span>    <span class="n">true_weights</span><span class="p">[</span><span class="mi">5</span><span class="p">:</span><span class="mi">15</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="o">*</span> <span class="mi">1</span>  <span class="c1"># Moderately important
</span>    <span class="n">true_weights</span><span class="p">[</span><span class="mi">15</span><span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">35</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>  <span class="c1"># Less important
</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">X</span> <span class="o">@</span> <span class="n">true_weights</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">num_samples</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>
    <span class="k">return</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">train_model</span><span class="p">(</span><span class="n">optimizer_name</span><span class="o">=</span><span class="s">"Shampoo"</span><span class="p">):</span>
    <span class="k">assert</span> <span class="n">optimizer_name</span> <span class="ow">in</span> <span class="n">optimizers</span><span class="p">,</span> <span class="s">"Expecting a known optimizer in "</span> <span class="o">+</span> <span class="s">", "</span><span class="p">.</span><span class="n">join</span><span class="p">(</span>
        <span class="n">optimizers</span><span class="p">.</span><span class="n">keys</span><span class="p">()</span>
    <span class="p">)</span>

    <span class="c1"># Create data
</span>    <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">create_ill_conditioned_data</span><span class="p">()</span>
    <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

    <span class="c1"># Create model
</span>    <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>

    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizers</span><span class="p">[</span><span class="n">optimizer_name</span><span class="p">](</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
    <span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>

    <span class="k">print</span><span class="p">(</span><span class="s">"Starting training..."</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Using optimizer: </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">optimizer</span><span class="p">).</span><span class="n">__name__</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

    <span class="c1"># Training loop
</span>    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
        <span class="n">total_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">num_batches</span> <span class="o">=</span> <span class="mi">0</span>

        <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_x</span><span class="p">,</span> <span class="n">batch_y</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
            <span class="c1"># Forward pass
</span>            <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
            <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">batch_x</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">batch_y</span><span class="p">)</span>

            <span class="c1"># Backward pass
</span>            <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
            <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>

            <span class="n">total_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span>
            <span class="n">num_batches</span> <span class="o">+=</span> <span class="mi">1</span>

        <span class="n">avg_loss</span> <span class="o">=</span> <span class="n">total_loss</span> <span class="o">/</span> <span class="n">num_batches</span>
        <span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="mi">5</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="mi">2</span><span class="n">d</span><span class="si">}</span><span class="s">: Average Loss = </span><span class="si">{</span><span class="n">avg_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">6</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

    <span class="k">print</span><span class="p">(</span><span class="s">"Training completed!"</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">model</span><span class="p">,</span> <span class="n">avg_loss</span>


<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">"__main__"</span><span class="p">:</span>
    <span class="k">for</span> <span class="n">optimizer_name</span> <span class="ow">in</span> <span class="n">optimizers</span><span class="p">.</span><span class="n">keys</span><span class="p">():</span>
        <span class="k">print</span><span class="p">(</span><span class="s">"Training with "</span><span class="p">,</span> <span class="n">optimizer_name</span><span class="p">,</span> <span class="s">"Optimizer:"</span><span class="p">)</span>
        <span class="n">model</span><span class="p">,</span> <span class="n">final_loss</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">optimizer_name</span><span class="p">)</span>

</code></pre></div></div>

<h2 id="improvements-and-practical-considerations">Improvements and Practical Considerations</h2>

<h3 id="1-grafting-for-stability">1. Grafting for Stability</h3>

<p><a href="https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py">Google’s implementation</a> uses “grafting” to fix the layerwise scale of Shampoo updates:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Compute both updates
</span><span class="n">shampoo_update</span> <span class="o">=</span> <span class="n">L_inv</span> <span class="o">@</span> <span class="n">grad</span> <span class="o">@</span> <span class="n">R_inv</span>
<span class="n">diagonal_update</span> <span class="o">=</span> <span class="n">grad</span> <span class="o">/</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">accumulated_grad_squares</span><span class="p">)</span>

<span class="c1"># Scale Shampoo to match diagonal magnitude
</span><span class="n">scale</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">diagonal_update</span><span class="p">)</span> <span class="o">/</span> <span class="n">norm</span><span class="p">(</span><span class="n">shampoo_update</span><span class="p">)</span>
<span class="n">final_update</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">shampoo_update</span>
</code></pre></div></div>

<h3 id="2-delayed-preconditioning">2. Delayed Preconditioning</h3>

<p>Start with simpler methods, then gradually transition to Shampoo after warm-up steps</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">if</span> <span class="n">step</span> <span class="o">&lt;</span> <span class="n">start_preconditioning_steps</span><span class="p">:</span>
    <span class="n">update</span> <span class="o">=</span> <span class="n">diagonal_update</span>  <span class="c1"># Use AdaGrad initially
</span><span class="k">else</span><span class="p">:</span>
    <span class="n">warmup_factor</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="p">(</span><span class="n">step</span> <span class="o">-</span> <span class="n">start_steps</span><span class="p">)</span> <span class="o">/</span> <span class="n">start_steps</span><span class="p">)</span>
    <span class="n">update</span> <span class="o">=</span> <span class="n">warmup_factor</span> <span class="o">*</span> <span class="n">shampoo_update</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">warmup_factor</span><span class="p">)</span> <span class="o">*</span> <span class="n">diagonal_update</span>
</code></pre></div></div>

<h3 id="3-soap-adam-in-shampoos-eigenbasis">3. SOAP: Adam in Shampoo’s Eigenbasis</h3>

<p>Recent work introduced SOAP, which runs “Adam in the Preconditioner’s eigenbasis”</p>

<ol>
  <li>Decompose preconditioner: \(P = Q \otimes Λ \otimes Q.T\)</li>
  <li>Transform gradient: \(g_rotated = Q.T \otimes grad \otimes Q\)</li>
  <li>Run Adam on g_rotated</li>
  <li>Transform back: \(\text{update} = Q \otimes \text{adam_update} \otimes Q.T\)</li>
</ol>

<h2 id="production-implementations">Production Implementations</h2>

<h3 id="tensorflowlingvo-implementation">TensorFlow/Lingvo Implementation</h3>

<p>The TensorFlow implementation focuses on practical deployment with CPU-based preconditioner computation:</p>

<p><strong>Key Features</strong>:</p>

<ul>
  <li><strong>Asynchronous preconditioning</strong>: Expensive matrix operations run on CPU while GPUs continue training</li>
  <li><strong>Simple partitioning</strong>: Splits large tensors when dimensions exceed thresholds</li>
  <li><strong>Grafting integration</strong>: Built-in support for scaling strategies</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">invoke_async_preconditioner_computation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">global_step</span><span class="p">):</span>
    <span class="s">"""Computes preconditioners asynchronously on CPU"""</span>
    <span class="k">return</span> <span class="n">x_ops</span><span class="p">.</span><span class="n">compute_preconditioners</span><span class="p">(</span>
        <span class="n">stats</span><span class="p">,</span> <span class="n">exponents</span><span class="p">,</span> <span class="n">global_step</span><span class="p">,</span>
        <span class="n">sync</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">_synchronous_preconditioning</span><span class="p">,</span>
        <span class="n">preconditioner_compute_graphdef</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">_preconditioner_compute_graphdef</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="jax-distributed-implementation">JAX Distributed Implementation</h3>

<p>The JAX version provides full distributed training support with advanced features:</p>

<p><strong>Advanced Features</strong>:</p>

<ul>
  <li><strong>Quantized statistics</strong>: Reduces memory usage through <code class="language-plaintext highlighter-rouge">QuantizedValue</code> storage</li>
  <li><strong>Sharded computation</strong>: Distributes preconditioner computation across devices</li>
  <li><strong>Global statistics aggregation</strong>: Coordinates statistics across multiple devices</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">struct</span><span class="p">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">ShardedShampooStats</span><span class="p">:</span>
    <span class="n">global_stats</span><span class="p">:</span> <span class="n">Any</span>      <span class="c1"># Statistics aggregated across all devices  
</span>    <span class="n">local_stats</span><span class="p">:</span> <span class="n">Any</span>       <span class="c1"># Device-local statistics
</span>
<span class="k">class</span> <span class="nc">LocalShardedParameterStats</span><span class="p">:</span>
    <span class="n">index_start</span><span class="p">:</span> <span class="nb">int</span>       <span class="c1"># Starting index in global statistics array
</span>    <span class="n">sizes</span><span class="p">:</span> <span class="n">Any</span>            <span class="c1"># Partition sizes for this device
</span></code></pre></div></div>

<p><strong>Distributed Training Flow</strong>:</p>

<ol>
  <li><strong>Local computation</strong>: Each device computes gradients and updates local statistics</li>
  <li><strong>Periodic synchronization</strong>: Every N steps, aggregate statistics across devices</li>
  <li><strong>Centralized preconditioning</strong>: Master device computes preconditioners</li>
  <li><strong>Broadcast updates</strong>: Distribute preconditioners back to all devices</li>
</ol>

<h2 id="references">References</h2>

<p>The key papers and implementations:</p>
<ul>
  <li><a href="http://www.math.iit.edu/~fass/477577_Chapter_16.pdf">Preconditing in iterative solvers</a></li>
  <li><a href="https://arxiv.org/abs/1802.09568">Shampoo: Preconditioned Stochastic Tensor Optimization” (2018)</a></li>
  <li><a href="https://github.com/tensorflow/lingvo/blob/master/lingvo/core/distributed_shampoo.py">Google Tensorflow(Lingvo) Implementation</a></li>
  <li><a href="https://github.com/google-research/google-research/blob/master/scalable_shampoo/jax/shampoo.py">Google JAX Implementation</a></li>
  <li><a href="google-research/scalable_shampoo/pytorch/shampoo.py">Google Pytorch(!) Implementation</a></li>
  <li><a href="https://arxiv.org/abs/2002.09018">Scalable Second Order Optimization for Deep Learning” (2020)</a></li>
  <li><a href="https://arxiv.org/abs/2409.11321">SOAP: Improving and Stabilizing Shampoo using Adam (2024)</a></li>
</ul>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="Shampoo Optimizer" /><category term="Optimizer" /><category term="python" /><summary type="html"><![CDATA[Refresher on machine learning optimizers]]></summary></entry><entry><title type="html">BPE Tokenizer Implementation Exercise</title><link href="https://ryanzhang.info/post/2025/02/14/BPE-implementation-exercise.html" rel="alternate" type="text/html" title="BPE Tokenizer Implementation Exercise" /><published>2025-02-14T19:03:43+00:00</published><updated>2025-02-14T19:03:43+00:00</updated><id>https://ryanzhang.info/post/2025/02/14/BPE-implementation-exercise</id><content type="html" xml:base="https://ryanzhang.info/post/2025/02/14/BPE-implementation-exercise.html"><![CDATA[<!--excerpt.start-->
<p>Partial solution to <a href="https://github.com/karpathy/minbpe/blob/master/exercise.md">BPE Tokenizer Implementation Exercise from Andrej Karpathy</a>.</p>

<p>Corresponding <a href="https://www.youtube.com/watch?v=zduSFxRajkE">youtube video</a> on the tokenizer topic.
<!--excerpt.end--></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">regex</span>
<span class="kn">import</span> <span class="nn">requests</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">Counter</span>


<span class="n">GPT4_SPLIT_PATTERN</span> <span class="o">=</span> <span class="sa">r</span><span class="s">"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""</span>
<span class="n">SHAKESPEAR_TEXT_URL</span> <span class="o">=</span> <span class="s">"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"</span>


<span class="k">class</span> <span class="nc">BPETokenizer</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">pattern</span><span class="o">=</span><span class="n">GPT4_SPLIT_PATTERN</span><span class="p">,</span> <span class="n">special_tokens</span><span class="o">=</span><span class="p">[]):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">pattern</span> <span class="o">=</span> <span class="n">pattern</span>  <span class="c1"># regex pattern to split text into words
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">special_tokens</span> <span class="o">=</span> <span class="n">special_tokens</span>  <span class="c1"># pre-allocated special tokens
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">vocab</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_init_vocab</span><span class="p">()</span>  <span class="c1"># map byte to token id
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">itob</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># the reverse map of vocab, map token id to byte
</span>
    <span class="k">def</span> <span class="nf">_init_vocab</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
        <span class="n">vocab</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="o">**</span><span class="mi">8</span><span class="p">):</span>
            <span class="n">vocab</span><span class="p">[</span><span class="nb">bytes</span><span class="p">([</span><span class="n">i</span><span class="p">])]</span> <span class="o">=</span> <span class="n">i</span>
        <span class="k">for</span> <span class="n">special_token</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">special_tokens</span><span class="p">:</span>
            <span class="n">vocab</span><span class="p">[</span><span class="nb">bytes</span><span class="p">(</span><span class="n">special_token</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="s">"utf-8"</span><span class="p">))]</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">vocab</span>

    <span class="k">def</span> <span class="nf">_get_stats</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">bytes_of_words</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="nb">bytes</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Counter</span><span class="p">:</span>
        <span class="n">counts</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">()</span>
        <span class="c1"># count the frequencey of each adjacent byte pairs result stat will be used to find merge rules.
</span>        <span class="k">for</span> <span class="n">bytes_of_word</span> <span class="ow">in</span> <span class="n">bytes_of_words</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">byte_pair</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bytes_of_word</span><span class="p">,</span> <span class="n">bytes_of_word</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
                <span class="n">counts</span><span class="p">[</span><span class="n">byte_pair</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span>
        <span class="k">return</span> <span class="n">counts</span>

    <span class="c1"># split the text into words then for each word further split into bytes.
</span>    <span class="k">def</span> <span class="nf">_parse_text</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="nb">bytes</span><span class="p">]]:</span>
        <span class="k">return</span> <span class="p">[</span>
            <span class="p">[</span><span class="nb">bytes</span><span class="p">([</span><span class="n">b</span><span class="p">])</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">word</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="s">"utf-8"</span><span class="p">)]</span>
            <span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">regex</span><span class="p">.</span><span class="n">findall</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">pattern</span><span class="p">,</span> <span class="n">text</span><span class="p">)</span>
        <span class="p">]</span>

    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
        <span class="n">bytes_of_words</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_parse_text</span><span class="p">(</span><span class="n">text</span><span class="p">)</span>
        <span class="n">num_merges</span> <span class="o">=</span> <span class="n">vocab_size</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">vocab</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"total </span><span class="si">{</span><span class="n">num_merges</span><span class="si">}</span><span class="s"> merges to learn"</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_merges</span><span class="p">):</span>
            <span class="c1"># find the merge
</span>            <span class="n">counts</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_get_stats</span><span class="p">(</span><span class="n">bytes_of_words</span><span class="p">)</span>
            <span class="n">pair_to_merge</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">counts</span><span class="p">.</span><span class="n">keys</span><span class="p">(),</span> <span class="n">key</span><span class="o">=</span><span class="n">counts</span><span class="p">.</span><span class="n">get</span><span class="p">)</span>
            <span class="n">byte_pair</span> <span class="o">=</span> <span class="sa">b</span><span class="s">""</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">pair_to_merge</span><span class="p">)</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">vocab</span><span class="p">[</span><span class="n">byte_pair</span><span class="p">]</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">vocab</span><span class="p">)</span>

            <span class="c1"># apply the merge to training data
</span>            <span class="n">temp_bytes_of_words</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="k">for</span> <span class="n">bytes_of_word</span> <span class="ow">in</span> <span class="n">bytes_of_words</span><span class="p">:</span>
                <span class="n">temp_bytes_of_word</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="n">just_merged</span> <span class="o">=</span> <span class="bp">False</span>
                <span class="k">for</span> <span class="n">first</span><span class="p">,</span> <span class="n">second</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bytes_of_word</span><span class="p">,</span> <span class="n">bytes_of_word</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
                    <span class="k">if</span> <span class="n">just_merged</span><span class="p">:</span>
                        <span class="n">just_merged</span> <span class="o">=</span> <span class="bp">False</span>
                        <span class="k">continue</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">first</span><span class="p">,</span> <span class="n">second</span><span class="p">)</span> <span class="o">==</span> <span class="n">pair_to_merge</span><span class="p">:</span>
                        <span class="n">temp_bytes_of_word</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">byte_pair</span><span class="p">)</span>
                        <span class="n">just_merged</span> <span class="o">=</span> <span class="bp">True</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="n">temp_bytes_of_word</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">first</span><span class="p">)</span>
                <span class="k">if</span> <span class="ow">not</span> <span class="n">just_merged</span><span class="p">:</span>
                    <span class="n">temp_bytes_of_word</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">bytes_of_word</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
                <span class="n">temp_bytes_of_words</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">temp_bytes_of_word</span><span class="p">)</span>
            <span class="n">bytes_of_words</span> <span class="o">=</span> <span class="n">temp_bytes_of_words</span>

            <span class="k">if</span> <span class="n">verbose</span> <span class="ow">and</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">verbose</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="k">print</span><span class="p">(</span>
                    <span class="sa">f</span><span class="s">"merge discovered at step </span><span class="si">{</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="si">}</span><span class="s"> is : "</span><span class="p">,</span>
                    <span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">pair_to_merge</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s"> + </span><span class="si">{</span><span class="n">pair_to_merge</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s"> -&gt; </span><span class="si">{</span><span class="n">byte_pair</span><span class="si">}</span><span class="s">"</span><span class="p">,</span>
                <span class="p">)</span>

    <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">text</span><span class="p">):</span>
        <span class="n">bytes_of_words</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_parse_text</span><span class="p">(</span><span class="n">text</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">bytes_of_word</span> <span class="ow">in</span> <span class="n">bytes_of_words</span><span class="p">:</span>
            <span class="c1"># speed this up? only one instance of the lowest rank pair gets updated each time
</span>            <span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
                <span class="n">min_idx</span> <span class="o">=</span> <span class="n">min_rank</span> <span class="o">=</span> <span class="n">merged_bytes</span> <span class="o">=</span> <span class="bp">None</span>
                <span class="c1"># find the mergeable byte pairs with the lowest rank
</span>                <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">byte_pair</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">bytes_of_word</span><span class="p">,</span> <span class="n">bytes_of_word</span><span class="p">[</span><span class="mi">1</span><span class="p">:])):</span>
                    <span class="n">rank</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">vocab</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="n">byte_pair</span><span class="p">,</span> <span class="bp">None</span><span class="p">)</span>
                    <span class="k">if</span> <span class="n">rank</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
                        <span class="k">continue</span>

                    <span class="k">if</span> <span class="n">min_rank</span> <span class="ow">is</span> <span class="bp">None</span> <span class="ow">or</span> <span class="n">min_rank</span> <span class="o">&gt;</span> <span class="n">rank</span><span class="p">:</span>
                        <span class="n">min_rank</span> <span class="o">=</span> <span class="n">rank</span>
                        <span class="n">min_idx</span> <span class="o">=</span> <span class="n">i</span>
                        <span class="n">merged_bytes</span> <span class="o">=</span> <span class="sa">b</span><span class="s">""</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">byte_pair</span><span class="p">)</span>

                <span class="k">if</span> <span class="n">min_rank</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
                    <span class="k">break</span>
                <span class="n">bytes_of_word</span> <span class="o">=</span> <span class="p">(</span>
                    <span class="n">bytes_of_word</span><span class="p">[:</span><span class="n">min_idx</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">merged_bytes</span><span class="p">]</span> <span class="o">+</span> <span class="n">bytes_of_word</span><span class="p">[</span><span class="n">min_idx</span> <span class="o">+</span> <span class="mi">2</span><span class="p">:]</span>
                <span class="p">)</span>
        <span class="n">token_ids</span> <span class="o">=</span> <span class="p">[</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">vocab</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="k">for</span> <span class="n">bytes_of_word</span> <span class="ow">in</span> <span class="n">bytes_of_words</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">bytes_of_word</span>
        <span class="p">]</span>
        <span class="k">return</span> <span class="n">token_ids</span>

    <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">token_ids</span><span class="p">):</span>
        <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">.</span><span class="n">itob</span><span class="p">:</span> <span class="bp">self</span><span class="p">.</span><span class="n">itob</span> <span class="o">=</span> <span class="p">{</span><span class="n">v</span><span class="p">:</span><span class="n">k</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span><span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">vocab</span><span class="p">.</span><span class="n">items</span><span class="p">()}</span>
        <span class="k">return</span> <span class="sa">b</span><span class="s">""</span><span class="p">.</span><span class="n">join</span><span class="p">((</span><span class="bp">self</span><span class="p">.</span><span class="n">itob</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">token_ids</span><span class="p">)).</span><span class="n">decode</span><span class="p">(</span><span class="s">"utf-8"</span><span class="p">)</span> 

    <span class="k">def</span> <span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">pass</span>

    <span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">pass</span>

<span class="k">def</span> <span class="nf">read_text_from_url</span><span class="p">(</span><span class="n">url</span><span class="p">):</span>
    <span class="k">try</span><span class="p">:</span>
        <span class="n">response</span> <span class="o">=</span> <span class="n">requests</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>
        <span class="n">response</span><span class="p">.</span><span class="n">raise_for_status</span><span class="p">()</span>  <span class="c1"># Raise HTTPError for bad responses (4xx or 5xx)
</span>        <span class="k">return</span> <span class="n">response</span><span class="p">.</span><span class="n">text</span>
    <span class="k">except</span> <span class="n">requests</span><span class="p">.</span><span class="n">exceptions</span><span class="p">.</span><span class="n">RequestException</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Error fetching data from URL: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">None</span>


<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">"__main__"</span><span class="p">:</span>
    <span class="n">text</span> <span class="o">=</span> <span class="n">read_text_from_url</span><span class="p">(</span><span class="n">SHAKESPEAR_TEXT_URL</span><span class="p">)</span>
    <span class="n">a</span> <span class="o">=</span> <span class="n">BPETokenizer</span><span class="p">(</span><span class="n">special_tokens</span><span class="o">=</span><span class="p">[</span><span class="s">"&lt;bos&gt;"</span><span class="p">,</span> <span class="s">"&lt;eos&gt;"</span><span class="p">,</span> <span class="s">"&lt;pad&gt;"</span><span class="p">,</span> <span class="s">"&lt;unk&gt;"</span><span class="p">])</span>
    <span class="n">a</span><span class="p">.</span><span class="n">train</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">vocab_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
    <span class="n">encoded</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">text</span><span class="p">[:</span><span class="mi">512</span><span class="p">])</span>
    <span class="k">print</span><span class="p">(</span><span class="n">text</span><span class="p">[:</span><span class="mi">512</span><span class="p">],</span> <span class="s">"</span><span class="se">\n</span><span class="s"> encoded as: </span><span class="se">\n</span><span class="s">"</span><span class="p">,</span> <span class="n">encoded</span><span class="p">)</span>
    <span class="n">decoded</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">decode</span><span class="p">(</span><span class="n">encoded</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"decoded: "</span><span class="p">,</span> <span class="n">decoded</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"equal to original text? "</span><span class="p">,</span> <span class="n">decoded</span> <span class="o">==</span> <span class="n">text</span><span class="p">[:</span><span class="mi">512</span><span class="p">])</span>
</code></pre></div></div>

<p>Some random after thoughts:</p>
<ol>
  <li>The text used to train the tokenizer should ideally match the training/inference text distribution. If the training and inference distribution are quite different, maybe use a separated tokenizer. For example, the output is English comment and code only, while the input can be multi-language and more descriptive of the code we want to generate. Can we use a tokenizer of a smaller vocab size for output?</li>
  <li>If a token id is never seen during the training run, its embedding will be random, prompting the model with such a token will cause undefined behavior. Eg., <a href="https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation">“solidgoldmagikarp”</a>. Maybe run frequency counter on the token id seen during training, reject bad input, or reserve a unk token and map bad token to it?</li>
  <li>For multilingual models, the tokenizer might be an important factor in determining the less performant language. Balancing the language mixture in tokenizer training data may help.</li>
  <li>Larger vocabulary size leads to shorter encoded sequences, which allows more information to be retrained in the limited context window and, therefore, improves performance. On the flip side, it will require more memory for training and make the softmax more expensive at inference.</li>
</ol>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="Byte Pair Encoding" /><category term="Tokenizer" /><category term="LLM" /><category term="python" /><summary type="html"><![CDATA[Partial solution to BPE Tokenizer Implementation Exercise from Andrej Karpathy. Corresponding youtube video on the tokenizer topic. import regex import requests from collections import Counter GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" SHAKESPEAR_TEXT_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" class BPETokenizer: def __init__(self, *, pattern=GPT4_SPLIT_PATTERN, special_tokens=[]): self.pattern = pattern # regex pattern to split text into words self.special_tokens = special_tokens # pre-allocated special tokens self.vocab = self._init_vocab() # map byte to token id self.itob = {} # the reverse map of vocab, map token id to byte def _init_vocab(self) -&gt; dict: vocab = {} for i in range(2**8): vocab[bytes([i])] = i for special_token in self.special_tokens: vocab[bytes(special_token.encode("utf-8"))] = len(vocab) return vocab def _get_stats(self, bytes_of_words: list[list[bytes]]) -&gt; Counter: counts = Counter() # count the frequencey of each adjacent byte pairs result stat will be used to find merge rules. for bytes_of_word in bytes_of_words: for byte_pair in zip(bytes_of_word, bytes_of_word[1:]): counts[byte_pair] += 1 return counts # split the text into words then for each word further split into bytes. def _parse_text(self, text: str) -&gt; list[list[bytes]]: return [ [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(self.pattern, text) ] def train(self, text: str, vocab_size: int, verbose: int = 0): bytes_of_words = self._parse_text(text) num_merges = vocab_size - len(self.vocab) if verbose: print(f"total {num_merges} merges to learn") for step in range(num_merges): # find the merge counts = self._get_stats(bytes_of_words) pair_to_merge = max(counts.keys(), key=counts.get) byte_pair = b"".join(pair_to_merge) self.vocab[byte_pair] = len(self.vocab) # apply the merge to training data temp_bytes_of_words = [] for bytes_of_word in bytes_of_words: temp_bytes_of_word = [] just_merged = False for first, second in zip(bytes_of_word, bytes_of_word[1:]): if just_merged: just_merged = False continue if (first, second) == pair_to_merge: temp_bytes_of_word.append(byte_pair) just_merged = True else: temp_bytes_of_word.append(first) if not just_merged: temp_bytes_of_word.append(bytes_of_word[-1]) temp_bytes_of_words.append(temp_bytes_of_word) bytes_of_words = temp_bytes_of_words if verbose and (step + 1) % verbose == 0: print( f"merge discovered at step {step + 1} is : ", f"{pair_to_merge[0]} + {pair_to_merge[1]} -&gt; {byte_pair}", ) def encode(self, text): bytes_of_words = self._parse_text(text) for bytes_of_word in bytes_of_words: # speed this up? only one instance of the lowest rank pair gets updated each time while True: min_idx = min_rank = merged_bytes = None # find the mergeable byte pairs with the lowest rank for i, byte_pair in enumerate(zip(bytes_of_word, bytes_of_word[1:])): rank = self.vocab.get(byte_pair, None) if rank is None: continue if min_rank is None or min_rank &gt; rank: min_rank = rank min_idx = i merged_bytes = b"".join(byte_pair) if min_rank is None: break bytes_of_word = ( bytes_of_word[:min_idx] + [merged_bytes] + bytes_of_word[min_idx + 2:] ) token_ids = [ self.vocab[b] for bytes_of_word in bytes_of_words for b in bytes_of_word ] return token_ids def decode(self, token_ids): if not self.itob: self.itob = {v:k for k,v in self.vocab.items()} return b"".join((self.itob[i] for i in token_ids)).decode("utf-8") def save(self): pass def load(self): pass def read_text_from_url(url): try: response = requests.get(url) response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) return response.text except requests.exceptions.RequestException as e: print(f"Error fetching data from URL: {e}") return None if __name__ == "__main__": text = read_text_from_url(SHAKESPEAR_TEXT_URL) a = BPETokenizer(special_tokens=["&lt;bos&gt;", "&lt;eos&gt;", "&lt;pad&gt;", "&lt;unk&gt;"]) a.train(text, vocab_size=1024, verbose=100) encoded = a.encode(text[:512]) print(text[:512], "\n encoded as: \n", encoded) decoded = a.decode(encoded) print("decoded: ", decoded) print("equal to original text? ", decoded == text[:512]) Some random after thoughts: The text used to train the tokenizer should ideally match the training/inference text distribution. If the training and inference distribution are quite different, maybe use a separated tokenizer. For example, the output is English comment and code only, while the input can be multi-language and more descriptive of the code we want to generate. Can we use a tokenizer of a smaller vocab size for output? If a token id is never seen during the training run, its embedding will be random, prompting the model with such a token will cause undefined behavior. Eg., “solidgoldmagikarp”. Maybe run frequency counter on the token id seen during training, reject bad input, or reserve a unk token and map bad token to it? For multilingual models, the tokenizer might be an important factor in determining the less performant language. Balancing the language mixture in tokenizer training data may help. Larger vocabulary size leads to shorter encoded sequences, which allows more information to be retrained in the limited context window and, therefore, improves performance. On the flip side, it will require more memory for training and make the softmax more expensive at inference.]]></summary></entry><entry><title type="html">RLHF Reading Notes 1</title><link href="https://ryanzhang.info/post/2025/02/11/RLHF-reading-notes-1.html" rel="alternate" type="text/html" title="RLHF Reading Notes 1" /><published>2025-02-11T03:53:43+00:00</published><updated>2025-02-11T03:53:43+00:00</updated><id>https://ryanzhang.info/post/2025/02/11/RLHF-reading-notes-1</id><content type="html" xml:base="https://ryanzhang.info/post/2025/02/11/RLHF-reading-notes-1.html"><![CDATA[<h1 id="table-of-contents">Table of Contents</h1>
<ol>
  <li><a href="#glossed-overview-rlhf-for-llm">Glossed Overview: RLHF for LLM</a></li>
  <li><a href="#further-background-readings">Further Background Readings</a>
    <ol>
      <li><a href="#rlhf-in-deep-reinforcement-learning">RLHF in Deep Reinforcement Learning</a></li>
      <li><a href="#learning-to-summarize-from-human-feedback">Learning to summarize from human feedback</a></li>
      <li><a href="#webgpt-browser-assisted-question-answering-with-human-feedback">Webgpt: Browser-assisted question-answering with human feedback</a></li>
      <li><a href="#training-language-models-to-follow-instructions-with-human-feedback">Training language models to follow instructions with human feedback</a></li>
      <li><a href="#training-a-helpful-and-harmless-assistant-with-reinforcement-learning-from-human-feedback"> Training a helpful and harmless assistant with reinforcement learning from human feedback</a></li>
    </ol>
  </li>
</ol>

<h2 id="glossed-overview-rlhf-for-llm">Glossed Overview: RLHF for LLM</h2>

<p>Reinforcement Learning from Human Feedback (RLHF) is a technique to integrate human preferences into AI systems, particularly for problems that are difficult to define explicitly.</p>

<p>The core RLHF process involves three steps:</p>

<ol>
  <li>training a capable language model</li>
  <li>collecting human preference data to train a reward model</li>
  <li>optimizing the language model using reinforcement learning guided by the reward model.</li>
</ol>

<p>RLHF is a crucial part of “post-training” for LLM, a set of techniques to enhance model usability, including:</p>

<ol>
  <li>Supervised Instructional Finetuning for learning features of language that form the basis of the desired output format and the ability of instruction following.</li>
  <li>Preference Finetuning for learning the output style and subtle alignment with human preferences and</li>
  <li>Reinforcement Finetuning for further performance boosts in verifiable domains</li>
</ol>

<h2 id="further-background-readings">Further Background Readings</h2>

<h3 id="rlhf-in-deep-reinforcement-learning">RLHF in Deep Reinforcement Learning</h3>

<p><a href="https://arxiv.org/abs/1706.03741">P. F. Christiano, J. Leike, T. Brown, M. Martic, S. Legg, and D. Amodei, “Deep reinforcement learning from human preferences,” <em>Advances in neural information processing systems</em>, vol. 30, 2017.</a></p>

<p><strong>Challenge</strong>: Difficulty in Specifying Reward Functions</p>

<p>Manually designing reward functions for complex tasks is incredibly difficult and often leads to unintended or suboptimal agent behavior.</p>

<p><strong>Proposal</strong>: Learning from Human Preferences Instead of Explicit Rewards</p>

<p>Instead of trying to define a reward function directly, learn a reward function <em>from human judgments</em> about which behavior is better.</p>

<p><strong>Details</strong>:</p>

<ol>
  <li><strong>Generate Trajectory Pairs:</strong> The agent performs the task and generates pairs of trajectories (sequences of actions and states).</li>
  <li><strong>Human Preference Judgments:</strong> Humans are presented with these pairs of trajectories and asked to choose which one they prefer (which trajectory is “better” according to some criteria). Crucially, humans don’t need to explicitly define <em>why</em> one is better, just to indicate their preference.</li>
  <li><strong>Reward Model Training:</strong> These human preference judgments are used to train a <strong>reward model</strong>. This reward model learns to predict which trajectory a human would prefer. Essentially, it learns to approximate the underlying, implicit reward function based on human feedback.</li>
  <li><strong>Reinforcement Learning with Learned Reward Model:</strong> The trained reward model is then used as the reward signal for a standard deep reinforcement learning algorithm (like policy gradients or Q-learning). The agent is trained to maximize the reward predicted by the reward model, which in turn is aligned with human preferences.</li>
</ol>

<h3 id="learning-to-summarize-from-human-feedback">Learning to summarize from human feedback</h3>

<p><a href="https://arxiv.org/abs/2009.01325">N. Stiennon <em>et al.</em>, “Learning to summarize with human feedback,” <em>Advances in Neural Information Processing Systems</em>, vol. 33, pp. 3008–3021, 2020.</a></p>

<p><strong>Challenge</strong>: Limitations of Traditional Summarization Metrics and Methods.</p>

<p>Traditional automatic summarization methods, often optimized using metrics like ROUGE, don’t always align well with human preferences for good summaries. ROUGE primarily measures n-gram overlap with reference summaries, which can be a crude proxy for summary quality. Furthermore, directly optimizing for metrics like ROUGE can lead to models that generate summaries that are grammatically correct but lack coherence, focus, or truly capture the essence of the original text as a human would.</p>

<p><strong>Proposal</strong>: Training Summarization Models with Human Preference Feedback.</p>

<p>Similar to the Christiano et al. (2017) paper, this work proposes to move away from solely relying on automatic metrics and instead train summarization models using direct human feedback on the quality of generated summaries. The idea is to teach the model to generate summaries that humans <em>prefer</em>, rather than just those that score well on automatic metrics.</p>

<p><strong>Details:</strong></p>
<ol>
  <li><strong>Pre-training a Summarization Model:</strong>
    Pre-train a sequence-to-sequence model for summarization.</li>
  <li><strong>Collecting Human Preference Data (Comparison Data):</strong>
    Collect human judgments by presenting human annotators with pairs of summaries generated by different models (or different versions of the same model). The annotators are asked to choose which summary is better based on criteria like:
    <ul>
      <li><strong>Helpfulness:</strong> Is the summary informative and useful?</li>
      <li><strong>Relevance:</strong> Does the summary accurately reflect the content of the original document?</li>
      <li><strong>Readability:</strong> Is the summary well-written and easy to understand?</li>
      <li><strong>Non-redundancy:</strong> Does the summary avoid unnecessary repetition?</li>
    </ul>
  </li>
  <li><strong>Training a Reward Model:</strong> 
    The collected human preference data (pairs of summaries and the preferred one) is used to train a <strong>reward model</strong>. This reward model learns to predict which summary a human would prefer given an input document. The reward model is trained to assign higher scores to summaries that humans tend to prefer.</li>
  <li><strong>Fine-tuning the Summarization Model with Reinforcement Learning:</strong> 
    The pre-trained summarization model is then fine-tuned using reinforcement learning. The reward signal for RL is provided by the trained reward model. The RL objective is to generate summaries that maximize the score given by the reward model, effectively guiding the summarization model towards generating summaries that are more human-preferred. They used Proximal Policy Optimization (PPO) algorithm for this RL fine-tuning stage.</li>
</ol>

<h3 id="webgpt-browser-assisted-question-answering-with-human-feedback">Webgpt: Browser-assisted question-answering with human feedback</h3>
<p><a href="https://arxiv.org/abs/2112.09332">R. Nakano et al., “Webgpt: Browser-assisted question-answering with human feedback,” arXiv preprint arXiv:2112.09332, 2021.</a></p>

<p><strong>Challenge</strong>: Limitations of Traditional Question Answering and the Need for Browser Assistance:</p>

<p>Traditional question-answering (QA) models rely solely on their internal knowledge or pre-indexed datasets. Many real-world questions require accessing and processing information from the open web to provide comprehensive and up-to-date answers. Furthermore, simply retrieving documents isn’t enough; the model needs to effectively browse, extract relevant information, and synthesize it into a coherent answer.</p>

<p><strong>Proposal</strong>: WebGPT - A Browser-Assisted QA Model Trained with Human Feedback</p>

<p><strong>WebGPT</strong>, a model that is trained to use a web browser to answer questions. It’s not just a language model; it’s an agent that can interact with the web in a controlled manner, including searching, clicking links, scrolling, and reading web pages. Crucially, WebGPT is trained using Reinforcement Learning from Human Feedback (RLHF) to generate answers that are helpful, truthful, and harmless.</p>

<p><strong>Details</strong>: Browser-in-the-Loop Question Answering with RLHF:</p>
<ol>
  <li><strong>Browser Environment:</strong> 
    They created a simulated browser environment that WebGPT can interact with. This environment provides actions like searching, clicking links, scrolling, and observing the rendered web page content.</li>
  <li><strong>WebGPT Agent:</strong> 
    WebGPT is a Transformer-based language model trained to act as an agent within this browser environment. Given a question, it decides on a sequence of browser actions to gather information and ultimately generate an answer.</li>
  <li><strong>Human Feedback Collection:</strong> 
    Human evaluators are crucial. They are asked to compare pairs of answers generated by different models (including WebGPT and baseline models) and indicate which answer is better based on criteria like:
    <ul>
      <li><strong>Helpfulness:</strong> Is the answer useful and informative?</li>
      <li><strong>Truthfulness/Accuracy:</strong> Is the answer factually correct and supported by evidence?</li>
      <li><strong>Harmlessness:</strong> Is the answer safe and avoids harmful or biased content?</li>
      <li><strong>Browser Usage Quality:</strong> Was the browsing process efficient and effective in finding relevant information?</li>
    </ul>
  </li>
  <li><strong>Reward Model Training:</strong> 
    The human preference data is used to train a <strong>reward model</strong>. This reward model learns to predict which answer a human would prefer, based on the quality criteria. It also learns to reward efficient and effective browser usage.</li>
  <li><strong>Reinforcement Learning Fine-tuning:</strong> 
    WebGPT’s policy (how it decides to act in the browser and generate answers) is then fine-tuned using reinforcement learning (Proximal Policy Optimization). The reward signal comes from the trained reward model. The RL objective is to train WebGPT to perform browser actions and generate answers that maximize the reward predicted by the reward model, thus aligning with human preferences for helpful, truthful, and harmless answers.</li>
</ol>

<h3 id="training-language-models-to-follow-instructions-with-human-feedback">Training language models to follow instructions with human feedback</h3>

<p><a href="https://arxiv.org/abs/2203.02155">L. Ouyang et al., “Training language models to follow instructions with human feedback,” Advances in neural information processing systems, vol. 35, pp. 27730–27744, 2022.</a></p>

<p><strong>Challenge</strong>: Mismatch between Language Model Objectives and User Intent</p>

<p>A key problem with standard language models trained for next-token prediction: they are good at generating text that is statistically likely but not necessarily helpful, truthful, or harmless (the “alignment problem”). These models often generate outputs that are:</p>
<ul>
  <li><strong>Unhelpful:</strong> Not actually answering the user’s question or fulfilling the user’s request.</li>
  <li><strong>Untruthful:</strong> Generating factually incorrect or misleading information.</li>
  <li><strong>Harmful:</strong> Producing biased, toxic, or unsafe content.</li>
</ul>

<p>The core issue is that optimizing for next-token prediction alone doesn’t incentivize models to align with human intent and values.</p>

<p><strong>Proposal</strong>: InstructGPT - Training Language Models to Follow Instructions via RLHF</p>

<p>The central solution proposed is <strong>InstructGPT</strong>, a language model specifically trained to follow instructions using Reinforcement Learning from Human Feedback (RLHF). The goal is to directly train the model to be helpful, truthful, and harmless, aligning its behavior with what humans actually want.</p>

<p><strong>Details</strong>: A Three-Step RLHF Pipeline for Instruction Following:</p>
<ol>
  <li><strong>Supervised Fine-tuning (SFT) on Instruction Data:</strong> 
    First, fine-tune a pre-trained language model (in this case, a GPT-3 model) on a dataset of human-written demonstrations of instruction following. This dataset consists of prompts (instructions) and desired responses. This step teaches the model to initially understand and attempt to follow instructions.</li>
  <li><strong>Reward Model Training from Human Preference Data:</strong> 
    Next, collect human preference data. Humans are presented with multiple responses generated by the SFT model for a given instruction. They are asked to rank these responses based on which one is better, considering factors like helpfulness, truthfulness, and harmlessness. This preference data is used to train a <strong>reward model</strong>. The reward model learns to predict which response a human would prefer for a given instruction. It essentially learns to score responses based on alignment with human values.</li>
  <li><strong>Reinforcement Learning Fine-tuning with the Reward Model:</strong>
    Finally, the SFT model is further fine-tuned using reinforcement learning (Proximal Policy Optimization). The reward signal for RL is provided by the trained reward model. The RL objective is to train the model to generate responses that maximize the reward predicted by the reward model. This step directly optimizes the language model for alignment with human preferences as captured by the reward model.</li>
</ol>

<h3 id="training-a-helpful-and-harmless-assistant-with-reinforcement-learning-from-human-feedback">Training a helpful and harmless assistant with reinforcement learning from human feedback</h3>

<p><a href="https://arxiv.org/abs/2204.05862">Y. Bai et al., “Training a helpful and harmless assistant with reinforcement learning from human feedback,” arXiv preprint arXiv:2204.05862, 2022.</a></p>

<p><strong>Challenge</strong>: Ensuring Harmlessness in AI Assistants trained with RLHF</p>

<p>While previous RLHF work (like InstructGPT) focused on helpfulness and truthfulness, this paper specifically tackles the challenge of ensuring <strong>harmlessness</strong> in AI assistants. They argue that directly relying on human feedback for <em>all</em> aspects of harmlessness can be problematic and potentially lead to inconsistent or biased judgments. It’s difficult for humans to consistently and comprehensively define “harmlessness” in all situations.</p>

<p><strong>Proposal</strong>: Constitutional AI (CAI) - Using a Constitution to Guide Harmlessness Learning</p>

<p>Instead of directly asking humans to rate harmlessness in every instance, they propose to use a <strong>set of principles, or a “constitution,” to define and guide what constitutes harmless behavior.</strong> This constitution is used to:</p>
<ol>
  <li><strong>Self-Critique:</strong> The AI assistant itself uses the constitution to critique its own responses and identify potentially harmful outputs.</li>
  <li><strong>Guide Reward Model Training:</strong> The constitution informs the training of the reward model, so the model learns to penalize responses that violate the constitutional principles.</li>
</ol>

<p><strong>Details</strong>: Two-Phase RLHF with Constitutional Guidance</p>
<ol>
  <li><strong>Constitutional Reinforcement Learning (Constitutional RL):</strong>
    <ul>
      <li><strong>Agent Generates Responses:</strong> The AI assistant generates responses to prompts.</li>
      <li><strong>Constitutional Critique:</strong> The assistant then uses the pre-defined constitution to critique its <em>own</em> generated responses. This critique identifies potential violations of the constitutional principles.</li>
      <li><strong>Self-Correction:</strong> Based on the critique, the assistant refines or regenerates its response to better align with the constitution.</li>
      <li><strong>Reward based on Constitutional Alignment:</strong> A reward signal is generated based on how well the response aligns with the constitution (i.e., how few constitutional violations it has). This phase trains the assistant to be <em>constitutionally aligned</em>.</li>
    </ul>
  </li>
  <li><strong>Human Preference Reinforcement Learning (Preference RL):</strong>
    <ul>
      <li><strong>Agent Generates Pairs of Responses:</strong> The constitutionally trained assistant generates pairs of responses (often one from the constitutional RL phase and one from a baseline model, or variations of constitutionally aligned responses).</li>
      <li><strong>Human Preference Judgments (Helpfulness):</strong> Humans are then asked to compare these pairs of responses and choose which one is <em>more helpful</em> (ignoring harmlessness at this stage, as harmlessness is already addressed in phase 1).</li>
      <li><strong>Reward Model Training (Helpfulness Reward):</strong> Human preference data is used to train a reward model that specifically focuses on predicting human preferences for <em>helpfulness</em>.</li>
      <li><strong>RL Fine-tuning with Helpfulness Reward:</strong> The constitutionally aligned assistant is further fine-tuned using reinforcement learning, but now with the reward signal from the <em>helpfulness</em> reward model. This phase trains the assistant to be <em>helpful</em>, while retaining the harmlessness learned in phase 1.</li>
    </ul>
  </li>
</ol>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="Large Language Models" /><category term="Reinforcement Learning from Human Feedback" /><category term="Reading" /><category term="Notes" /><category term="RLHF" /><category term="LLM" /><summary type="html"><![CDATA[Table of Contents Glossed Overview: RLHF for LLM Further Background Readings RLHF in Deep Reinforcement Learning Learning to summarize from human feedback Webgpt: Browser-assisted question-answering with human feedback Training language models to follow instructions with human feedback Training a helpful and harmless assistant with reinforcement learning from human feedback Glossed Overview: RLHF for LLM Reinforcement Learning from Human Feedback (RLHF) is a technique to integrate human preferences into AI systems, particularly for problems that are difficult to define explicitly. The core RLHF process involves three steps: training a capable language model collecting human preference data to train a reward model optimizing the language model using reinforcement learning guided by the reward model. RLHF is a crucial part of “post-training” for LLM, a set of techniques to enhance model usability, including: Supervised Instructional Finetuning for learning features of language that form the basis of the desired output format and the ability of instruction following. Preference Finetuning for learning the output style and subtle alignment with human preferences and Reinforcement Finetuning for further performance boosts in verifiable domains Further Background Readings RLHF in Deep Reinforcement Learning P. F. Christiano, J. Leike, T. Brown, M. Martic, S. Legg, and D. Amodei, “Deep reinforcement learning from human preferences,” Advances in neural information processing systems, vol. 30, 2017. Challenge: Difficulty in Specifying Reward Functions Manually designing reward functions for complex tasks is incredibly difficult and often leads to unintended or suboptimal agent behavior. Proposal: Learning from Human Preferences Instead of Explicit Rewards Instead of trying to define a reward function directly, learn a reward function from human judgments about which behavior is better. Details: Generate Trajectory Pairs: The agent performs the task and generates pairs of trajectories (sequences of actions and states). Human Preference Judgments: Humans are presented with these pairs of trajectories and asked to choose which one they prefer (which trajectory is “better” according to some criteria). Crucially, humans don’t need to explicitly define why one is better, just to indicate their preference. Reward Model Training: These human preference judgments are used to train a reward model. This reward model learns to predict which trajectory a human would prefer. Essentially, it learns to approximate the underlying, implicit reward function based on human feedback. Reinforcement Learning with Learned Reward Model: The trained reward model is then used as the reward signal for a standard deep reinforcement learning algorithm (like policy gradients or Q-learning). The agent is trained to maximize the reward predicted by the reward model, which in turn is aligned with human preferences. Learning to summarize from human feedback N. Stiennon et al., “Learning to summarize with human feedback,” Advances in Neural Information Processing Systems, vol. 33, pp. 3008–3021, 2020. Challenge: Limitations of Traditional Summarization Metrics and Methods. Traditional automatic summarization methods, often optimized using metrics like ROUGE, don’t always align well with human preferences for good summaries. ROUGE primarily measures n-gram overlap with reference summaries, which can be a crude proxy for summary quality. Furthermore, directly optimizing for metrics like ROUGE can lead to models that generate summaries that are grammatically correct but lack coherence, focus, or truly capture the essence of the original text as a human would. Proposal: Training Summarization Models with Human Preference Feedback. Similar to the Christiano et al. (2017) paper, this work proposes to move away from solely relying on automatic metrics and instead train summarization models using direct human feedback on the quality of generated summaries. The idea is to teach the model to generate summaries that humans prefer, rather than just those that score well on automatic metrics. Details: Pre-training a Summarization Model: Pre-train a sequence-to-sequence model for summarization. Collecting Human Preference Data (Comparison Data): Collect human judgments by presenting human annotators with pairs of summaries generated by different models (or different versions of the same model). The annotators are asked to choose which summary is better based on criteria like: Helpfulness: Is the summary informative and useful? Relevance: Does the summary accurately reflect the content of the original document? Readability: Is the summary well-written and easy to understand? Non-redundancy: Does the summary avoid unnecessary repetition? Training a Reward Model: The collected human preference data (pairs of summaries and the preferred one) is used to train a reward model. This reward model learns to predict which summary a human would prefer given an input document. The reward model is trained to assign higher scores to summaries that humans tend to prefer. Fine-tuning the Summarization Model with Reinforcement Learning: The pre-trained summarization model is then fine-tuned using reinforcement learning. The reward signal for RL is provided by the trained reward model. The RL objective is to generate summaries that maximize the score given by the reward model, effectively guiding the summarization model towards generating summaries that are more human-preferred. They used Proximal Policy Optimization (PPO) algorithm for this RL fine-tuning stage. Webgpt: Browser-assisted question-answering with human feedback R. Nakano et al., “Webgpt: Browser-assisted question-answering with human feedback,” arXiv preprint arXiv:2112.09332, 2021. Challenge: Limitations of Traditional Question Answering and the Need for Browser Assistance: Traditional question-answering (QA) models rely solely on their internal knowledge or pre-indexed datasets. Many real-world questions require accessing and processing information from the open web to provide comprehensive and up-to-date answers. Furthermore, simply retrieving documents isn’t enough; the model needs to effectively browse, extract relevant information, and synthesize it into a coherent answer. Proposal: WebGPT - A Browser-Assisted QA Model Trained with Human Feedback WebGPT, a model that is trained to use a web browser to answer questions. It’s not just a language model; it’s an agent that can interact with the web in a controlled manner, including searching, clicking links, scrolling, and reading web pages. Crucially, WebGPT is trained using Reinforcement Learning from Human Feedback (RLHF) to generate answers that are helpful, truthful, and harmless. Details: Browser-in-the-Loop Question Answering with RLHF: Browser Environment: They created a simulated browser environment that WebGPT can interact with. This environment provides actions like searching, clicking links, scrolling, and observing the rendered web page content. WebGPT Agent: WebGPT is a Transformer-based language model trained to act as an agent within this browser environment. Given a question, it decides on a sequence of browser actions to gather information and ultimately generate an answer. Human Feedback Collection: Human evaluators are crucial. They are asked to compare pairs of answers generated by different models (including WebGPT and baseline models) and indicate which answer is better based on criteria like: Helpfulness: Is the answer useful and informative? Truthfulness/Accuracy: Is the answer factually correct and supported by evidence? Harmlessness: Is the answer safe and avoids harmful or biased content? Browser Usage Quality: Was the browsing process efficient and effective in finding relevant information? Reward Model Training: The human preference data is used to train a reward model. This reward model learns to predict which answer a human would prefer, based on the quality criteria. It also learns to reward efficient and effective browser usage. Reinforcement Learning Fine-tuning: WebGPT’s policy (how it decides to act in the browser and generate answers) is then fine-tuned using reinforcement learning (Proximal Policy Optimization). The reward signal comes from the trained reward model. The RL objective is to train WebGPT to perform browser actions and generate answers that maximize the reward predicted by the reward model, thus aligning with human preferences for helpful, truthful, and harmless answers. Training language models to follow instructions with human feedback L. Ouyang et al., “Training language models to follow instructions with human feedback,” Advances in neural information processing systems, vol. 35, pp. 27730–27744, 2022. Challenge: Mismatch between Language Model Objectives and User Intent A key problem with standard language models trained for next-token prediction: they are good at generating text that is statistically likely but not necessarily helpful, truthful, or harmless (the “alignment problem”). These models often generate outputs that are: Unhelpful: Not actually answering the user’s question or fulfilling the user’s request. Untruthful: Generating factually incorrect or misleading information. Harmful: Producing biased, toxic, or unsafe content. The core issue is that optimizing for next-token prediction alone doesn’t incentivize models to align with human intent and values. Proposal: InstructGPT - Training Language Models to Follow Instructions via RLHF The central solution proposed is InstructGPT, a language model specifically trained to follow instructions using Reinforcement Learning from Human Feedback (RLHF). The goal is to directly train the model to be helpful, truthful, and harmless, aligning its behavior with what humans actually want. Details: A Three-Step RLHF Pipeline for Instruction Following: Supervised Fine-tuning (SFT) on Instruction Data: First, fine-tune a pre-trained language model (in this case, a GPT-3 model) on a dataset of human-written demonstrations of instruction following. This dataset consists of prompts (instructions) and desired responses. This step teaches the model to initially understand and attempt to follow instructions. Reward Model Training from Human Preference Data: Next, collect human preference data. Humans are presented with multiple responses generated by the SFT model for a given instruction. They are asked to rank these responses based on which one is better, considering factors like helpfulness, truthfulness, and harmlessness. This preference data is used to train a reward model. The reward model learns to predict which response a human would prefer for a given instruction. It essentially learns to score responses based on alignment with human values. Reinforcement Learning Fine-tuning with the Reward Model: Finally, the SFT model is further fine-tuned using reinforcement learning (Proximal Policy Optimization). The reward signal for RL is provided by the trained reward model. The RL objective is to train the model to generate responses that maximize the reward predicted by the reward model. This step directly optimizes the language model for alignment with human preferences as captured by the reward model. Training a helpful and harmless assistant with reinforcement learning from human feedback Y. Bai et al., “Training a helpful and harmless assistant with reinforcement learning from human feedback,” arXiv preprint arXiv:2204.05862, 2022. Challenge: Ensuring Harmlessness in AI Assistants trained with RLHF While previous RLHF work (like InstructGPT) focused on helpfulness and truthfulness, this paper specifically tackles the challenge of ensuring harmlessness in AI assistants. They argue that directly relying on human feedback for all aspects of harmlessness can be problematic and potentially lead to inconsistent or biased judgments. It’s difficult for humans to consistently and comprehensively define “harmlessness” in all situations. Proposal: Constitutional AI (CAI) - Using a Constitution to Guide Harmlessness Learning Instead of directly asking humans to rate harmlessness in every instance, they propose to use a set of principles, or a “constitution,” to define and guide what constitutes harmless behavior. This constitution is used to: Self-Critique: The AI assistant itself uses the constitution to critique its own responses and identify potentially harmful outputs. Guide Reward Model Training: The constitution informs the training of the reward model, so the model learns to penalize responses that violate the constitutional principles. Details: Two-Phase RLHF with Constitutional Guidance Constitutional Reinforcement Learning (Constitutional RL): Agent Generates Responses: The AI assistant generates responses to prompts. Constitutional Critique: The assistant then uses the pre-defined constitution to critique its own generated responses. This critique identifies potential violations of the constitutional principles. Self-Correction: Based on the critique, the assistant refines or regenerates its response to better align with the constitution. Reward based on Constitutional Alignment: A reward signal is generated based on how well the response aligns with the constitution (i.e., how few constitutional violations it has). This phase trains the assistant to be constitutionally aligned. Human Preference Reinforcement Learning (Preference RL): Agent Generates Pairs of Responses: The constitutionally trained assistant generates pairs of responses (often one from the constitutional RL phase and one from a baseline model, or variations of constitutionally aligned responses). Human Preference Judgments (Helpfulness): Humans are then asked to compare these pairs of responses and choose which one is more helpful (ignoring harmlessness at this stage, as harmlessness is already addressed in phase 1). Reward Model Training (Helpfulness Reward): Human preference data is used to train a reward model that specifically focuses on predicting human preferences for helpfulness. RL Fine-tuning with Helpfulness Reward: The constitutionally aligned assistant is further fine-tuned using reinforcement learning, but now with the reward signal from the helpfulness reward model. This phase trains the assistant to be helpful, while retaining the harmlessness learned in phase 1.]]></summary></entry><entry><title type="html">Python dictionary get with default value</title><link href="https://ryanzhang.info/post/2022/03/10/python-dictionary-get-with-default-value.html" rel="alternate" type="text/html" title="Python dictionary get with default value" /><published>2022-03-10T15:43:43+00:00</published><updated>2022-03-10T15:43:43+00:00</updated><id>https://ryanzhang.info/post/2022/03/10/python-dictionary-get-with-default-value</id><content type="html" xml:base="https://ryanzhang.info/post/2022/03/10/python-dictionary-get-with-default-value.html"><![CDATA[<p>Yesterday I spent way more time than neccessary debugging a piece of python code. It has something to do with how the python dictionary’s <code class="language-plaintext highlighter-rouge">get</code> method works with <code class="language-plaintext highlighter-rouge">default</code> arguments. TLDR: it is ok when default is a simple value, not recommended if default is a function call.</p>

<p>Here is what the problematic code looks like:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_option</span><span class="p">(</span><span class="n">option_name</span><span class="p">):</span>
    <span class="n">option</span> <span class="o">=</span> <span class="c1"># some logic
</span>    <span class="k">return</span> <span class="n">option</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">d</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"option_1"</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">get_option</span><span class="p">(</span><span class="s">"option_1"</span><span class="p">))</span>
</code></pre></div></div>

<p>The problem here is that the <code class="language-plaintext highlighter-rouge">default</code> argument is a pointer to the return value of a function call, thus the function will be evaluated before the body of the <code class="language-plaintext highlighter-rouge">get</code> method. So no matter whether the key exists or not, the function will always be evaluated. It can break if:</p>

<ol>
  <li>the function is not defined.</li>
  <li>the function call raises an Error or Exception.</li>
</ol>

<p>The fix is change it to:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="n">d</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"option_1"</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span> <span class="ow">or</span> <span class="n">get_option</span><span class="p">(</span><span class="s">"option_1"</span><span class="p">)</span>
</code></pre></div></div>
<p>So that the function call will be used as a last resort when key is not in the dictionary.</p>

<p>But one drawback is that, if the value is logically <code class="language-plaintext highlighter-rouge">False</code>, like a int value <code class="language-plaintext highlighter-rouge">0</code>, it will still fallback to the function call.</p>

<p>So a better rewrite is:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="n">d</span><span class="p">[</span><span class="s">"option_1"</span><span class="p">]</span> <span class="k">if</span> <span class="s">"option_1"</span> <span class="ow">in</span> <span class="n">d</span> <span class="k">else</span> <span class="n">get_option</span><span class="p">(</span><span class="s">"option_1"</span><span class="p">)</span>
</code></pre></div></div>

<p>This is slightly different to <code class="language-plaintext highlighter-rouge">collections.defaultdict</code>. In which, the fallback <code class="language-plaintext highlighter-rouge">default_factory</code> callable is only called after the key check. But ofcourse, <code class="language-plaintext highlighter-rouge">default_factory</code> does not accept any argument, thus does not suit the usecase here.</p>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="programming language" /><category term="python" /><summary type="html"><![CDATA[Yesterday I spent way more time than neccessary debugging a piece of python code. It has something to do with how the python dictionary’s get method works with default arguments. TLDR: it is ok when default is a simple value, not recommended if default is a function call.]]></summary></entry><entry><title type="html">Python sequentially unpacks tuple with assignment expression</title><link href="https://ryanzhang.info/post/2021/04/14/python-sequentially-unpacks-tuple-with-assignment-expression.html" rel="alternate" type="text/html" title="Python sequentially unpacks tuple with assignment expression" /><published>2021-04-14T15:43:43+00:00</published><updated>2021-04-14T15:43:43+00:00</updated><id>https://ryanzhang.info/post/2021/04/14/python-sequentially-unpacks-tuple-with-assignment-expression</id><content type="html" xml:base="https://ryanzhang.info/post/2021/04/14/python-sequentially-unpacks-tuple-with-assignment-expression.html"><![CDATA[<p>The other day, I shadowed an interview with a data science candidate. The primary focus is obviously not on coding skills, but we do want to assess basic knowledge of the programming language of his choice. So, my colleague asked a very simple python question to warm him up. The question is: ‘how do you swap values of two variables wtihout using a temprary variable?’. To my surprise the candidate had no clue it is as simple as <code class="language-plaintext highlighter-rouge">a, b = b, a</code>.</p>

<p>In most other languages this is not a valid statement. The reason it works in python is as follows:</p>
<ol>
  <li>the expression on the right hand side gets evaluated. As a result, a tuple of two elements <code class="language-plaintext highlighter-rouge">(a, b)</code> is created.</li>
  <li>then python unpack this tuple, assign the values to each variables on the left hand side sequentially in the left to right order.</li>
</ol>

<p>So if <code class="language-plaintext highlighter-rouge">a = 5; b = 3</code>. When python evaluates <code class="language-plaintext highlighter-rouge">a, b = b, a</code>:</p>
<ol>
  <li>It first create a tuple of <code class="language-plaintext highlighter-rouge">(3, 5)</code></li>
  <li>Then it assigns <code class="language-plaintext highlighter-rouge">3</code> to <code class="language-plaintext highlighter-rouge">a</code></li>
  <li>Finally it assigns <code class="language-plaintext highlighter-rouge">5</code> to <code class="language-plaintext highlighter-rouge">b</code></li>
</ol>

<p>This is very handy and readable. However, users may think, as long as they aligned elements on the left with the corresponding elements on the right, the swap should always work. Indeed, if there is no error, code will be excuted, but not always as intened. The sequential order of unpacking should be considered when we write multiple assignments via tuple unpacking to avoid any ‘suprising’ behavior.</p>

<p>Let’s say that we want to do a simple linked list reversal. In other programming language, we will usually use a temprary variable to hold the <code class="language-plaintext highlighter-rouge">next</code> node from <code class="language-plaintext highlighter-rouge">curr</code> to make sure we can advance to it after we rewire <code class="language-plaintext highlighter-rouge">curr.next</code> to <code class="language-plaintext highlighter-rouge">prev</code>:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">reverse_linked_list</span><span class="p">(</span><span class="n">head</span><span class="p">):</span>
    <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span> <span class="n">head</span>
    <span class="k">while</span> <span class="n">curr</span><span class="p">:</span>
        <span class="n">next_</span> <span class="o">=</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span>
        <span class="n">curr</span><span class="p">.</span><span class="nb">next</span> <span class="o">=</span> <span class="n">prev</span>
        <span class="n">prev</span> <span class="o">=</span> <span class="n">curr</span>
        <span class="n">curr</span> <span class="o">=</span> <span class="n">next_</span>
    <span class="k">return</span> <span class="n">prev</span>
</code></pre></div></div>
<p>With a quick drawing, you can picture how this rewiring works and verify it is correct.</p>

<p>With tuple unpacking, we can do the same thing in python like this:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">while</span> <span class="n">curr</span><span class="p">:</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span><span class="p">,</span> <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span>
</code></pre></div></div>
<p>This follows the assignments order from above with temprary variable swap pattern.</p>

<p>However this will also work.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">while</span> <span class="n">curr</span><span class="p">:</span> <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span><span class="p">,</span> <span class="n">curr</span> <span class="o">=</span> <span class="n">curr</span><span class="p">,</span> <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span>
</code></pre></div></div>
<p>At first read, one may feel the first two elements on both sides are out of order. But it does not matter, because the reference to the node objects are already stored in the tuple before the unpacking. However, this does not mean any order will work:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">while</span> <span class="n">curr</span><span class="p">:</span> <span class="n">curr</span><span class="p">,</span> <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span> <span class="o">=</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span><span class="p">,</span> <span class="n">curr</span><span class="p">,</span> <span class="n">prev</span>
<span class="k">while</span> <span class="n">curr</span><span class="p">:</span> <span class="n">prev</span><span class="p">,</span> <span class="n">curr</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span> <span class="o">=</span> <span class="n">curr</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span><span class="p">,</span> <span class="n">prev</span>
</code></pre></div></div>
<p>Both above will not work. Because the sequential unpacking. In the last two versions, after the first two unpacking, <code class="language-plaintext highlighter-rouge">curr</code> will in both cases become referencing to the second node in the linked list. And the last unpacking assignment will wirte this node’s <code class="language-plaintext highlighter-rouge">next</code> pointer to what <code class="language-plaintext highlighter-rouge">prev</code> was at the begining of this unpacking happened, which is <code class="language-plaintext highlighter-rouge">None</code>. As a result, the loop will throw error in the second iteration, as we will try to access <code class="language-plaintext highlighter-rouge">.next</code> from <code class="language-plaintext highlighter-rouge">None</code>.</p>

<p>If the above is hard to wrap your head around, the one-liner is roughly equivlent to:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">while</span> <span class="n">curr</span><span class="p">:</span>
    <span class="n">snapshot</span> <span class="o">=</span> <span class="p">(</span><span class="n">prev</span><span class="p">,</span> <span class="n">curr</span><span class="p">,</span> <span class="n">curr</span><span class="p">.</span><span class="nb">next</span><span class="p">)</span>
    <span class="n">curr</span><span class="p">.</span><span class="nb">next</span> <span class="o">=</span> <span class="n">snapshot</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">prev</span> <span class="o">=</span> <span class="n">snapshot</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    <span class="n">curr</span> <span class="o">=</span> <span class="n">snapshot</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
</code></pre></div></div>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="programming language" /><category term="python" /><summary type="html"><![CDATA[The other day, I shadowed an interview with a data science candidate. The primary focus is obviously not on coding skills, but we do want to assess basic knowledge of the programming language of his choice. So, my colleague asked a very simple python question to warm him up. The question is: ‘how do you swap values of two variables wtihout using a temprary variable?’. To my surprise the candidate had no clue it is as simple as a, b = b, a.]]></summary></entry><entry><title type="html">Learning to rank</title><link href="https://ryanzhang.info/post/2019/10/31/learning-to-rank.html" rel="alternate" type="text/html" title="Learning to rank" /><published>2019-10-31T15:43:43+00:00</published><updated>2019-10-31T15:43:43+00:00</updated><id>https://ryanzhang.info/post/2019/10/31/learning-to-rank</id><content type="html" xml:base="https://ryanzhang.info/post/2019/10/31/learning-to-rank.html"><![CDATA[<h2 id="task">Task</h2>
<p>We want to learn a function \(f(q, D)\) which takes in a query \(q\) and a list of documents \(D=\{d_1, d_2, ..., d_n\}\), and produces scores using which we can rank/order the list of documents.</p>

<h2 id="types">Types</h2>
<p>There are multiple ways we can formulate the problem:</p>
<ol>
  <li>Pointwise</li>
  <li>Pairwise</li>
  <li>Listwise</li>
</ol>

<h3 id="pointwise">Pointwise</h3>
<p>In this approach we learn \(f(q,d)\), which scores the match-ness between the query and document independently. When scoring a data point, the function does not take other document in the list into consideration.</p>

<p>To train a model in this approach, the data would be in the long format where each row contains a \((q,d)\) pair and we need labels for every row. Either the label is binary(classification) or relevance scores(regression).</p>

<h3 id="pairwise">Pairwise</h3>
<p>In this approach we learn \(Pr(rank(d_i,q)\succ rank(d_j,q))\), that is to learn to determine relevant preference between two documents given a query.</p>

<p>It can be treated as binary classification problem, the data would be in the format where each row contains a triplet of \((q,d_i,d_j)\) and we need a binary label for each row. We can hand crafting features that captures the difference between \(d_i,d_j\) with respect to \(q\) and feed that difference to a binary classifier.</p>

<p>Or more often we learn it in a pointwise fashion, by learning the intermediate rank function. Let \(rank(d_i,q)=s_i\) , the pairwise classification problem becomes classification on the difference between rank scores. That is:</p>

\[Pr(rank(d_i,q) &gt; rank(d_j,q))=\frac{1}{1+exp(-(s_i-s_j))}\]

<p>The loss would be the negative log of this likelihood, which is: \(L_{ij}=log(1+exp(s_j-s_i))\) and we can train the rank function to minimize this loss.</p>

<p>If we work out the graident with respect to the parameter in the rank function, it is:</p>

\[\begin{aligned}
\frac{\partial L_{ij}}{\partial \theta}&amp;=\frac{\partial L_{ij}}{\partial s_i}\frac{\partial s_i}{\partial \theta} + \frac{\partial L_{ij}}{\partial s_j}\frac{\partial s_j}{\partial \theta} \\
&amp;=-\frac{1}{1 + exp(s_i-s_j)}(\frac{\partial s_i}{\theta} - \frac{\partial s_j}{\theta}) \\
&amp;=\lambda_{ij}(\frac{\partial s_i}{\theta} - \frac{\partial s_j}{\theta})
\end{aligned}\]

<p>As a result, a single gradient descent step with this gradient is doing a gradient ascent for \(s_i\) and gradient descent for \(s_j\) together with a weight of \(\lambda_{ij}\). That is, for a given pair of documents, we make the score of a more relevant document higher, and make the score of a less relevant document lower, and how much we perform the update is determined by the score difference.</p>

<h3 id="listwise">Listwise</h3>
<p>To be continued.</p>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="machine_learning" /><summary type="html"><![CDATA[Task We want to learn a function \(f(q, D)\) which takes in a query \(q\) and a list of documents \(D=\{d_1, d_2, ..., d_n\}\), and produces scores using which we can rank/order the list of documents.]]></summary></entry><entry><title type="html">Iterate over an iterable multiple times</title><link href="https://ryanzhang.info/post/2019/08/22/iterate-over-an-iterable-multiple-times.html" rel="alternate" type="text/html" title="Iterate over an iterable multiple times" /><published>2019-08-22T19:50:00+00:00</published><updated>2019-08-22T19:50:00+00:00</updated><id>https://ryanzhang.info/post/2019/08/22/iterate-over-an-iterable-multiple-times</id><content type="html" xml:base="https://ryanzhang.info/post/2019/08/22/iterate-over-an-iterable-multiple-times.html"><![CDATA[<p>I was working on a piece of code today and I need to iterate over a iterable multiple times to do some computes. The body of code is the same for all passes. One obvious thing I can do is to do double loops:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">results</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_repeats</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">get_iterable</span><span class="p">():</span>
        <span class="n">results</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">.</span><span class="n">process</span><span class="p">(</span><span class="n">item</span><span class="p">))</span>
</code></pre></div></div>

<p>The reason that I need to loop over it multiple times rather than loop once and duplicate <code class="language-plaintext highlighter-rouge">results</code> is that <code class="language-plaintext highlighter-rouge">f</code> has an internal state, it gives a different output based on the number of times it has seen the input. The above code did work, but not as nice as I’d like. I utilized <code class="language-plaintext highlighter-rouge">chain</code> and <code class="language-plaintext highlighter-rouge">repeat</code> from <code class="language-plaintext highlighter-rouge">itertools</code> to clean up it a bit:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">itertools</span> <span class="kn">import</span> <span class="n">chain</span><span class="p">,</span> <span class="n">repeat</span>
<span class="n">results</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">chain</span><span class="p">(</span><span class="o">*</span><span class="n">repeat</span><span class="p">(</span><span class="n">get_iterable</span><span class="p">(),</span> <span class="n">n_repeats</span><span class="p">)):</span>
    <span class="n">results</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">.</span><span class="n">process</span><span class="p">(</span><span class="n">item</span><span class="p">))</span>
</code></pre></div></div>

<p>It does the exactly same thing, but the code reads more like straight English.</p>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="python" /><category term="iterator" /><summary type="html"><![CDATA[I was working on a piece of code today and I need to iterate over a iterable multiple times to do some computes. The body of code is the same for all passes. One obvious thing I can do is to do double loops:]]></summary></entry><entry><title type="html">Clean python code to get a reverse mapping</title><link href="https://ryanzhang.info/post/2019/08/21/clean-code-to-get-a-reverse-mapping-in-python.html" rel="alternate" type="text/html" title="Clean python code to get a reverse mapping" /><published>2019-08-21T15:40:00+00:00</published><updated>2019-08-21T15:40:00+00:00</updated><id>https://ryanzhang.info/post/2019/08/21/clean-code-to-get-a-reverse-mapping-in-python</id><content type="html" xml:base="https://ryanzhang.info/post/2019/08/21/clean-code-to-get-a-reverse-mapping-in-python.html"><![CDATA[<!--excerpt.start-->
<p>When working with machine learning problems, often I use python dictionary to map categorical values to its integer encoded values. <!--excerpt.end--> Something like:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">string</span>
<span class="n">feature_encoder</span> <span class="o">=</span> <span class="p">{</span><span class="n">v</span><span class="p">:</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">string</span><span class="p">.</span><span class="n">ascii_lowercase</span><span class="p">)}</span>
</code></pre></div></div>
<p>To get back the original value, I need to have a reverse mapping, I used to create it with:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">feature_decoder</span> <span class="o">=</span> <span class="p">{</span><span class="n">v</span><span class="p">:</span><span class="n">k</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">feature_encoder</span><span class="p">.</span><span class="n">items</span><span class="p">()}</span>
</code></pre></div></div>
<p>This is fine, but I dislike it for that I have to spend some mental effort to read it to know what I was doing the next time I read my code. And today I found a nicer way to get the reverse mapping.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">feature_decoder</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">reversed</span><span class="p">,</span> <span class="n">feature_encoder</span><span class="p">.</span><span class="n">items</span><span class="p">()))</span>
</code></pre></div></div>
<p>It is doing exactly the same thing, but I can just read it as an English sentence to know what is going on.</p>]]></content><author><name>Ren Zhang</name></author><category term="post" /><category term="python" /><summary type="html"><![CDATA[When working with machine learning problems, often I use python dictionary to map categorical values to its integer encoded values. Something like: import string feature_encoder = {v:i for i, v in enumerate(string.ascii_lowercase)} To get back the original value, I need to have a reverse mapping, I used to create it with: feature_decoder = {v:k for k, v in feature_encoder.items()} This is fine, but I dislike it for that I have to spend some mental effort to read it to know what I was doing the next time I read my code. And today I found a nicer way to get the reverse mapping. feature_decoder = dict(map(reversed, feature_encoder.items())) It is doing exactly the same thing, but I can just read it as an English sentence to know what is going on.]]></summary></entry></feed>