Encoding Schemes in FHE

In cryptography, we need a distinction between a cleartext and a plaintext. A cleartext is a message in its natural form. A plaintext is a cleartext that is represented in a specific way to prepare it for encryption in a specific scheme. The process of taking a cleartext and turning it into a plaintext is called encoding, and the reverse is called decoding.

In homomorphic encryption, the distinction matters. Cleartexts are generally all integers, though the bit width of allowed integers can be restricted (e.g., 4-bit integers). On the other hand, each homomorphic encryption (HE) scheme has its own special process for encoding and decoding, and since HEIR hopes to support all HE schemes, I set about cataloguing the different encoding schemes. This article is my notes on what they are.

If you’re not familiar with the terms Learning With Errors LWE and and its ring variant RLWE, then you may want to read up on those Wikipedia pages first. These problems are fundamental to most FHE schemes.

Bit field encoding for LWE

A bit field encoding simply places the bits of a small integer cleartext within a larger integer plaintext. An example might be a 3-bit integer cleartext placed in the top-most bits of a 32-bit integer plaintext. This is necessary because operations on FHE ciphertexts accumulate noise, which pollutes the lower-order bits of the corresponding plaintext (BGV is a special case that inverts this, see below).

Many papers in the literature will describe “placing in the top-most bits” as “applying a scaling factor,” which essentially means pick a power of 2 $\Delta$ and encode an integer $x$ as $\Delta x$. However, by using a scaling factor it’s not immediately clear if all of the top-most bits of the plaintext are safe to use.

To wit, the CGGI (aka TFHE) scheme has a slightly more specific encoding because it requires the topmost bit to be zero in order to use its fancy programmable bootstrapping feature. Don’t worry if you don’t know what it means, but just know that in this scheme the top-most bit is set aside.

This encoding is hence most generally described by specifying a starting bit and a bit width for the location of the cleartext in a plaintext integer. The code would look like

plaintext = message << (plaintext_bit_width - starting_bit - cleartext_bit_width)

There are additional steps that come into play when one wants to encode a decimal value in LWE, which can be done with fixed point representations.

As mentioned above, the main HE scheme that uses bit field LWE encodings is CGGI, but all the schemes use this encoding as part of their encoding because all schemes need to ensure there is space for noise growth during FHE operations.

Coefficient encoding for RLWE

One of the main benefits of RLWE-based FHE schemes is that you can pack lots of cleartexts into one plaintext. For this and all the other RLWE-based sections, the cleartext space is something like $(\mathbb{Z}/3\mathbb{Z})^{1024}$, vectors of small integers of some dimension. Many folks in the FHE world call $p$ the modulus of the cleartexts. And the plaintext space is something like $(\mathbb{Z}/2^{32}\mathbb{Z})[x] / (x^{1024} + 1)$, i.e., polynomials with large integer coefficients and a polynomial degree matching the cleartext space dimension. Many people call $q$ the coefficient modulus of the plaintext space.

In the coefficient encoding for RLWE, the bit-field encoding is applied to each input, and they are interpreted as coefficients of the polynomial.

This encoding scheme is also used in CGGI, in order to encrypt a lookup table as a polynomial for use in programmable bootstrapping. But it can also be used (though it is rarely used) in the BGV and BFV schemes, and rarely because both of those schemes use the polynomial multiplication to have semantic meaning. When you encode RLWE with the coefficient encoding, polynomial multiplication corresponds to a convolution of the underlying cleartexts, when most of the time those schemes prefer that multiplication corresponds to some kind of point-wise multiplication. The next encoding will handle that exactly.

Evaluation encoding for RLWE

The evaluation encoding borrows ideas from the Discrete Fourier Transform literature. See this post for a little bit more about why the DFT and polynomial multiplication are related.

The evaluation encoding encodes a vector $(v_1, \dots, v_N)$ by interpreting it as the output value of a polynomial $p(x)$ at some implicitly determined, but fixed points. These points are usually the roots of unity of $x^N + 1$ in the ring $\mathbb{Z}/q\mathbb{Z}$ (recall, the coefficients of the polynomial ring), and one computes this by picking $q$ in such a way that guarantees the multiplicative group $(\mathbb{Z}/q\mathbb{Z})^\times$ has a generator, which plays the analogous role of a $2N$-th root of unity that you would normally see in the complex numbers.

Once you have the root of unity, you can convert from the evaluation form to a coefficient form (which many schemes need for the encryption step) via an inverse number-theoretic transform (INTT). And then, of course, one must scale the coefficients using the bit field encoding to give room for noise. The coefficient form here is considered the “encoded” version of the cleartext vector.

Aside: one can perform the bit field encoding step before or after the INTT, since the bitfield encoding is equivalent to multiplying by a constant, and scaling a polynomial by a constant is equivalent to scaling its point evaluations by the same constant. Polynomial evaluation is a linear function of the coefficients.

