Modulus Switching in LWE

The Learning With Errors problem is the basis of a few cryptosystems, and a foundation for many fully homomorphic encryption (FHE) schemes. In this article I’ll describe a technique used in some of these schemes called modulus switching.

In brief, an LWE sample is a vector of values in \mathbb{Z}/q{Z} for some q, and in LWE cryptosystems an LWE sample can be modified so that it hides a secret message m. Modulus switching allows one to convert an LWE encryption from having entries in \mathbb{Z}/q{Z} to entries in some other \mathbb{Z}/q'{Z}, i.e., change the modulus from q to q' < q.

The reason you’d want to do this are a bit involved, so I won’t get into them here and instead back-reference this article in the future.

LWE encryption

Briefly, the LWE encryption scheme I’ll use has the following parameters:

  • A plaintext space \mathbb{Z}/q\mathbb{Z}, where q \geq 2 is a positive integer. This is the space that the underlying message comes from.
  • An LWE dimension n \in \mathbb{N}.
  • A discrete Gaussian error distribution D with a mean of zero and a fixed standard deviation.

An LWE secret key is defined as a vector in \{0, 1\}^n (uniformly sampled). An LWE ciphertext is defined as a vector a = (a_1, \dots, a_n), sampled uniformly over (\mathbb{Z} / q\mathbb{Z})^n, and a scalar b = \langle a, s \rangle + m + e, where e is drawn from D and all arithmetic is done modulo q.

Without the error term, an attacker could determine the secret key from a polynomial-sized collection of LWE ciphertexts with something like Gaussian elimination. The set of samples looks like a linear (or affine) system, where the secret key entries are the unknown variables. With an error term, the problem of solving the system is believed to be hard, and only exponential time/space algorithms are known.

However, the error term in an LWE encryption encompasses all of the obstacles to FHE. For starters, if your message is m=1 and the error distribution is wide (say, a standard deviation of 10), then the error will completely obscure the message from the start. You can’t decrypt the LWE ciphertext because you can’t tell if the error generated in a particular instance was 9 or 10. So one thing people do is have a much smaller cleartext space (actual messages) and encode cleartexts as plaintexts by putting the messages in the higher-order bits of the plaintext space. E.g., you can encode 10-bit messages in the top 10 bits of a 32-bit integer, and leave the remaining 22 bits of the plaintext for the error distribution.

Moreover, for FHE you need to be able to add and multiply ciphertexts to get the corresponding sum/product of the underlying plaintexts. One can easily see that adding two LWE ciphertexts produces an LWE ciphertext of the sum of the plaintexts (multiplication is harder and beyond the scope of this article). Summing ciphertexts also sums the error terms together. So the error grows with each homomorphic operation, and eventually the error may overtake the message, at which point decryption fails. How to deal with this error accumulation is 99% of the difficulty of FHE.

Finally, because the error can be negative, even if you store a message in the high-order bits of the plaintext, you can’t decrypt by simply clearing the low order error bits. In that case an error of -1 would result in a corrupted message. Instead, to decrypt, we round the value b - \langle a, s \rangle = m + e to the nearest multiple of 2^k, where $k$ is the number of bits “reserved” for error, as described above. In particular, decryption will only succeed if the error is small enough in absolute value. So to make this work in practice, one must coordinate the encoding scheme (how many bits to reserve for error), the dimension of the vector a, and the standard deviation of the error distribution.

Modulus switching

With a basic outline of an LWE ciphertext, we can talk about modulus switching.

Start with an LWE ciphertext for the plaintext m. Call it (a_1, \dots, a_n, b) \in (\mathbb{Z}/q\mathbb{Z})^{n+1}, where

\displaystyle b = \left ( \sum_{i=1}^n a_i s_i \right ) + m + e_{\textup{original}}

Given q' < q, we would like to produce a vector (a'_1, \dots, a'_n, b') \in (\mathbb{Z}/q'\mathbb{Z})^{n+1} (all that has changed is I’ve put a prime on all the terms to indicate which are changing, most notably the new modulus q') that also encrypts m, without knowing m or e_{\textup{original}}, i.e., without access to the secret key.

Failed attempt: why not simply reduce each entry in the ciphertext vector modulo q'? That would set a'_i = a_i \mod q' and b' = b \mod q'. Despite the fact that this operation produces a perfectly valid equation, it won’t work. The problem is that taking m \mod q' destroys part or all of the underlying message. For example, say x is a 12-bit number stored in the top 12 bits of the plaintext, i.e., m = x \cdot 2^{20}. If q' = 2^{15}, then the message is a multiple of q' already, so the proposed modulus produces zero.

For this reason, we can’t hope to perfectly encrypt m, as the output ciphertext entries may not have a modulus large enough to represent m at all. Rather, we can only hope to encrypt something like “the message x that’s encoded in m, but instead with x stored in lower order bits than m originally used.” In more succinct terms, we can hope to encrypt m' = m q' / q. Indeed, the operation of m \mapsto m q' / q shifts up by \log_2(q') many bits (temporarily exceeding the maximum allowable bit length), and then shifting down by \log_2(q) many bits.

For example, say the number x=7 is stored in the top 3 bits of a 32-bit unsigned integer (q = 2^{32}), i.e., m = 7 \cdot 2^{29} and q' = 2^{10}. Then m q' / q = 7 \cdot 2^{29} \cdot 2^{10} / 2^{32} = 7 \cdot 2^{29+10 - 32} = 7 \cdot 2^7, which stores the same underlying number x=7, but in the top three bits of a 10-bit message. In particular, x is in the same “position” in the plaintext space, while the plaintext space has shrunk around it.

