Sample Extraction from RLWE to LWE

In this article I’ll derive a trick used in FHE called sample extraction. In brief, it allows one to partially convert a ciphertext in the Ring Learning With Errors (RLWE) scheme to the Learning With Errors (LWE) scheme.

Here are some other articles I’ve written about other FHE building blocks, though they are not prerequisites for this article.

LWE and RLWE

The first two articles in the list above define the Learning With Errors problem (LWE). I will repeat the definition here:

LWE: The LWE encryption scheme has the following parameters:

  • A plaintext space $ \mathbb{Z}/q\mathbb{Z}$, where $ q \geq 2$ is a positive integer. This is the space that the underlying message $m$ comes from.
  • An LWE dimension $ n \in \mathbb{N}$.
  • A discrete Gaussian error distribution $ D$ with a mean of zero and a fixed standard deviation.

An LWE secret key is defined as a vector $s \in \{0, 1\}^n$ (uniformly sampled). An LWE ciphertext is defined as a vector $ a = (a_1, \dots, a_n)$, sampled uniformly over $ (\mathbb{Z} / q\mathbb{Z})^n$, and a scalar $ b = \langle a, s \rangle + m + e$, where $m$ is the message, $e$ is drawn from $D$ and all arithmetic is done modulo $q$. Note: the message $m$ usually is represented by placing an even smaller message (say, a 4-bit message) in the highest-order bits of a 32-bit unsigned integer. So then decryption corresponds to computing $b – \langle a, s \rangle = m + e$ and rounding the result to recover $m$ while discarding $e$.

Without the error term, an attacker could determine the secret key from a polynomial-sized collection of LWE ciphertexts with something like Gaussian elimination. The set of samples looks like a linear (or affine) system, where the secret key entries are the unknown variables. With an error term, the problem of solving the system is believed to be hard, and only exponential time/space algorithms are known.

RLWE: The Ring Learning With Errors (RLWE) problem is the natural analogue of LWE, where all scalars involved are replaced with polynomials over a (carefully) chosen ring.

Formally, the RLWE encryption scheme has the following parameters:

  • A ring $R = \mathbb{Z}/q\mathbb{Z}$, where $ q \geq 2$ is a positive integer. This is the space of coefficients of all polynomials in the scheme. I usually think of $q$ as $2^{32}$, i.e., unsigned 32-bit integers.
  • A plaintext space $R[x] / (x^N + 1)$, where $N$ is a power of 2. This is the space that the underlying message $m(x)$ comes from, and it is encoded as a list of $N$ integers forming the coefficients of the polynomial.
  • An RLWE dimension $n \in \mathbb{N}$.
  • A discrete Gaussian error distribution $D$ with a mean of zero and a fixed standard deviation.

An RLWE secret key $s$ is defined as a list of $n$ polynomials with binary coefficients in $\mathbb{B}[x] / (x^N+1)$, where $\mathbb{B} = \{0, 1\}$. The coefficients are uniformly sampled, like in LWE. An RLWE ciphertext is defined as a vector of $n$ polynomials $a = (a_1(x), \dots, a_n(x))$, sampled uniformly over $(R[x] / (x^N+1))^n$, and a polynomial $b(x) = \langle a, s \rangle + m(x) + e(x)$, where $m(x)$ is the message (with a similar “store it in the top bits” trick as LWE), $e(x)$ is a polynomial with coefficients drawn from $D$ and all the products of the inner product are done in $R[x] / (x^N+1)$. Decryption in RLWE involves computing $b(x) – \langle a, s \rangle$ and rounding appropriately to recover $m(x)$. Just like with RLWE, the message is “hidden” in the noise added to an equation corresponding to the polynomial products (i.e., without the noise and with enough sample encryptions of the same message/secret key, you can solve the system and recover the message). For more notes on how polynomial multiplication ends up being tricker in this ring, see my negacyclic polynomial multiplication article.

The most common version of RLWE you will see in the literature sets the vector dimension $n=1$, and so the secret key $s$ is a single polynomial, the ciphertext is a single polynomial, and RLWE can be viewed as directly replacing the vector dot product in LWE with a polynomial product. However, making $n$ larger is believed to provide more security, and it can be traded off against making the polynomial degree smaller, which can be useful when tweaking parameters for performance (keeping the security level constant).

Sample Extraction

Sample extraction is the trick of taking an RLWE encryption of $m(x) = m_0 + m_1(x) + \dots + m_{N-1}x^{N-1}$, and outputting an LWE encryption of $m_0$. In our case, the degree $N$ and the dimension $n_{\textup{RLWE}}$ of the input RLWE ciphertext scheme is fixed, but we may pick the dimension $n_{\textup{LWE}}$ of the LWE scheme as we choose to make this trick work.