The evaluation encoding is the most commonly used encoding used for both the BGV and BFV schemes. And then after encryption is done, one usually NTT’s back to the evaluation representation so that polynomial multiplication can be more quickly implemented as entry-wise multiplication.

Rounded canonical embeddings for RLWE

This embedding is for a family of FHE schemes related to the CKKS scheme, which focuses on approximate computation.

Here the cleartext space and plaintext spaces change slightly. The cleartext space is $\mathbb{C}^{N/2}$, and the plaintext space is again $(\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1)$ for some machine-word-sized power of two $q$. As you’ll note, the cleartext space is continuous but the plaintext space is discrete, so this necessitates some sort of approximation.

Aside: In the literature you will see the plaintext space described as just $(\mathbb{Z}[x] / (x^N + 1)$, and while this works in principle, in practice doing so requires multiprecision integer computations, and ends up being slower than the alternative, which is to use a residue number system before encoding, and treat the plaintext space as $(\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1)$. I’ll say more about RNS encoding in the next section.

The encoding is easier to understand by first describing the decoding step. Given a polynomial $f \in (\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1)$, there is a map called the canonical embedding $\varphi: (\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1) \to \mathbb{C}^N$ that evaluates $f$ at the odd powers of a primitive $2N$-th root of unity. I.e., letting $\omega = e^{2\pi i / 2N}$, we have

\[ \varphi(f) = (f(\omega), f(\omega^3), f(\omega^5), \dots, f(\omega^{2N-1})) \]

Aside: My algebraic number theory is limited (not much beyond a standard graduate course covering Galois theory), but this paper has some more background. My understanding is that we’re viewing the input polynomials as actually sitting inside the number field $\mathbb{Q}[x] / (x^N + 1)$ (or else $q$ is a prime and the original polynomial ring is a field), and the canonical embedding is a specialization of a more general theorem that says that for any subfield $K \subset \mathbb{C}$, the Galois group $K/\mathbb{Q}$ is exactly the set of injective homomorphisms $K \to \mathbb{C}$. I don’t recall exactly why these polynomial quotient rings count as subfields of $\mathbb{C}$, and I think it is not completely trivial (see, e.g., this stack exchange question).

As specialized to this setting, the canonical embedding is a scaled isometry for the 2-norm in both spaces. See this paper for a lot more detail on that. This is a critical aspect of the analysis for FHE, since operations in the ciphertext space add perturbations (noise) in the plaintext space, and it must be the case that those perturbations decode to similar perturbations so that one can use bounds on noise growth in the plaintext space to ensure the corresponding cleartexts stay within some desired precision.

Because polynomials commute with complex conjugation ($f(\overline{z}) = \overline{f(z)}$), and roots of unity satisfy $\overline{\omega^k} = \omega^{-k}$, this canonical embedding is duplicating information. We can throw out the second half of the roots of unity and retain the same structure (the scaling in the isometry changes as well). The result is that the canonical embedding is defined $\varphi: (\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1) \to \mathbb{C}^{N/2}$ via

\[ \varphi(f) = (f(\omega), f(\omega^3), \dots, f(\omega^{N-1})) \]

Since we’re again using the bit-field encoding to scale up the inputs for noise, the decoding is then defined by applying the canonical embedding, and then applying bit-field decoding (scaling down).

This decoding process embeds the discrete polynomial space inside $\mathbb{C}^{N/2}$ as a lattice, but input cleartexts need not lie on that lattice. And so we get to the encoding step, which involves rounding to a point on the lattice, then inverting the canonical embedding, then applying the bit-field encoding to scale up for noise.

Using commutativity, one can more specifically implement this by first inverting the canonical embedding (which again uses an FFT-like operation), the result of which is in $\mathbb{C}[x] / (x^N + 1)$, then apply the bit-field encoding to scale up, then round the coefficients to be in $\mathbb{Z}[x] / (x^N + 1)$. As mentioned above, if you want the coefficients to be machine-word-sized integers, you’ll have to design this all to ensure the outputs are sufficiently small, and then treat the output as $\mathbb{Z}/q\mathbb{Z}[x] / (x^N + 1)$. Or else use a RNS mechanism.

Residue Number System Pre-processing

In all of the above schemes, the cleartext spaces can be too small for practical use. In the CGGI scheme, for example, a typical cleartext space is only 3 or 4 bits. Some FHE schemes manage this by representing everything in terms of boolean circuits, and pack inputs to various boolean gates in those bits. That is what I’ve mainly focused on, but it has the downside of increasing the number of FHE operations, requiring deeper circuits and more noise management operations, which are slow. Other approaches try to use the numerical structure of the ciphertexts more deliberately, and Sunzi’s Theorem (colloquially known as the Chinese Remainder Theorem) comes to the rescue here.

There will be two “cleartext” spaces floating around here, one for the “original” message, which I’ll call the “original” cleartext space, and one for the Sunzi’s-theorem-decomposed message, which I’ll call the “RNS” cleartext space (RNS for residue number system).

The original cleartext space size $M$ must be a product of primes or co-prime integers $M = m_1 \cdot \dots \cdot m_r$, with each $m_i$ being small enough to be compatible with the desired FHE’s encoding. E.g., for a bit-field encoding, $M$ might be large, but each $m_i$ would have to be at most a 4-bit prime (which severely limits how much we can decompose).

Then, we represent a single original cleartext message $x \in \mathbb{Z}/M\mathbb{Z}$ via its residues mod each $m_i$. I.e., $x$ becomes $r$ different cleartexts $(x \mod m_1, x \mod m_2, \dots, x \mod m_r)$ in the RNS cleartext space. From there we can either encode all the cleartexts in a single plaintext—the various RLWE encodings support this so long as $r < N$ (or $N/2$ for the canonical embedding))—or else encode them as difference plaintexts. In the latter case, the executing program needs to ensure the plaintexts are jointly processed. E.g., any operation that happens to one must happen to all, to ensure that the residues stay in sync and can be reconstructed at the end.