Side note: because of this change to the cleartext-to-plaintext encoding, the decryption/decoding steps before and after a modulus switch are slightly different. In decryption you use different moduli, and in decoding you round to different powers of 2.

So the trick is instead to apply z \mapsto z q' / q to all the entries of the LWE ciphertext vector. However, because the entries like a_i use the entire space of bits in the plaintext, this transformation will not necessarily result in an integer. So we can round the result to an integer and analyze that. The final proposal for a modulus switch is

\displaystyle a'_i = \textup{round}(a_i q' / q)

\displaystyle b' = \textup{round}(b q' / q)

Because the error growth of LWE ciphertexts permeates everything, in addition to proving this transformation produces a valid ciphertext, we also have to understand how it impacts the error term.

Analyzing the modulus switch

The statement summarizing the last section:

Theorem: Let \mathbf{c} = (a_1, \dots, a_n, b) \in (\mathbb{Z}/q\mathbb{Z})^{n+1} be an LWE ciphertext encrypting plaintext m with error term e_\textup{original}. Let q' < q. Then c' = \textup{round}(\mathbf{c} q' / q) (where rounding is performed entrywise) is an LWE encryption of m' = m q' / q, provided m' is an integer.

Proof. The only substantial idea is that \textup{round}(x) = x + \varepsilon, where |\varepsilon| \leq 0.5. This is true by the definition of rounding, but that particular way to express it allows us to group the error terms across a sum-of-rounded-things in isolation, and then everything else has a factor of q'/q that can be factored out. Let’s proceed.

Let c' = (a'_1, \dots, a'_n, b'), where a'_i = \textup{round}(a_i q' / q) and likewise for b'. need to show that b' = \left ( \sum_{i=1}^n a'_i s_i \right ) + m q' / q + e_{\textup{new}}, where e_{\textup{new}} is a soon-to-be-derived error term.

Expanding b' and using the “only substantial idea” above, we get

\displaystyle b' = \textup{round}(b q' / q) = bq'/q + \varepsilon_b

For some \varepsilon_b with magnitude at most 1/2. Continuing to expand, and noting that b is related to the a_i only modulo q, we have

\displaystyle \begin{aligned} b' &= bq'/q + \varepsilon_b \\ b' &= \left ( \left ( \sum_{i=1}^n a_i s_i \right ) + m + e_{\textup{original}} \right ) \frac{q'}{q} + \varepsilon_b \mod q \end{aligned}

Because we’re switching moduli, it makes sense to rewrite this over the integers, which means we add a term Mq for some integer M and continue to expand

\displaystyle \begin{aligned} b' &= \left ( \left ( \sum_{i=1}^n a_i s_i \right ) + m + e_{\textup{original}} + Mq \right ) \frac{q'}{q} + \varepsilon_b \\ &= \left ( \sum_{i=1}^n \left ( a_i \frac{q'}{q} \right) s_i \right ) + m \frac{q'}{q} + e_{\textup{original}}\frac{q'}{q} + Mq \frac{q'}{q} + \varepsilon_b \\ &= \left ( \sum_{i=1}^n \left ( a_i \frac{q'}{q} \right) s_i \right ) + m' + e_{\textup{original}}\frac{q'}{q} + Mq' + \varepsilon_b \end{aligned}

The terms with a_i are still missing their rounding, so, just like b', rewrite a'_i = a_i q'/q + \varepsilon_i as a_i q'/q = a'_i - \varepsilon_i, expanding, simplifying, and finally reducing modulo q' to get

\displaystyle \begin{aligned} b' &= \left ( \sum_{i=1}^n \left ( a'_i - \varepsilon_i \right) s_i \right ) + m' + e_{\textup{original}}\frac{q'}{q} + Mq' + \varepsilon_b \\ &= \left ( \sum_{i=1}^n a'_i s_i \right ) - \left ( \sum_{i=1}^n \varepsilon_i s_i \right) + m' + e_{\textup{original}}\frac{q'}{q} + Mq' + \varepsilon_b \\   &= \left ( \sum_{i=1}^n a'_i s_i \right ) + m' + Mq' +  \left [ e_{\textup{original}}\frac{q'}{q} - \left ( \sum_{i=1}^n \varepsilon_i s_i \right)  + \varepsilon_b \right ] \\   &=  \left ( \sum_{i=1}^n a'_i s_i \right ) + m' + \left [ e_{\textup{original}}\frac{q'}{q} - \left ( \sum_{i=1}^n \varepsilon_i s_i \right)  + \varepsilon_b \right ]  \mod q' \end{aligned}

Define the square bracketed term as e_{\textup{new}}, and we have proved the theorem.

\square

The error after modulus switching is laid out. It’s the original error scaled, plus at most n+1 terms, each of which is at most 1/2. However, note that this is larger than it appears. If the new modulus is, say, q'=1024, and the dimension is n = 512, then in the worst case the error right after modulus switching will leave us only 1 bit left for the message. This is not altogether unrealistic, as production (128-bit) security parameters for LWE put n around 600. But it is compensated for by the fact that the secret s is chosen uniformly at random, so in expectation only half the bits will be set. In other words, the error is more like \frac{1}{2} \left \| s \right \|_1, and you can bound the probability that the error exceeds, say, \frac{n}{4} + \sqrt{n \log n} using standard probabilistic arguments. That gives you a few extra bits to work with.

Until next time!