This is one of those times in math when it is best to “just work it out with a pencil.” It turns out there are no serious obstacles to our goal. We start with polynomials $a = (a_1(x), \dots, a_n(x))$ and $b(x) = \langle a, s \rangle + m(x) + e(x)$, and we want to produce a vector of scalars $(x_1, \dots, x_D)$ of some dimension $D$, a corresponding secret key $s$, and a $b = \langle a, s \rangle + m_0 + e’$, where $e’$ may be different from the input error $e(x)$, but is hopefully not too much larger.

As with many of the articles in this series, we employ the so-called “phase function” to help with the analysis, which is just the partial decryption of an RLWE ciphertext without the rounding step: $\varphi(x) = b(x) – \langle a, s \rangle = m(x) + e(x)$. The idea is as follows: inspect the structure of the constant term of $\varphi(x)$, oh look, it’s an LWE encryption.

So let’s expand the constant term of $b(x) – \langle a, s \rangle$. Given a polynomial expression, I will use the notation $(-)[0]$ to denote the constant coefficient, and $(-)[k]$ for the $k$-th coefficient.

$$ \begin{aligned}(b(x) – \langle a, s \rangle)[0] &= b[0] – \left ( (a_1s_1)[0] + \dots + (a_n s_n)[0] \right ) \end{aligned}$$

Each entry in the dot product is a negacyclic polynomial product, so its constant term requires summing all the pairs of coefficients of $a_i$ and $s_i$ whose degrees sum to zero mod $N$, and flipping signs when there’s wraparound. In particular, a single product above for $a_i s_i$ has the form:

$$(a_is_i) [0] = s_i[0]a_i[0] – s_i[1]a_i[N-1] – s_i[2]a_i[N-2] – \dots – s_i[N-1]a_i[1]$$

Notice that I wrote the coefficients of $s_i$ in increasing order. This was on purpose, because if we re-write this expression $(a_is_i)[0]$ as a dot product, we get

$$(a_is_i[0]) = \left \langle (s_i[0], s_i[1], \dots, s_i[N-1]), (a_i[0], -a_i[N-1], \dots, -a_i[1])\right \rangle$$

In particular, the $a_i[k]$ are public, so we can sign-flip and reorder them easily in our conversion trick. But $s_i$ is unknown at the time the sample extraction needs to occur, so it helps if we can leave the secret key untouched. And indeed, when we apply the above expansion to all of the terms in the computation of $\varphi(x)[0]$, we end up manipulating the $a_i$’s a lot, but merely “flattening” the coefficients of $s = (s_1(x), \dots, s_n(x))$ into a single long vector.

So combining all of the above products, we see that $(b(x) – \langle a, s \rangle)[0]$ is already an LWE encryption with $(x, y) = ((x_1, \dots, x_D), b[0])$, and $x$ being the very long ($D = n*N$) vector

$$\begin{aligned} x = (& a_0[0], -a_0[N-1], \dots, -a_0[1], \\ &a_1[0], -a_1[N-1], \dots, -a_1[1], \\ &\dots , \\ &a_n[0], -a_n[N-1], \dots, -a_n[1] ) \end{aligned}$$

And the corresponding secret key is