And finally, after decoding we use the standard reconstruction algorithm from Sunzi’s theorem to rebuild the original cleartext from the decoded RNS cleartexts.

I’d like to write a bit more about RNS decompositions and Sunzi’s theorem in a future article, because it is critical to how many FHE schemes operate, and influences a lot of their designs. For example, I glazed over how inverting the canonical embedding works in detail, and it is related to Sunzi’s theorem in a deep way. So more on that in the future.

MLIR — Verifiers

Table of Contents

Last time we defined folders and used them to enable some canonicalization and the sccp constant propagation pass for the poly dialect. This time we’ll add some additional safety checks to the dialect in the form of verifiers.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Purpose of a verifier

Verifiers ensure the types and operations in a concrete MLIR program are well-formed. Verifiers are run before and after every pass, and help to ensure that individual passes, folders, rewrite patterns, etc., emit proper IR. This allows you to enforce invariants of each op, and it makes passes simpler because they can rely on the invariants to avoid edge case checking.

The official docs for verifiers of attributes and types is here, and for operations here.

That said, most common kinds of verification code are implemented as Traits (mixins, recall the earlier article). So we’ll start with those.

Trait-based verifiers

In the last article we added SameOperandsAndResultElementType to enable poly.add to have a mixed poly and tensor-of-poly input semantics. This technically did add a verifier to the IR, but to demonstrate this more clearly I want to restrict that behavior to assert that the input and output types must all agree (all tensors-of-polys or all polys).

This commit shows the work to do this, which is mainly changing the trait name to SameOperandsAndResultType. As a result, we get a few new generated things for free. First the verification engine uses verifyTrait to check that the types agree. There, verifyInvariants is an Operation base class method that the generated code overrides when traits inject verification, the same thing that checks the type constraints on operation types. (Note: a custom verifier instead gets a method name verify to distinguish it from verifyInvariants). Since the SameOperandsAndResultType is a generic check, this doesn’t affect the generated code.

Next, an inferReturnTypes function is generated, below shown for AddOp.

::mlir::LogicalResult AddOp::inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
  inferredReturnTypes.resize(1);
  ::mlir::Builder odsBuilder(context);
  ::mlir::Type odsInferredType0 = operands[0].getType();
  inferredReturnTypes[0] = odsInferredType0;
  return ::mlir::success();
}

With a type inference hook present, we can simplify the operation’s assembly format, so that the type need only be specified once instead of three times (type, type) -> type. If we tried to simplify it before this trait, tablegen would complain that it can’t infer the types needed to build a parser.

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";

This also requires updating all the tests to use the new assembly format (I did not try to find a way to make both the functional and abbreviated forms allowed at the same time, no big deal).

Finally, this trait adds builders that don’t need you to specify the return type. Another example for AddOp:

void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs) {
  odsState.addOperands(lhs);
  odsState.addOperands(rhs);

  ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
  if (::mlir::succeeded(AddOp::inferReturnTypes(odsBuilder.getContext(),
          odsState.location, odsState.operands,
          odsState.attributes.getDictionary(odsState.getContext()),
          odsState.getRawProperties(),
          odsState.regions, inferredReturnTypes)))
    odsState.addTypes(inferredReturnTypes);
  else
    ::llvm::report_fatal_error("Failed to infer result type(s).");
}

For another example, the EvalOp can’t use SameOperandsAndResultType, because its operands require different types, but we can use AllTypesMatch which generates similar code, but restricts the verification to a subset of types. This is added in this commit.

def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>]> {
  let summary = "Evaluates a Polynomial at a given input value.";
  let arguments = (ins Polynomial:$input, AnyInteger:$point);
  let results = (outs AnyInteger:$output);
  let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";
}

You can see many similar sorts of type-inference traits here and their corresponding verifiers here.

Aside: the nomenclature for verification in regards to traits is a bit confusing. There is a concept called constraint in MLIR, and the docs describe traits as subclasses of a Constraint base class. But at the time of this writing (2023-09-11) that particular claim is wrong. There’s a TraitBase base class for traits, and the Constraint base class appears to be used for verification on type declarations and attribute delarations (the let arguments = (ins ...) stuff). These go by the names “type constraints” and “attribute constraints.” AnyInteger is an example of a type constraint because it can match multiple types, and it does inherit (indirectly) from the Constraint base class. I think type constraints are a bit more complicated because the examples I see in MLIR all involve injecting C++ code through the tablegen (you’ll see stuff like CPred) and I haven’t explored how that materializes as generated code yet.

A custom verifier

I couldn’t think of any legitimate custom verification I wanted to have in poly, so to make one up arbitrarily, I will assert that EvalOp‘s input must be a 32-bit integer type. I could do this with the type constraint in tablegen, but I will do it in a custom verifier instead for demonstration purposes.

Repeating our routine, we start by adding let hasVerifier = 1; to the op’s tablegen, and inspect the generated signature in the header, in this commit.

class EvalOp ... {
  ...
  ::mlir::LogicalResult verify();
}

And the implementation

LogicalResult EvalOp::verify() {
  return getPoint().getType().IsIneger(32)
             ? success()
             : emitOpError("argument point must be a 32-bit integer");
}

The new thing here is emitOpError, which is required because if you just return failure, then the verifier will fail but not output any information, resulting in an empty stdout.

And then to test for failure, the lit run command should pipe stderr to stdout, and have FileCheck operate on that

// tests/poly_verifier.mlir
// RUN: tutorial-opt %s 2>%t; FileCheck %s < %t

A trait-based custom verifier

We can combine these two ideas together by defining a custom trait that includes a verification hook.

Each trait in MLIR has an optional verifyTrait hook (which is checked before the custom verifier created via hasVerifier), and we can use this to define generic verifiers that apply to many ops. We’ll do this by making a verifier that extends the previous section by asserting—generically for any op—that all integer-like operands must be 32-bit. Again, this is a silly and arbitrary constraint to enforce, but it’s a demonstration.

The process of defining this is almost entirely C++, and contained in this commit. We start by defining a so-called NativeOpTrait subclass in tablegen:

def Has32BitArguments:  NativeOpTrait<"Has32BitArguments"> {
  let cppNamespace = "::mlir::tutorial::poly";
}

This has an almost trivial effect: add a template argument called ::mlir::tutorial::poly::Has32BitArguments to the generated header class for an op that has this trait. E.g., for EvalOp,

def Poly_EvalOp : Op<Poly_Dialect, "eval", [
    AllTypesMatch<["point", "output"]>, 
    Has32BitArguments
]> {
  ...
}

Generates

// PolyOps.h.inc
class EvalOp : public ::mlir::Op<
    EvalOp, ::mlir::OpTrait::ZeroRegions, 
    ...,
    ::mlir::tutorial::poly::Has32BitArguments,
    ...
> {
  ...
}

The rest is up to you in C++. Define an implementation of Has32BitArguments, following the curiously-recurring-template pattern required of OpTrait::TraitBase, and then implement a verifyTrait hook. In our case, iterate over all ops, check which ones are integer-like, and then assert the width of those.

// PolyTraits.h
template <typename ConcreteType>
class Has32BitArguments : public OpTrait::TraitBase<ConcreteType, Has32BitArguments> {
 public:
  static LogicalResult verifyTrait(Operation *op) {
    for (auto type : op->getOperandTypes()) {
      // OK to skip non-integer operand types
      if (!type.isIntOrIndex()) continue;

      if (!type.isInteger(32)) {
        return op->emitOpError()
               << "requires each numeric operand to be a 32-bit integer";
      }
    }

    return success();
  }
};

This has the upside of being more generic, but the downside of requiring awkward casting to support specific ops and their named arguments. I.e., here we can’t refer to getPoint unless we do a dynamic cast to EvalOp. Doing so would make the trait less general, so a custom op-specific verifier is more appropriate for that.

MLIR — Folders and Constant Propagation

Table of Contents

Last time we saw how to use pre-defined MLIR traits to enable upstream MLIR passes like loop-invariant-code-motion to apply to poly programs. We left out -sccp (sparse conditional constant propagation), and so this time we’ll add what is needed to make that pass work. It requires the concept of folding.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Constant Propagation vs Canonicalization

-sccp is sparse conditional constant propagation, which attempts to infer when an operation has a constant output, and then replaces the operation with the constant value. Repeating this, it “propagates” the constants as far as possible through the program. You can think of it like eagerly computing values it can during compile time, and then sticking them into the compiled program as constants.