$$\begin{aligned} s_{\textup{LWE}} = (& (s_0[0], s_0[1], \dots, s_0[N-1] \\ &(s_1[0], s_1[1], \dots, s_1[N-1], \\ &\dots , \\ &s_n[0], s_n[1], \dots, s_n[N-1] ) \end{aligned}$$

And the error in this ciphertext is exactly the constant coefficient of the error polynomial $e(x)$ from the RLWE encryption, which is independent of the error of all the other coefficients.

Commentary

This trick is a best case scenario. Unlike with key switching, we don’t need to encrypt the output LWE secret key to perform the conversion. And unlike modulus switching, there is no impact on the error growth in the conversion from RLWE to LWE. So in a sense, this trick is “perfect,” though it loses information about the other coefficients of $m(x)$ in the process. As it happens, the CGGI FHE scheme that these articles are building toward only uses the constant coefficient.

The only twist to think about is that the output LWE ciphertext is dependent on the RLWE scheme parameters. What if you wanted to get a smaller-dimensional LWE ciphertext as output? This is a realistic concern, as in the CGGI FHE scheme one starts from an LWE ciphertext of one dimension, goes to RLWE of another (larger) dimension, and needs to get back to LWE of the original dimension by the end.

To do this, you have two options: one is to pick the RLWE ciphertext parameters $n, N$, so that their product is the value you need. A second is to allow the RLWE parameters to be whatever you need for performance/security, and then employ a key switching operation after the sample extraction to get back to the LWE parameters you need.

It is worth mentioning—though I am far from fully understanding the methods—there other ways to convert between LWE and RLWE. One can go from LWE to RLWE, or from a collection of LWEs to RLWE. Some methods can be found in this paper and its references.

Until next time!

The Gadget Decomposition in FHE

Lately I’ve been studying Fully Homomorphic Encryption, which is the miraculous ability to perform arbitrary computations on encrypted data without learning any information about the underlying message. It’s the most comprehensive private computing solution that can exist (and it does exist!).

The first FHE scheme by Craig Gentry was based on ideal lattices and was considered very complex (I never took the time to learn how it worked). Some later schemes (GSW = Gentry-Sahai-Waters) are based on matrix multiplication, and are conceptually much simpler. Even more recent FHE schemes build on GSW or use it as a core subroutine.

All of these schemes inject random noise into the ciphertext, and each homomorphic operation increases noise. Once the noise gets too big, you can no longer decrypt the message, and so every now and then you must apply a process called “bootstrapping” that reduces noise. It also tends to be the performance bottleneck of any FHE scheme, and this bottleneck is why FHE is not considered practical yet.

To help reduce noise growth, many FHE schemes like GSW use a technical construction dubbed the gadget decomposition. Despite the terribly vague name, it’s a crucial limitation on noise growth. When it shows up in a paper, it’s usually remarked as “well known in the literature,” and the details you’d need to implement it are omitted. It’s one of those topics.

So I’ll provide some details. The code from this post is on GitHub.

Binary digit decomposition

To create an FHE scheme, you need to apply two homomorphic operations to ciphertexts: addition and multiplication. Most FHE schemes admit one of the two operations trivially. If the ciphertexts are numbers as in RSA, you multiply them as numbers and that multiplies the underlying messages, but addition is not known to be possible. If ciphertexts are vectors as in the “Learning With Errors” scheme (LWE)—the basis of many FHE schemes—you add them as vectors and that adds the underlying messages. (Here the “Error” in LWE is synonymous with “random noise”, I will use the term “noise”) In LWE and most FHE schemes, a ciphertext hides the underlying message by adding random noise, and addition of two ciphertexts adds the corresponding noise. After too many unmitigated additions, the noise will grow so large it obstructs the message. So you stop computing, or you apply a bootstrapping operation to reduce the noise.

Most FHE schemes also allow you to multiply a ciphertext by an unencrypted constant $ A$, but then the noise scales by a factor of $ A$, which is undesirable if $ A$ is large. So you either need to limit the coefficients of your linear combinations by some upper bound, or use a version of the gadget decomposition.

The simplest version of the gadget decomposition works like this. Instead of encrypting a message $ m \in \mathbb{Z}$, you would encrypt $ m, 2m, 4m, …, 2^{k-1} m$ for some choice of $ k$, and then to multiply $ A < 2^k$ you write the binary digits of $ A = \sum_{i=0}^{k-1} a_i 2^i$ and you compute $ \sum_{i=0}^{k-1} a_i \textup{Enc}(2^i m)$. If the noise in each encryption is $ E$, and summing ciphertexts sums noise, then this trick reduces the noise growth from $ O(AE)$ to $ O(kE) = O(\log(A)E)$, at the cost of tracking $ k$ ciphertexts. (Calling the noise $ E$ is a bit of an abuse—in reality the error is sampled from a random distribution—but hopefully you see my point).

Some folks call the mapping $ \textup{PowersOf2}(m) = m \cdot (2^0, 2^1, 2^2, \dots, 2^{k-1})$, and for the sake of this article let’s call the operation of writing a number $ A$ in terms of its binary digits $ \textup{Bin}(A) = (a_0, \dots, a_{k-1})$ (note, the first digit is the least-significant bit, i.e., it’s a little-endian representation). Then PowersOf2 and Bin expand an integer product into a dot product, while shifting powers of 2 from one side to the other.

$ \displaystyle A \cdot m = \langle \textup{Bin}(A), \textup{PowersOf2}(m) \rangle$

This inspired the following “proof by meme” that I can’t resist including.

Working out an example, if the message is $ m=7$ and $ A = 100, k=7$, then $ \textup{PowersOf2}(7) = (7, 14, 28, 56, 112, 224, 448, 896)$ and $ \textup{Bin}(A) = (0,0,1,0,0,1,1,0)$ (again, little-endian), and the dot product is

$ \displaystyle 28 \cdot 1 + 224 \cdot 1 + 448 \cdot 1 = 700 = 7 \cdot 2^2 + 7 \cdot 2^5 + 7 \cdot 2^6$

A generalized gadget construction

One can generalize the binary digit decomposition to different bases, or to vectors of messages instead of a single message, or to include a subset of the digits for varying approximations. I’ve been puzzling over an FHE scheme that does all three. In my search for clarity I came across a nice paper of Genise, Micciancio, and Polyakov called “Building an Efficient Lattice Gadget Toolkit: Subgaussian Sampling and More“, in which they state a nice general definition.

Definition: For any finite additive group $ A$, an $ A$-gadget of size $ w$ and quality $ \beta$ is a vector $ \mathbf{g} \in A^w$ such that any group element $ u \in A$ can be written as an integer combination $ u = \sum_{i=1}^w g_i x_i$ where $ \mathbf{x} = (x_1, \dots , x_w)$ has norm at most $ \beta$.

The main groups considered in my case are $ A = (\mathbb{Z}/q\mathbb{Z})^n$, where $ q$ is usually $ 2^{32}$ or $ 2^{64}$, i.e., unsigned int sizes on computers for which we get free modulus operations. In this case, a $ (\mathbb{Z}/q\mathbb{Z})^n$-gadget is a matrix $ G \in (\mathbb{Z}/q\mathbb{Z})^{n \times w}$, and the representation $ x \in \mathbb{Z}^w$ of $ u \in (\mathbb{Z}/q\mathbb{Z})^n$ satisfies $ Gx = u$.

Here $ n$ and $ q$ are fixed, and $ w, \beta$ are traded off to make the chosen gadget scheme more efficient (smaller $ w$) or better at reducing noise (smaller $ \beta$). An example of how this could work is shown in the next section by generalizing the binary digit decomposition to an arbitrary base $ B$. This allows you to use fewer digits to represent the number $ A$, but each digit may be as large as $ B$ and so the quality is $ \beta = O(B\sqrt{w})$.

One commonly-used construction is to convert an $ A$-gadget to an $ A^n$-gadget using the Kronecker product. Let $ g \in A^w$ be an $ A$-gadget of quality $ \beta$. Then the following matrix is an $ A^n$-gadget of size $ nw$ and quality $ \sqrt{n} \beta$:

$ \displaystyle G = I_n \otimes \mathbf{g}^\top = \begin{pmatrix} g_1 & \dots & g_w & & & & & & & \\ & & & g_1 & \dots & g_w & & & & \\ & & & & & & \ddots & & & \\ & & & & & & & g_1 & \dots & g_w \end{pmatrix}$

Blank spaces represent zeros, for clarity.

An example with $ A = (\mathbb{Z}/16\mathbb{Z})$. The $ A$-gadget is $ \mathbf{g} = (1,2,4,8)$. This has size $ 4 = \log(q)$ and quality $ \beta = 2 = \sqrt{1+1+1+1}$. Then for an $ A^3$-gadget, we construct

Now given a vector $ (15, 4, 7) \in \mathbb{A}^3$ we write it as follows, where each little-endian representation is concatenated into a single vector.

$ \displaystyle \mathbf{x} = \begin{pmatrix} 1\\1\\1\\1\\0\\0\\1\\0\\1\\1\\1\\0 \end{pmatrix}$

And finally,

To use the definition more rigorously, if we had to write the matrix above as a gadget “vector”, it would be in column order from left to right, $ \mathbf{g} = ((1,0,0), (2,0,0), \dots, (0,0,8)) \in A^{wn}$. Since the vector $ \mathbf{x}$ can be at worst all 1’s, its norm is at most $ \sqrt{12} = \sqrt{nw} = \sqrt{n} \beta = 2 \sqrt{3}$, as claimed above.

A signed representation in base B

As we’ve seen, the gadget decomposition trades reducing noise for a larger ciphertext size. With integers modulo $ q = 2^{32}$, this can be fine-tuned a bit more by using a larger base. Instead of PowersOf2 we could define PowersOfB, where $ B = 2^b$, such that $ B$ divides $ 2^{32}$. For example, with $ b = 8, B = 256$, we would only need to track 4 ciphertexts. And the gadget decomposition of the number we’re multiplying by would be the little-endian digits of its base-$ B$ representation. The cost here is that the maximum entry of the decomposed representation is 255.

We can fine tune this a little bit more by using a signed base-$ B$ representation. To my knowledge this is not the same thing as what computer programmers normally refer to as a signed integer, nor does it have anything to do with the two’s complement representation of negative numbers. Rather, instead of the normal base-$ B$ digits $ n_i \in \{ 0, 1, \dots, B-1 \}$ for a number $ N = \sum_{i=0}^k n_i B^i$, the signed representation chooses $ n_i \in \{ -B/2, -B/2 + 1, \dots, -1, 0, 1, \dots, B/2 – 1 \}$.

Computing the digits is slightly more involved, and it works by shifting large coefficients by $ -B/2$, and “absorbing” the impact of that shift into the next more significant digit. E.g., if $ B = 256$ and $ N = 2^{11} – 1$ (all 1s up to the 10th digit), then the unsigned little-endian base-$ B$ representation of $ N$ is $ (255, 7) = 255 + 7 \cdot 256$. The corresponding signed base-$ B$ representation subtracts $ B$ from the first digit, and adds 1 to the second digit, resulting in $ (-1, 8) = -1 + 8 \cdot 256$. This works in general because of the following “add zero” identity, where $ p$ and $ q$ are two successive unsigned digits in the unsigned base-$ B$ representation of a number.

$ \displaystyle \begin{aligned} pB^{k-1} + qB^k &= pB^{k-1} – B^k + qB^k + B^k \\ &= (p-B)B^{k-1} + (q+1)B^k \end{aligned}$

Then if $ q+1 \geq B/2$, you’d repeat and carry the 1 to the next higher coefficient.

The result of all this is that the maximum absolute value of a coefficient of the signed representation is halved from the unsigned representation, which reduces the noise growth at the cost of a slightly more complex representation (from an implementation standpoint). Another side effect is that the largest representable number is less than $ 2^{32}-1$. If you try to apply this algorithm to such a large number, the largest digit would need to be shifted, but there is no successor to carry to. Rather, if there are $ k$ digits in the unsigned base-$ B$ representation, the maximum number representable in the signed version has all digits set to $ B/2 – 1$. In our example with $ B=256$ and 32 bits, the largest digit is 127. The formula for the max representable integer is $ \sum_{i=0}^{k-1} (B/2 – 1) B^i = (B/2 – 1)\frac{B^k – 1}{B-1}$.

max_digit = base // 2 - 1
max_representable = (max_digit 
    * (base ** (num_bits // base_log) - 1) // (base - 1)
)

A simple python implementation computes the signed representation, with code copied below, in which $ B=2^b$ is the base, and $ b = \log_2(B)$ is base_log.

def signed_decomposition(
  x: int, base_log: int, total_num_bits=32) -> List[int]:
    result = []
    base = 1 << base_log
    digit_mask = (1 << base_log) - 1
    base_over_2_threshold = 1 << (base_log - 1)
    carry = 0

    for i in range(total_num_bits // base_log):
        unsigned_digit = (x >> (i * base_log)) & digit_mask
        if carry:
            unsigned_digit += carry
            carry = 0

        signed_digit = unsigned_digit
        if signed_digit >= base_over_2_threshold:
            signed_digit -= base
            carry = 1
        result.append(signed_digit)

    return result

In a future article I demonstrate the gadget decomposition in action in a practical setting called key switching, which allows one to convert an LWE ciphertext encrypted with key $ s_1$ into an LWE ciphertext encrypted with a different key $ s_2$. This operation increases noise, and so the gadget decomposition is used to reduce noise growth. Key switching is used in FHE because some operations (like bootstrapping) have the side effect of switching the encryption key.

Until then!

Regression and Linear Combinations

Recently I’ve been helping out with a linear algebra course organized by Tai-Danae Bradley and Jack Hidary, and one of the questions that came up a few times was, “why should programmers care about the concept of a linear combination?”

For those who don’t know, given vectors $ v_1, \dots, v_n$, a linear combination of the vectors is a choice of some coefficients $ a_i$ with which to weight the vectors in a sum $ v = \sum_{i=1}^n a_i v_i$.

I must admit, math books do a poor job of painting the concept as more than theoretical—perhaps linear combinations are only needed for proofs, while the real meat is in matrix multiplication and cross products. But no, linear combinations truly lie at the heart of many practical applications.

In some cases, the entire goal of an algorithm is to find a “useful” linear combination of a set of vectors. The vectors are the building blocks (often a vector space or subspace basis), and the set of linear combinations are the legal ways to combine the blocks. Simpler blocks admit easier and more efficient algorithms, but their linear combinations are less expressive. Hence, a tradeoff.

A concrete example is regression. Most people think of regression in terms of linear regression. You’re looking for a linear function like $ y = mx+b$ that approximates some data well. For multiple variables, you have, e.g., $ \mathbf{x} = (x_1, x_2, x_3)$ as a vector of input variables, and $ \mathbf{w} = (w_1, w_2, w_3)$ as a vector of weights, and the function is $ y = \mathbf{w}^T \mathbf{x} + b$.

To avoid the shift by $ b$ (which makes the function affine instead of purely linear; formulas of purely linear functions are easier to work with because the shift is like a pesky special case you have to constantly account for), authors often add a fake input variable $ x_0$ which is always fixed to 1, and relabel $ b$ as $ w_0$ to get $ y = \mathbf{w}^T \mathbf{x} = \sum_i w_i x_i$ as the final form. The optimization problem to solve becomes the following, where your data set to approximate is $ \{ \mathbf{x}_1, \dots, \mathbf{x}_k \}$.

$ \displaystyle \min_w \sum_{i=1}^k (y_i – \mathbf{w}^T \mathbf{x}_i)^2$

In this case, the function being learned—the output of the regression—doesn’t look like a linear combination. Technically it is, just not in an interesting way.

It becomes more obviously related to linear combinations when you try to model non-linearity. The idea is to define a class of functions called basis functions $ B = \{ f_1, \dots, f_m \mid f_i: \mathbb{R}^n \to \mathbb{R} \}$, and allow your approximation to be any linear combination of functions in $ B$, i.e., any function in the span of B.

$ \displaystyle \hat{f}(\mathbf{x}) = \sum_{i=1}^m w_i f_i(\mathbf{x})$

Again, instead of weighting each coordinate of the input vector with a $ w_i$, we’re weighting each basis function’s contribution (when given the whole input vector) to the output. If the basis functions were to output a single coordinate ($ f_i(\mathbf{x}) = x_i$), we would be back to linear regression.

Then the optimization problem is to choose the weights to minimize the error of the approximation.

$ \displaystyle \min_w \sum_{j=1}^k (y_j – \hat{f}(\mathbf{x}_j))^2$

As an example, let’s say that we wanted to do regression with a basis of quadratic polynomials. Our basis for three input variables might look like

$ \displaystyle \{ 1, x_1, x_2, x_3, x_1x_2, x_1x_3, x_2x_3, x_1^2, x_2^2, x_3^2 \}$

Any quadratic polynomial in three variables can be written as a linear combination of these basis functions. Also note that if we treat this as the basis of a vector space, then a vector is a tuple of 10 numbers—the ten coefficients in the polynomial. It’s the same as $ \mathbb{R}^{10}$, just with a different interpretation of what the vector’s entries mean. With that, we can see how we would compute dot products, projections, and other nice things, though they may not have quite the same geometric sensibility.

These are not the usual basis functions used for polynomial regression in practice (see the note at the end of this article), but we can already do some damage in writing regression algorithms.

A simple stochastic gradient descent

Although there is a closed form solution to many regression problems (including the quadratic regression problem, though with a slight twist), gradient descent is a simple enough solution to showcase how an optimization solver can find a useful linear combination. This code will be written in Python 3.9. It’s on Github.

First we start with some helpful type aliases

from typing import Callable, Tuple, List

Input = Tuple[float, float, float]
Coefficients = List[float]
Gradient = List[float]
Hypothesis = Callable[[Input], float]
Dataset = List[Tuple[Input, float]]

Then define a simple wrapper class for our basis functions

class QuadraticBasisPolynomials:
    def __init__(self):
        self.basis_functions = [
            lambda x: 1,
            lambda x: x[0],
            lambda x: x[1],
            lambda x: x[2],
            lambda x: x[0] * x[1],
            lambda x: x[0] * x[2],
            lambda x: x[1] * x[2],
            lambda x: x[0] * x[0],
            lambda x: x[1] * x[1],
            lambda x: x[2] * x[2],
        ]

    def __getitem__(self, index):
        return self.basis_functions[index]

    def __len__(self):
        return len(self.basis_functions)

    def linear_combination(self, weights: Coefficients) -> Hypothesis:
        def combined_function(x: Input) -> float:
            return sum(
                w * f(x)
                for (w, f) in zip(weights, self.basis_functions)
            )

        return combined_function

basis = QuadraticBasisPolynomials()

The linear_combination function returns a function that computes the weighted sum of the basis functions. Now we can define the error on a dataset, as well as for a single point

def total_error(weights: Coefficients, data: Dataset) -> float:
    hypothesis = basis.linear_combination(weights)
    return sum(
        (actual_output - hypothesis(example)) ** 2
        for (example, actual_output) in data
    )


def single_point_error(
        weights: Coefficients, point: Tuple[Input, float]) -> float:
    return point[1] - basis.linear_combination(weights)(point[0])

We can then define the gradient of the error function with respect to the weights and a single data point. Recall, the error function is defined as

$ \displaystyle E(\mathbf{w}) = \sum_{j=1}^k (y_j – \hat{f}(\mathbf{x}_j))^2$

where $ \hat{f}$ is a linear combination of basis functions

$ \hat{f}(\mathbf{x}_j) = \sum_{s=1}^n w_s f_s(\mathbf{x}_j)$

Since we’ll do stochastic gradient descent, the error formula is a bit simpler. We compute it not for the whole data set but only a single random point at a time. So the error is

$ \displaystyle E(\mathbf{w}) = (y_j – \hat{f}(\mathbf{x}_j))^2$

Then we compute the gradient with respect to the individual entries of $ \mathbf{w}$, using the chain rule and noting that the only term of the linear combination that has a nonzero contribution to the gradient for $ \frac{\partial E}{\partial w_i}$ is the term containing $ w_i$. This is one of the major benefits of using linear combinations: the gradient computation is easy.

$ \displaystyle \frac{\partial E}{\partial w_i} = -2 (y_j – \hat{f}(\mathbf{x}_j)) \frac{\partial \hat{f}}{\partial w_i}(\mathbf{x}_j) = -2 (y_j – \hat{f}(\mathbf{x}_j)) f_i(\mathbf{x}_j)$

Another advantage to being linear is that this formula is agnostic to the content of the underlying basis functions. This will hold so long as the weights don’t show up in the formula for the basis functions. As an exercise: try changing the implementation to use radial basis functions around each data point. (see the note at the end for why this would be problematic in real life)

def gradient(weights: Coefficients, data_point: Tuple[Input, float]) -> Gradient:
    error = single_point_error(weights, data_point)
    dE_dw = [0] * len(weights)

    for i, w in enumerate(weights):
        dE_dw[i] = -2 * error * basis[i](data_point[0])

    return dE_dw

Finally, the gradient descent core with a debugging helper.

import random

def print_debug_info(step, grad_norm, error, progress):
    print(f"{step}, {progress:.4f}, {error:.4f}, {grad_norm:.4f}")


def gradient_descent(
        data: Dataset,
        learning_rate: float,
        tolerance: float,
        training_callback = None,
) -> Hypothesis:
    weights = [random.random() * 2 - 1 for i in range(len(basis))]
    last_error = total_error(weights, data)
    step = 0
    progress = tolerance * 2
    grad_norm = 1.0

    if training_callback:
        training_callback(step, 0.0, last_error, 0.0)

    while abs(progress) > tolerance or grad_norm > tolerance:
        grad = gradient(weights, random.choice(data))
        grad_norm = sum(x**2 for x in grad)
        
        for i in range(len(weights)):
            weights[i] -= learning_rate * grad[i]

        error = total_error(weights, data)
        progress = error - last_error
        last_error = error
        step += 1

        if training_callback:
            training_callback(step, grad_norm, error, progress)

    return basis.linear_combination(weights)

Next create some sample data and run the optimization

def example_quadratic_data(num_points: int):
    def fn(x, y, z):
        return 2 - 4*x*y + z + z**2

    data = []
    for i in range(num_points):
        x, y, z = random.random(), random.random(), random.random()
        data.append(((x, y, z), fn(x, y, z)))

    return data


if __name__ == "__main__":
    data = example_quadratic_data(30)
    gradient_descent(
        data, 
        learning_rate=0.01, 
        tolerance=1e-06, 
        training_callback=print_debug_info
    )

Depending on the randomness, it may take a few thousand steps, but it typically converges to an error of < 1. Here’s the plot of error against gradient descent steps.

Gradient descent showing log(total error) vs number of steps, for the quadratic regression problem.

Kernels and Regularization

I’ll finish with explanations of the parentheticals above.

The real polynomial kernel. We chose a simple set of polynomial functions. This is closely related to the concept of a “kernel”, but the “real” polynomial kernel uses slightly different basis functions. It scales some of the basis functions by $ \sqrt{2}$. This is OK because a linear combination can compensate by using coefficients that are appropriately divided by $ \sqrt{2}$. But why would one want to do this? The answer boils down to a computational efficiency technique called the “Kernel trick.” In short, it allows you to compute the dot product between two linear combinations of vectors in this vector space without explicitly representing the vectors in the space to begin with. If your regression algorithm uses only dot products in its code (as is true of the closed form solution for regression), you get the benefits of nonlinear feature modeling without the cost of computing the features directly. There’s a lot more mathematical theory to discuss here (cf. Reproducing Kernel Hilbert Space) but I’ll have to leave it there for now.

What’s wrong with the radial basis function exercise? This exercise asked you to create a family of basis functions, one for each data point. The problem here is that having so many basis functions makes the linear combination space too expressive. The optimization will overfit the data. It’s like a lookup table: there’s one entry dedicated to each data point. New data points not in the training would be rarely handled well, since they aren’t in the “lookup table” the optimization algorithm found. To get around this, in practice one would add an extra term to the error corresponding to the L1 or L2 norm of the weight vector. This allows one to ensure that the total size of the weights is small, and in the L1 case that usually corresponds to most weights being zero, and only a few weights (the most important) being nonzero. The process of penalizing the “magnitude” of the linear combination is called regularization.

A parlor trick for SET

Tai-Danae Bradley is one of the hosts of PBS Infinite Series, a delightful series of vignettes into fun parts of math. The video below is about the same of SET, a favorite among mathematicians. Specifically, Tai-Danae explains how SET cards lie in (using more technical jargon) a vector space over a finite field, and that valid sets correspond to lines. If you don’t immediately know how this would work, watch the video.

In this post I want to share a parlor trick for SET that I originally heard from Charlotte Chan. It uses the same ideas from the video above, which I’ll only review briefly.

In the game of SET you see a board of cards like the following, and players look for sets.

SetCards

Image source: theboardgamefamily.com

A valid set is a triple of cards where, feature by feature, the characteristics on the cards are either all the same or all different. A valid set above is {one empty blue oval, two solid blue ovals, three shaded blue ovals}. The feature of “fill” is different on all the cards, but the feature of “color” is the same, etc.

In a game of SET, the cards are dealt in order from a shuffled deck, players race to claim sets, removing the set if it’s valid, and three cards are dealt to replace the removed set. Eventually the deck is exhausted and the game is over, and the winner is the player who collected the most sets.

There are a handful of mathematical tricks you can use to help you search for sets faster, but the parlor trick in this post adds a fun variant to the end of the game.

Play the game of SET normally, but when you get down to the last card in the deck, don’t reveal it. Keep searching for sets until everyone agrees no visible sets are left. Then you start the variant: the first player to guess the last un-dealt card in the deck gets a bonus set.

The math comes in when you discover that you don’t need to guess, or remember anything about the game that was just played! A clever stranger could walk into the room at the end of the game and win the bonus point.

Theorem: As long as every player claimed a valid set throughout the game, the information on the remaining board uniquely determines the last (un-dealt) card.

Before we get to the proof, some reminders. Recall that there are four features on a SET card, each of which has three options. Enumerate the options for each feature (e.g., {Squiggle, Oval, Diamond} = {0, 1, 2}).

While we will not need the geometry induced by this, this implies each card is a vector in the vector space $ \mathbb{F}_3^4$, where $ \mathbb{F}_3 = \mathbb{Z}/3\mathbb{Z}$ is the finite field of three elements, and the exponent means “dimension 4.” As Tai-Danae points out in the video, each SET is an affine line in this vector space. For example, if this is the enumeration:

joyofset

Source: “The Joy of Set

Then using the enumeration, a set might be given by

$ \displaystyle \{ (1, 1, 1, 1), (1, 2, 0, 1), (1, 0, 2, 1) \}$

The crucial feature for us is that the vector-sum (using the modular field arithmetic on each entry) of the cards in a valid set is the zero vector $ (0, 0, 0, 0)$. This is because $ 1+1+1 = 0, 2+2+2 = 0,$ and $ 1+2+3=0$ are all true mod 3.

Proof of Theorem. Consider the vector-valued invariant $ S_t$ equal to the sum of the remaining cards after $ t$ sets have been taken. At the beginning of the game the deck has 81 cards that can be partitioned into valid sets. Because each valid set sums to the zero vector, $ S_0 = (0, 0, 0, 0)$. Removing a valid set via normal play does not affect the invariant, because you’re subtracting a set of vectors whose sum is zero. So $ S_t = 0$ for all $ t$.

At the end of the game, the invariant still holds even if there are no valid sets left to claim. Let $ x$ be the vector corresponding to the last un-dealt card, and $ c_1, \dots, c_n$ be the remaining visible cards. Then $ x + \sum_{i=1}^n c_i = (0,0,0,0)$, meaning $ x = -\sum_{i=1}^n c_i$.

$ \square$

I would provide an example, but I want to encourage everyone to play a game of SET and try it out live!

Charlotte, who originally showed me this trick, was quick enough to compute this sum in her head. So were the other math students we played SET with. It’s a bit easier than it seems since you can do the sum feature by feature. Even though I’ve known about this trick for years, I still require a piece of paper and a few minutes.

Because this is Math Intersect Programming, the reader is encouraged to implement this scheme as an exercise, and simulate a game of SET by removing randomly chosen valid sets to verify experimentally that this scheme works.

Until next time!