Here’s what it looks like for arith, where all the needed tools are implemented. For an input like:

func.func @test_arith_sccp() -> i32 {
  %0 = arith.constant 7 : i32
  %1 = arith.constant 8 : i32
  %2 = arith.addi %0, %0 : i32
  %3 = arith.muli %0, %0 : i32
  %4 = arith.addi %2, %3 : i32
  return %2 : i32
}

The output of tutorial-opt --sccp is

func.func @test_arith_sccp() -> i32 {
  %c63_i32 = arith.constant 63 : i32
  %c49_i32 = arith.constant 49 : i32
  %c14_i32 = arith.constant 14 : i32
  %c8_i32 = arith.constant 8 : i32
  %c7_i32 = arith.constant 7 : i32
  return %c14_i32 : i32
}

Note two additional facts: sccp doesn’t delete dead code, and what is not shown here is the main novelty in sccp, which is that it can propagate constants through control flow (ifs and loops).

A related concept is the idea of canonicalization, which gets its own --canonicalize pass, and which hides a lot of the heavy lifting in MLIR. Canonicalize overlaps a little bit with sccp, in that it also computes constants and materializes them in the IR. Take, for example, the --canonicalize pass on the same IR:

func.func @test_arith_sccp() -> i32 {
  %c14_i32 = arith.constant 14 : i32
  return %c14_i32 : i32
}

The intermediate constants are all pruned, and all that remains is the return value and no operations. Canonicalize cannot propagate constants through control flow, and as such should be thought of as more “local” than sccp.

Both of these, however, are supported via folding, which is the process of taking series of ops and merging them together into simpler ops. It also requires our dialect has some sort of constant operation, which is inserted (“materialized”) with the results of a fold. Folding and canonicalization are more general than what I’m showing here, so we’ll come back to what else they can do in a future article.

The rough outline of what is needed to support folding in this way is:

  • Adding a constant operation
  • Adding a materialization hook
  • Adding folders for each op

This would result in a situation (test case commit) as follows. Starting from

%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%1 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
%2 = poly.mul %1, %1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %1, %1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>

We would get

%0 = poly.constant dense<[2, 8, 20, 24, 18]> : tensor<5xi32> : <10>
%1 = poly.constant dense<[1, 4, 10, 12, 9]> : tensor<5xi32> : <10>
%2 = poly.constant dense<[1, 2, 3]> : tensor<3xi32> : <10>

Making a constant op

Currently we’re imitating a constant polynomial construction by combing a from_tensor op with arith.constant. Like this:

%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>

While a constant operation might combine them into one op.

%0 = poly.constant dense<[2, 8, 20, 24, 18]> : !poly.poly<10>

The from_tensor op can also be used to build a polynomial from data, not just constants, so it’s worth having around even after we implement poly.constant.

Having a dedicated constant operation has benefits explained in the MLIR documentation on folding. What’s relevant here is that fold can be used to signal to passes like sccp that the result of an op is constant (statically known), or it can be used to say that the result of an op is equivalent to a pre-existing value created by a different op. For the constant case, a materializeConstant hook is also needed to tell MLIR how to take the constant result and turn it into a proper IR op.

The constant op itself, in this commit, comes with two new concepts, the ConstantLike trait and an argument that is an attribute constraint.

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {   // new
  let summary = "Define a constant polynomial via an attribute.";
  let arguments = (ins AnyIntElementsAttr:$coefficients);    // new
  let results = (outs Polynomial:$output);
  let assemblyFormat = "$coefficients attr-dict `:` type($output)";
}

The ConstantLike attribute is checked here during folding via the constant op matcher as an assertion. [Aside: I’m not sure why the trait is specifically required, so long as the materialization function is present on the dialect; it just seems like this check is used for assertions. Perhaps it’s just a safeguard.]

Next we have the line let arguments = (ins AnyIntElementsAttr:$coefficients); This defines the input to the op as an attribute (statically defined data) rather than a previous SSA value. The AnyIntElementsAttr is itself an attribute constraint, allowing any attribute that is has the IntElementsAttrBase as a base class to be used (e.g., 32-bit or 64-bit integer attributes). This means that we could use all of the following syntax forms:

%10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%11 = poly.constant dense<[2, 3, 4]> : tensor<3xi8> : !poly.poly<10>
%12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10>
%13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10>

Adding folders

We add folders in these commits:

Each has the same structure: add let hasFolder = 1; to the tablegen for the op, which adds a header declaration of the following form (noting that the signature would be different if the op has more than one result value, see docs).

OpFoldResult <OpName>::fold(<OpName>::FoldAdaptor adaptor);

Then we implement it in PolyOps.cpp. The semantics of this function are such that, if the fold decides the op should be replaced with a constant, it must return an attribute representing that constant, which can be given as the input to a poly.constant. The FoldAdaptor is a shim that has the same method names as an instance of the op’s C++ class, but arguments that have been folded themselves are replaced with Attribute instances representing the constants they were folded with. This will be relevant for folding add and mul, since the body needs to actually compute the result eagerly, and needs access to the actual values to do so.

For poly.constant the timplementation is trivial: you just return the input attribute.

OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
  return adaptor.getCoefficients();
}

The from_tensor op is similar, but has an extra cast that acts as an assertion, since the tensor might have been constructed with weird types we don’t want as input. If the dyn_cast fails, the result is nullptr, which is cast by MLIR to a failed OpFoldResult.

OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
  // Returns null if the cast failed, which corresponds to a failed fold.
  return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
}

The poly binary ops are slightly more complicated since they are actually doing work. Each of these fold methods effectively takes as input two DenseIntElementsAttr for each operand, and expects us to return another DenseIntElementsAttr for the result.

For add/sub which are elementwise operations on the coefficients, we get to use an existing upstream helper method, constFoldBinaryOp, which through some template metaprogramming wizardry, allows us to specify only the elementwise operation itself.

OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
  return constFoldBinaryOp<IntegerAttr, APInt>(
      adaptor.getOperands(), [&](APInt a, APInt b) { return a + b; });
}

For mul, we have to write out the multiplication routine manually. In what’s below, I’m implementing the naive textbook polymul algorithm, which could be optimized if one expects people to start compiling programs with large, static polynomials in them.

OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
  auto lhs = cast<DenseIntElementsAttr>(adaptor.getOperands()[0]);
  auto rhs = cast<DenseIntElementsAttr>(adaptor.getOperands()[1]);
  auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
  auto maxIndex = lhs.size() + rhs.size() - 1;

  SmallVector<APInt, 8> result;
  result.reserve(maxIndex);
  for (int i = 0; i < maxIndex; ++i) {
    result.push_back(APInt((*lhs.begin()).getBitWidth(), 0));
  }

  int i = 0;
  for (auto lhsIt = lhs.value_begin<APInt>(); lhsIt != lhs.value_end<APInt>();
       ++lhsIt) {
    int j = 0;
    for (auto rhsIt = rhs.value_begin<APInt>(); rhsIt != rhs.value_end<APInt>();
         ++rhsIt) {
      // index is modulo degree because poly's semantics are defined modulo x^N = 1.
      result[(i + j) % degree] += *rhsIt * (*lhsIt);
      ++j;
    }
    ++i;
  }

  return DenseIntElementsAttr::get(
      RankedTensorType::get(static_cast<int64_t>(result.size()),
                            IntegerType::get(getContext(), 32)),
      result);
}

Adding a constant materializer

Finally, we add a constant materializer. This is a dialect-level feature, so we start by adding let hasConstantMaterializer = 1; to the dialect tablegen, and observing the newly generated header signature:

Operation *PolyDialect::materializeConstant(
    OpBuilder &builder, Attribute value, Type type, Location loc);

The Attribute input represents the result of each folding step above. The Type is the desired result type of the op, which is needed in cases like arith.constant where the same attribute can generate multiple different types via different interpretations of a hex string or splatting with a result tensor that has different dimensions.

In our case the implementation is trivial: just construct a constant op from the attribute.

Operation *PolyDialect::materializeConstant(
    OpBuilder &builder, Attribute value, Type type, Location loc) {
  auto coeffs = dyn_cast<DenseIntElementsAttr>(value);
  if (!coeffs)
    return nullptr;
  return builder.create<ConstantOp>(loc, type, coeffs);
}

Other kinds of folding

While this has demonstrated a generic kind of folding with respect to static constants, many folding functions in MLIR use simple matches to determine when an op can be replaced with a value from a previously computed op.

Take, for example, the complex dialect (for complex numbers). A complex.create op constructs a complex number from real and imaginary parts. A folder in that dialect checks for a pattern like complex.create(complex.re(%z), complex.im(%z)), and replaces it with %z directly. The arith dialect similarly has folds for things like a-b + b -> a and a + 0 -> a.

However, most work on simplifying an IR according to algebraic rules belongs in the canonicalization pass, since while it supports folding, it also supports general rewrite patterns that are allowed to delete and create ops as needed to simplify the IR. We’ll cover canonicalization in more detail in a future article. But just remember, folds may only modify the single operation being folded, use existing SSA values, and may not create new ops. So they are limited in power and decidedly local operations.

MLIR — Using Traits

Table of Contents

Last time we defined a new dialect poly for polynomial arithmetic. This time we’ll spruce up the dialect by adding some pre-defined MLIR traits, and see how the application of traits enables some general purpose passes to optimize poly programs.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Traits and Loop Invariant Code Motion

As a compiler toolchain, MLIR heavily emphasizes code reuse. This is accomplished largely through two constructs: traits and interfaces. An interface is an interface in the normal programming sense of the word: a set of function signatures implemented by a type and providing some limited-scope behavior. Applied to MLIR, you can implement interfaces on operations and types, and then passes can operate at the interface level. Traits are closely related: a trait is an interface with no methods. Traits can just be “slapped on” an operation and passes can magically start working with them. They can also serve as mixins for common op verification routines, type inference, and more.

In this article I’ll show how adding traits to the operations in the poly dialect (defined last time) allows us to reuse existing MLIR passes on poly constructs. In future articles we’ll see how to define new traits and interfaces. But for now, existing traits are an extremely simple way to start using the batteries included in MLIR on new dialects.

As mentioned, one applies traits primarily to enable passes to do things with custom operations. So let’s start from the passes. The general transformation passes list includes a pass called loop invariant code motion. This checks loop bodies for any operations that don’t need to be in the loop, and moves them out of the loop body. Using this requires us to add two traits to ops to express that they are safe to move around. The first is called NoMemoryEffect (which is technically an empty implementation of an interface) that asserts the operation does not have any side effects related to writing to memory. The second is AlwaysSpeculatable (technically a list of two traits), which says that an operation is allowed to be “speculatively executed,” i.e., computed early. If it is speculatable, then the compiler can move the op to another location. If not, say it reads from a memory location that can be written to, there is an earliest point before which it’s not safe to move the op.

Loop invariant code motion takes ops with these two traits, and hoists them outside of loops when the operation’s operands are unchanged by the body of the loop. Conveniently, MLIR also defines a single named list of traits called Pure, which is NoMemoryEffect and AlwaysSpeculatable. So we can just add the trait name to our tablegen op definition (via a template parameter that defaults to an empty list of traits) as in this commit.

//lib/Dialect/Poly/PolyOps.td
@@ -3,9 +3,10 @@

 include "PolyDialect.td"
 include "PolyTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"


-class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic> {
+class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure]> {
   let arguments = (ins Polynomial:$lhs, Polynomial:$rhs);

This commit adds the boilerplate to register all default MLIR passes in tutorial-opt, and adds an example test asserting a poorly-placed poly.mul is hoisted out of the loop body.

// RUN: tutorial-opt %s --loop-invariant-code-motion > %t
// RUN: FileCheck %s < %t
... <setup> ...
// CHECK: poly.mul
// CHECK: affine.for
%ret_val = affine.for %i = 0 to 100 iter_args(%sum_iter = %p0) -> !poly.poly<10> {
  // The poly.mul should be hoisted out of the loop.
  // CHECK-NOT: poly.mul
  %2 = poly.mul %p0, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
  %sum_next = poly.add %sum_iter, %2 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
  affine.yield %sum_next : !poly.poly<10>
}

In the generated C++ code, adding new traits or interfaces adds new template arguments in the op’s class definition:

// PolyOps.h.inc
class SubOp : public ::mlir::Op<SubOp, 
::mlir::OpTrait::ZeroRegions, 
::mlir::OpTrait::OneResult, 
::mlir::OpTrait::OneTypedResult<::mlir::tutorial::poly::PolynomialType>::Impl, 
::mlir::OpTrait::ZeroSuccessors, 
::mlir::OpTrait::NOperands<2>::Impl, 
::mlir::OpTrait::OpInvariants,
::mlir::ConditionallySpeculatable::Trait,            // <-- new
::mlir::OpTrait::AlwaysSpeculatableImplTrait,   // <-- new
::mlir::MemoryEffectOpInterface::Trait>          // <--- new
{ ... }

And NoMemoryEffect adds a trivial implementation of the memory effects interface:

// PolyOps.h.inc
void SubOp::getEffects(
  ::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) {
}

As far as using traits goes, this is it. However, to figure out what each trait does, you have to dig through the pass implementations a bit. That, and all the “helper” definitions like Pure that combine multiple traits are not documented, nor is the complete list of available traits and interfaces (the traits list is missing quite a few, like ConstantLike, Involution, Idempotent, etc.). When I first drafted this article, I had only applied the AlwaysSpeculatable trait list to the Poly_BinOp, and I was confused when --loop-invariant-code-motion was a no-op. I had to dig into the passes implementation here to actually see that it also needs isMemoryEffectFree which uses MemoryEffectOpInterface.

So next we’ll explore the general passes and traits, and add relevant traits to the poly dialect.

Passes already handled by Pure, or not relevant to poly

control-flow-sink moves ops that are only used in one branch of a conditional into the relevant branch. Requires the op to be memory-effect free, which Pure already covers. To demonstrate, I added the Pure trait to poly.from_tensor and added a test in this commit.

cse is constant subexpression elimination, which removes unnecessarily repeated computations when possible. Again, a lack of memory-effects suffices. Demonstrated in this commit.

-inline inlines function calls, which does not apply to poly.

-mem2reg replaces memory store/loads with direct usage of the underlying value, when possible. Should not require any changes and not interesting enough for me to demo here.

-remove-dead-values does things like remove function arguments that are unused, or return values that are not used by any caller. Should not require any changes.

-sroa is “scalar replacement of aggregates,” which seems like it is about reshuffling memory allocations around. Not really sure why this is useful.

-symbol-dce eliminates dead private functions, which does not apply to poly.

Punting one pass to next time

-sccp is sparse conditional constant propagation, which attempts to infer when an operation has a constant output, and then replaces the operation with the constant value. Repeating this, it “propagates” the constants as far as possible through the program. To support this requires a bit of extra work and requires me to introduces some more concepts, so I’ll do that next time.

Elementwise mappings

Now that we’ve gone through the most generic passes, I’ll cover some remaining traits I’m aware of.

There are a number of traits that extend scalar operations to tensor operations and vice versa. Elementwise, Scalarizable, Tensorizable, and Vectorizable, whose docs you can read in detail here, essentially allow you to use ops that work on scalars in tensors in the natural way. The trait list ElementwiseMappable combines them into a single trait. This commit demonstrates how adding the trait allows the poly.add op to magically work on tensor arguments. It also requires relaxing the ins arguments to the op in the tablegen, and we do this by using a so-called type constraint that permits polynomials and tensors containing them

// PolyOps.td
def PolyOrContainer : TypeOrContainer<Polynomial, "poly-or-container">;

class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, ElementwiseMappable]> {
  let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs);
  let results = (outs PolyOrContainer:$output);
  ...
}

Behind the hood, the generated code has a new type checking routine used in parsing:

// PolyOps.cpp.inc
static ::mlir::LogicalResult __mlir_ods_local_type_constraint_PolyOps0(
    ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
    unsigned valueIndex) {
  if (!(((::llvm::isa<::mlir::tutorial::poly::PolynomialType>(type))) || ((((::llvm::isa<::mlir::VectorType>(type))) && ((::llvm::cast<::mlir::VectorType>(type).getRank() > 0))) && ([](::mlir
::Type elementType) { return (::llvm::isa<::mlir::tutorial::poly::PolynomialType>(elementType)); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || (((::llvm::isa<::mlir::TensorT
ype>(type))) && ([](::mlir::Type elementType) { return (::llvm::isa<::mlir::tutorial::poly::PolynomialType>(elementType)); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))))) {
    return op->emitOpError(valueKind) << " #" << valueIndex
        << " must be poly-or-container, but got " << type;
  }
  return ::mlir::success();
}

Moreover, the accessors that previously returned a PolynomialType or ::mlir::TypedValue<PolynomialType> now must return a more generic ::mlir::Type or ::mlir::Value because the values could be tensors or vectors as well. It’s left to the caller to type-switch or dyn_cast these manually during passes.

Note that after adding this, we now get the following legal syntax:

    %0 = ... -> !poly.poly<10>
    %1 = ... -> !poly.poly<10>

    %2 = tensor.from_elements %0, %1 : tensor<2x!poly.poly<10>>
    %3 = poly.add %2, %0 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>>

I would say the implied semantics is that the second poly is constant across the mapping.

Verifier traits

Some traits add extra verification checks to the operation as mixins. While we’ll cover custom verifiers in a future article, for now we can notice that the following is legal (with or without ElementwiseMappable):

    %0 = ... -> !poly.poly<10>
    %1 = ... -> !poly.poly<9>
    %2 = poly.add %0, %1 : (!poly.poly<10>, !poly.poly<9>) -> !poly.poly<10>

While one could make this legal by defining the semantics of add to embed the smaller-degree polynomial ring into the larger, we’re demonstrating traits, so we’ll add the SameOperandsAndResultElementType trait (a vectorized cousin of SameOperandsAndResultType), which asserts that the poly type in all the arguments (and elements of containers) are the same. This commit does it.

Last few unneeded traits

Involution is for operations that are their own opposite, $f(f(x)) = x$. This is a common math concept, and if we had something like a poly.neg op, it would be perfect for it. Adding it would enable a free canonicalization to remove the repeated ops.

Idempotent is for operations $f(x)$ for which $f(f(x)) = f(x)$. This is a common math concept, but none of the poly ops have it. A rounding op like ceiling or floor would. If it did apply, adding this trait would enable a free canonicalization like involution.

Broadcastable handles broadcast semantics for tensor/vector ops, which is not relevant.

Commutative is for ops whose arguments can be reordered, (including across multiple ops of the same kind). This is used by a pass here for the purpose of simplifying pattern matching, but as far as I can tell the pass is never registered or manually applied so this trait is a no-op.

AffineScope, IsolatedFromAbove, Terminator, SingleBlock, and SingleBlockImplicitTerminator are for ops that have regions and scoping rules, which we don’t.

SymbolTable is for ops that define a symbol table, which I think is mainly module (in which the symbols are functions), and this is mainly used in the -inline and -symbol-dce passes.