Encoding Schemes in FHE

In cryptography, we need a distinction between a cleartext and a plaintext. A cleartext is a message in its natural form. A plaintext is a cleartext that is represented in a specific way to prepare it for encryption in a specific scheme. The process of taking a cleartext and turning it into a plaintext is called encoding, and the reverse is called decoding.

In homomorphic encryption, the distinction matters. Cleartexts are generally all integers, though the bit width of allowed integers can be restricted (e.g., 4-bit integers). On the other hand, each homomorphic encryption (HE) scheme has its own special process for encoding and decoding, and since HEIR hopes to support all HE schemes, I set about cataloguing the different encoding schemes. This article is my notes on what they are.

If you’re not familiar with the terms Learning With Errors LWE and and its ring variant RLWE, then you may want to read up on those Wikipedia pages first. These problems are fundamental to most FHE schemes.

Bit field encoding for LWE

A bit field encoding simply places the bits of a small integer cleartext within a larger integer plaintext. An example might be a 3-bit integer cleartext placed in the top-most bits of a 32-bit integer plaintext. This is necessary because operations on FHE ciphertexts accumulate noise, which pollutes the lower-order bits of the corresponding plaintext (BGV is a special case that inverts this, see below).

Many papers in the literature will describe “placing in the top-most bits” as “applying a scaling factor,” which essentially means pick a power of 2 $\Delta$ and encode an integer $x$ as $\Delta x$. However, by using a scaling factor it’s not immediately clear if all of the top-most bits of the plaintext are safe to use.

To wit, the CGGI (aka TFHE) scheme has a slightly more specific encoding because it requires the topmost bit to be zero in order to use its fancy programmable bootstrapping feature. Don’t worry if you don’t know what it means, but just know that in this scheme the top-most bit is set aside.

This encoding is hence most generally described by specifying a starting bit and a bit width for the location of the cleartext in a plaintext integer. The code would look like

plaintext = message << (plaintext_bit_width - starting_bit - cleartext_bit_width)

There are additional steps that come into play when one wants to encode a decimal value in LWE, which can be done with fixed point representations.

As mentioned above, the main HE scheme that uses bit field LWE encodings is CGGI, but all the schemes use this encoding as part of their encoding because all schemes need to ensure there is space for noise growth during FHE operations.

Coefficient encoding for RLWE

One of the main benefits of RLWE-based FHE schemes is that you can pack lots of cleartexts into one plaintext. For this and all the other RLWE-based sections, the cleartext space is something like $(\mathbb{Z}/3\mathbb{Z})^{1024}$, vectors of small integers of some dimension. Many folks in the FHE world call $p$ the modulus of the cleartexts. And the plaintext space is something like $(\mathbb{Z}/2^{32}\mathbb{Z})[x] / (x^{1024} + 1)$, i.e., polynomials with large integer coefficients and a polynomial degree matching the cleartext space dimension. Many people call $q$ the coefficient modulus of the plaintext space.

In the coefficient encoding for RLWE, the bit-field encoding is applied to each input, and they are interpreted as coefficients of the polynomial.

This encoding scheme is also used in CGGI, in order to encrypt a lookup table as a polynomial for use in programmable bootstrapping. But it can also be used (though it is rarely used) in the BGV and BFV schemes, and rarely because both of those schemes use the polynomial multiplication to have semantic meaning. When you encode RLWE with the coefficient encoding, polynomial multiplication corresponds to a convolution of the underlying cleartexts, when most of the time those schemes prefer that multiplication corresponds to some kind of point-wise multiplication. The next encoding will handle that exactly.

Evaluation encoding for RLWE

The evaluation encoding borrows ideas from the Discrete Fourier Transform literature. See this post for a little bit more about why the DFT and polynomial multiplication are related.

The evaluation encoding encodes a vector $(v_1, \dots, v_N)$ by interpreting it as the output value of a polynomial $p(x)$ at some implicitly determined, but fixed points. These points are usually the roots of unity of $x^N + 1$ in the ring $\mathbb{Z}/q\mathbb{Z}$ (recall, the coefficients of the polynomial ring), and one computes this by picking $q$ in such a way that guarantees the multiplicative group $(\mathbb{Z}/q\mathbb{Z})^\times$ has a generator, which plays the analogous role of a $2N$-th root of unity that you would normally see in the complex numbers.

Once you have the root of unity, you can convert from the evaluation form to a coefficient form (which many schemes need for the encryption step) via an inverse number-theoretic transform (INTT). And then, of course, one must scale the coefficients using the bit field encoding to give room for noise. The coefficient form here is considered the “encoded” version of the cleartext vector.

Aside: one can perform the bit field encoding step before or after the INTT, since the bitfield encoding is equivalent to multiplying by a constant, and scaling a polynomial by a constant is equivalent to scaling its point evaluations by the same constant. Polynomial evaluation is a linear function of the coefficients.

The evaluation encoding is the most commonly used encoding used for both the BGV and BFV schemes. And then after encryption is done, one usually NTT’s back to the evaluation representation so that polynomial multiplication can be more quickly implemented as entry-wise multiplication.

Rounded canonical embeddings for RLWE

This embedding is for a family of FHE schemes related to the CKKS scheme, which focuses on approximate computation.

Here the cleartext space and plaintext spaces change slightly. The cleartext space is $\mathbb{C}^{N/2}$, and the plaintext space is again $(\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1)$ for some machine-word-sized power of two $q$. As you’ll note, the cleartext space is continuous but the plaintext space is discrete, so this necessitates some sort of approximation.

Aside: In the literature you will see the plaintext space described as just $(\mathbb{Z}[x] / (x^N + 1)$, and while this works in principle, in practice doing so requires multiprecision integer computations, and ends up being slower than the alternative, which is to use a residue number system before encoding, and treat the plaintext space as $(\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1)$. I’ll say more about RNS encoding in the next section.

The encoding is easier to understand by first describing the decoding step. Given a polynomial $f \in (\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1)$, there is a map called the canonical embedding $\varphi: (\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1) \to \mathbb{C}^N$ that evaluates $f$ at the odd powers of a primitive $2N$-th root of unity. I.e., letting $\omega = e^{2\pi i / 2N}$, we have

\[ \varphi(f) = (f(\omega), f(\omega^3), f(\omega^5), \dots, f(\omega^{2N-1})) \]

Aside: My algebraic number theory is limited (not much beyond a standard graduate course covering Galois theory), but this paper has some more background. My understanding is that we’re viewing the input polynomials as actually sitting inside the number field $\mathbb{Q}[x] / (x^N + 1)$ (or else $q$ is a prime and the original polynomial ring is a field), and the canonical embedding is a specialization of a more general theorem that says that for any subfield $K \subset \mathbb{C}$, the Galois group $K/\mathbb{Q}$ is exactly the set of injective homomorphisms $K \to \mathbb{C}$. I don’t recall exactly why these polynomial quotient rings count as subfields of $\mathbb{C}$, and I think it is not completely trivial (see, e.g., this stack exchange question).

As specialized to this setting, the canonical embedding is a scaled isometry for the 2-norm in both spaces. See this paper for a lot more detail on that. This is a critical aspect of the analysis for FHE, since operations in the ciphertext space add perturbations (noise) in the plaintext space, and it must be the case that those perturbations decode to similar perturbations so that one can use bounds on noise growth in the plaintext space to ensure the corresponding cleartexts stay within some desired precision.

Because polynomials commute with complex conjugation ($f(\overline{z}) = \overline{f(z)}$), and roots of unity satisfy $\overline{\omega^k} = \omega^{-k}$, this canonical embedding is duplicating information. We can throw out the second half of the roots of unity and retain the same structure (the scaling in the isometry changes as well). The result is that the canonical embedding is defined $\varphi: (\mathbb{Z}/q\mathbb{Z})[x] / (x^N + 1) \to \mathbb{C}^{N/2}$ via

\[ \varphi(f) = (f(\omega), f(\omega^3), \dots, f(\omega^{N-1})) \]

Since we’re again using the bit-field encoding to scale up the inputs for noise, the decoding is then defined by applying the canonical embedding, and then applying bit-field decoding (scaling down).

This decoding process embeds the discrete polynomial space inside $\mathbb{C}^{N/2}$ as a lattice, but input cleartexts need not lie on that lattice. And so we get to the encoding step, which involves rounding to a point on the lattice, then inverting the canonical embedding, then applying the bit-field encoding to scale up for noise.

Using commutativity, one can more specifically implement this by first inverting the canonical embedding (which again uses an FFT-like operation), the result of which is in $\mathbb{C}[x] / (x^N + 1)$, then apply the bit-field encoding to scale up, then round the coefficients to be in $\mathbb{Z}[x] / (x^N + 1)$. As mentioned above, if you want the coefficients to be machine-word-sized integers, you’ll have to design this all to ensure the outputs are sufficiently small, and then treat the output as $\mathbb{Z}/q\mathbb{Z}[x] / (x^N + 1)$. Or else use a RNS mechanism.

Residue Number System Pre-processing

In all of the above schemes, the cleartext spaces can be too small for practical use. In the CGGI scheme, for example, a typical cleartext space is only 3 or 4 bits. Some FHE schemes manage this by representing everything in terms of boolean circuits, and pack inputs to various boolean gates in those bits. That is what I’ve mainly focused on, but it has the downside of increasing the number of FHE operations, requiring deeper circuits and more noise management operations, which are slow. Other approaches try to use the numerical structure of the ciphertexts more deliberately, and Sunzi’s Theorem (colloquially known as the Chinese Remainder Theorem) comes to the rescue here.

There will be two “cleartext” spaces floating around here, one for the “original” message, which I’ll call the “original” cleartext space, and one for the Sunzi’s-theorem-decomposed message, which I’ll call the “RNS” cleartext space (RNS for residue number system).

The original cleartext space size $M$ must be a product of primes or co-prime integers $M = m_1 \cdot \dots \cdot m_r$, with each $m_i$ being small enough to be compatible with the desired FHE’s encoding. E.g., for a bit-field encoding, $M$ might be large, but each $m_i$ would have to be at most a 4-bit prime (which severely limits how much we can decompose).

Then, we represent a single original cleartext message $x \in \mathbb{Z}/M\mathbb{Z}$ via its residues mod each $m_i$. I.e., $x$ becomes $r$ different cleartexts $(x \mod m_1, x \mod m_2, \dots, x \mod m_r)$ in the RNS cleartext space. From there we can either encode all the cleartexts in a single plaintext—the various RLWE encodings support this so long as $r < N$ (or $N/2$ for the canonical embedding))—or else encode them as difference plaintexts. In the latter case, the executing program needs to ensure the plaintexts are jointly processed. E.g., any operation that happens to one must happen to all, to ensure that the residues stay in sync and can be reconstructed at the end.

And finally, after decoding we use the standard reconstruction algorithm from Sunzi’s theorem to rebuild the original cleartext from the decoded RNS cleartexts.

I’d like to write a bit more about RNS decompositions and Sunzi’s theorem in a future article, because it is critical to how many FHE schemes operate, and influences a lot of their designs. For example, I glazed over how inverting the canonical embedding works in detail, and it is related to Sunzi’s theorem in a deep way. So more on that in the future.

MLIR — Using Tablegen for Passes

Table of Contents

In the last article in this series, we defined some custom lowering passes that modified an MLIR program. Notably, we accomplished that by implementing the required interfaces of the MLIR API directly.

This is not the way that most MLIR developers work. Instead, they use a code generation tool called tablegen to generate boilerplate for them, and then only add the implementation methods that are custom to their work. In this article, we’ll convert the custom passes we wrote previously to use this infrastructure, and going forward we will use tablegen for defining new dialects as well.

The code relevant to this article is contained in this pull request, and as usual the commits are organized so that they can be read in order.

How to think about tablegen

The basic idea of tablegen is that you can define a pass—or a dialect, as we’ll do next time—in the tablegen DSL (as specialized for MLIR) and then run the mlir-tblgen binary with appropriate flags, and it will generate headers for the functions you need to implement your pass, along with default implementations of some common functions, documentation, and hooks to register the pass/dialect with the mlir-opt-like entry point tool.

It sounds nice, but I have a love hate relationship with tablegen. I personally find it to be unpleasant to use, primarily because it provides poor diagnostic information when you do things wrong. Today, though, I realize that my part of my frustration came from having the wrong initial mindset around tablegen. I thought, incorrectly, that tablegen was an abstraction layer. That is, I could write my tablegen files, build them, and only think about the parts of the generated code that I needed to implement.

Instead, I’ve found that I still need to understand the details of the generated code, enough that I may as well read the generated code closely until it becomes familiar. This materializes in a few ways:

  • mlir-tablegen doesn’t clearly tell you what functions are left unimplemented or explain the contract of the functions you have to write. For that you have to read the docs (which are often incomplete), or often the source code of the base classes that the generated boilerplate subclasses. Or, sadly, sometimes the Discourse threads or Phabricator commit messages are the only places to find the answers to some questions.
  • The main way to determine what is missing is to try to build the generated code with some code that uses it, and then sift through hundreds of lines of C++ compiler errors, which in turn requires understanding the various template gymnastics in the generated code.
  • The generated code will make use of symbols that you have to know to import or forward-declare in the right places, and it expects you to manage the namespaces in which the generated code lives.

With the expectation that tablegen is a totally see-through code generator, and then just sitting down and learning the MLIR APIs, the result is not so bad. Indeed, that was the motivation for building the passes in the last article from “scratch” (without tablegen), because it makes the transition to using tablegen more transparent. Pass generation is also a good starting point for tablegen because, by comparison, using tablegen for dialects results in many more moving parts and configurable options.

Tablegen files and the mlir-tblgen binary

We’ll start by migrating the AffineFullUnroll pass from last time. The starting point is to write some tablegen code. This commit implements it, which I’ll abbreviate below. If you want the official docs on how to do this, see Declarative Pass Specification.

include "mlir/Pass/PassBase.td"

def AffineFullUnroll : Pass<"affine-full-unroll"> {
  let summary = "Fully unroll all affine loops";
  let description = [{
    Fully unroll all affine loops. (could add more docs here like code examples)
  }];
  let dependentDialects = ["mlir::affine::AffineDialect"];
}

[Aside: In the actual commit, there are two definitions, one for the AffineFullUnroll that walks the IR, and the other for AffineFullUnrollPatternRewrite which uses the pattern rewrite engine. In the article I’ll just show the generated code for the first.]

As you can see, tablegen has concepts like classes and class inheritance (the : Pass<...> is subclassing the Pass base class defined in PassBase.td, an upstream MLIR file. But the def keyword here specifically implies we’re instantiating this thing in a way that the codegen tool should see and generate real code for (as opposed to the class keyword, which is just for tablegen templating and code reuse).

So tablegen lets you let-define string variables and lists, but one thing that won’t be apparent in this article, but will be apparent in the next article, is that tablegen lets you define variables and use them across definitions, as well as define snippets of C++ code that should be put into the generated classes (which can use the defined variables). This is visible in the present context in the PassBase.td class which defines a code constructor variable. If you have a special constructor for your pass, you can write the C++ code for it there.

Next we define a build rule using gentbl_cc_library to run the mlir-tblgen binary (in the same commit) with the right options. This gentbl_cc_library bazel rule is provided by the MLIR devs, and it basically just assembles the CLI flags to mlir-tblgen and ensures the code-genned files end up in the right places on the filesystem and are compatible with the bazel dependency system. The build rule invocation looks like

gentbl_cc_library(
    name = "pass_inc_gen",
    tbl_outs = [
        (
            [
                "-gen-pass-decls",
                "-name=Affine",
            ],
            "Passes.h.inc",
        ),
    ],
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "Passes.td",
    deps = [
        "@llvm-project//mlir:OpBaseTdFiles",
        "@llvm-project//mlir:PassBaseTdFiles",
    ],
)

The important part here is that td_file specifies our input file, and tbl_outs defines the generated file, Passes.inc.h, which is at $GIT_ROOT/bazel-bin/lib/Transform/Affine/Passes.h.inc.

The main quirk with gentbl_cc_library is that the name of the bazel rule is not the target that actually generates the code. That is, if you run bazel build pass_inc_gen (or from the git root, bazel build lib/Transform/Affine:pass_inc_gen), it won’t create the files but the build will be successful. Instead, under the hood gentbl_cc_library is a bazel macro that generates the rule pass_inc_gen_filegroup, which is what you have to bazel build to see the actual files.

I’ve pasted the generated code (with both version of the AffineFullUnroll) into a gist and will highlight the important parts here. The first quirky thing the generated code does is use #ifdef as a sort of function interface for what code is produced. For example, you will see:

#ifdef GEN_PASS_DECL_AFFINEFULLUNROLL
std::unique_ptr<::mlir::Pass> createAffineFullUnroll();
#undef GEN_PASS_DECL_AFFINEFULLUNROLL
#endif // GEN_PASS_DECL_AFFINEFULLUNROLL

#ifdef GEN_PASS_DEF_AFFINEFULLUNROLL
... <lots of C++ code> ...
#undef GEN_PASS_DEF_AFFINEFULLUNROLL
#endif // GEN_PASS_DEF_AFFINEFULLUNROLL

This means that to use this file, we will need to define the appropriate symbol in a #define macro before including this header. You can see it happening in this commit, but in brief it will look like this

// in file AffineFullUnroll.h

#define GEN_PASS_DECL_AFFINEFULLUNROLL
#include "lib/Transform/Affine/Passes.h.inc"

// in file AffineFullUnroll.cpp

#define GEN_PASS_DEF_AFFINEFULLUNROLL
#include "lib/Transform/Affine/Passes.h.inc"

... <implement the missing functions from the generated code> ...

I’m no C++ expert, and this was the first time I’d seen this pattern of using #include as a function with #define as the argument. It was a little unsettling to me, until I landed on that mindset that it’s meant to be a white-box codegen, not an abstraction. So read the generated code. Inside the GEN_PASS_DECL_... guard, it defines a single function std::unique_ptr<::mlir::Pass> createAffineFullUnroll(); that is a very limited sole entry point for code that wants to use the pass. We don’t need to implement it unless our Pass has a custom constructor. Then in the GEN_PASS_DEF_... guard it defines a base class, whose functions I’ll summarize, but you should recognize many of them because we implemented them by hand last time.

template <typename DerivedT>
class AffineFullUnrollBase : public ::mlir::OperationPass<> {
  AffineFullUnrollBase() : ::mlir::OperationPass<>(::mlir::TypeID::get<DerivedT>()) {}
  AffineFullUnrollBase(const AffineFullUnrollBase &other) : ::mlir::OperationPass<>(other) {}

  static ::llvm::StringLiteral getArgumentName() {...}
  static ::llvm::StringRef getArgument() { ... }
  static ::llvm::StringRef getDescription() { ... }
  static ::llvm::StringLiteral getPassName() { ... }
  static ::llvm::StringRef getName() { ... }

  /// Support isa/dyn_cast functionality for the derived pass class.
  static bool classof(const ::mlir::Pass *pass) { ... }

  /// A clone method to create a copy of this pass.
  std::unique_ptr<::mlir::Pass> clonePass() const override { ... }

  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
    registry.insert<mlir::affine::AffineDialect>();
  }

  ... <type_id stuff> ...
}

Notably, this doesn’t tell us what functions are left for us to implement. For that we have to either build it and read compiler error messages, or compare it to the base class (OperationPass) and it’s base class (Pass) to see that the only function left to implement is runOnOperation() Or, since we did this last time from the raw API, we can observe that the boilerplate functions we implemented before like getArgument are here, but runOnOperation is not.

Another notable aspect of the generated code is that it uses the curiously recurring template pattern (CRTP), so that the base class can know the eventual name of its subclass, and use that name to hook the concrete subclass into the rest of the framework.

Lower in the generated file you’ll see another #define-guarded block for GEN_PASS_REGISTRATION, which implements hooks for tutorial-opt to register the passes without having to depend on each internal Pass class directly.

#ifdef GEN_PASS_REGISTRATION

inline void registerAffineFullUnroll() {
  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
    return createAffineFullUnroll();
  });
}

inline void registerAffineFullUnrollPatternRewrite() {
  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
    return createAffineFullUnrollPatternRewrite();
  });
}

inline void registerAffinePasses() {
  registerAffineFullUnroll();
  registerAffineFullUnrollPatternRewrite();
}
#undef GEN_PASS_REGISTRATION
#endif // GEN_PASS_REGISTRATION

This implies that, once we link everything properly, the changes to tutorial-opt (in this commit) simplify to calling registerAffinePasses. This registration macro is intended to go into a Passes.h file that includes all the individual pass header files, as done in this commit. And we can use that header file as an anchor for a bazel build target that includes all the passes defined in lib/Transform/Affine at once.

#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_REGISTRATION
#include "lib/Transform/Affine/Passes.h.inc"

}  // namespace tutorial
}  // namespace mlir

Finally, after all this (abbreviated from this commit), the actual content of the pass reduces to the following subclass (with CRTP) and implementation of runOnOperation, the body of which is identical to the last article except for a change from reference to pointer for the return value of getOperation.

#define GEN_PASS_DEF_AFFINEFULLUNROLL
#include "lib/Transform/Affine/Passes.h.inc"

struct AffineFullUnroll : impl::AffineFullUnrollBase<AffineFullUnroll> {
  using AffineFullUnrollBase::AffineFullUnrollBase;

  void runOnOperation() {
    getOperation()->walk([&](AffineForOp op) {
      if (failed(loopUnrollFull(op))) {
        op.emitError("unrolling failed");
        signalPassFailure();
      }
    });
  }
};

I split the AffineFullUnroll migration into multiple commits to highlight the tablegen vs C++ code changes. For MulToAdd, I did it all in one commit. The tests are unchanged, because the entry point is still the tutorial-opt binary with the appropriate CLI flags, and those names are unchanged in the tablegen’ed code.

Bonus: mlir-tblgen also has an option -gen-pass-doc, which you’ll see in the commits, which generates a markdown file containing auto-generated documentation for the pass. A CI workflow can copy these to a website directory, as we do in HEIR, and you get free docs. See this example from HEIR.

Addendum: hermetic Python

When I first set up this tutorial project, I didn’t realize that bazel’s Python rules use the system Python by default. Some early readers found an error that Python couldn’t find the lit module when running tests. While pip install lit in your system Python would work, I also migrated in this commit to a hermetic python runtime and explicit dependency on lit. It should all be handled automatically by bazel now.

MLIR — Writing Our First Pass

Table of Contents

This series is an introduction to MLIR and an onboarding tutorial for the HEIR project.

Last time we saw how to run and test a basic lowering. This time we will write some simple passes to illustrate the various parts of the MLIR API and the pass infrastructure.

As mentioned previously, the main work in MLIR is defining passes that either optimize part of a program, lower from parts of one dialect to others, or perform various normalization and canonicalization operations. In this article, we’ll start by defining a pass that operates entirely within a given dialect by fully unrolling loops. Then we’ll define a pass that does a simple replacement of one instruction with another. Neither pass will be particularly complex, but rather they will show how to set up a pass, how to navigate through a program via the MLIR API, and how to modify the IR by deleting and adding operations.

The code for this post is contained within this pull request.

tutorial-opt and project organization

Last time we used the mlir-opt binary as the main entry point to parse MLIR, run a pass, and emit the output IR. A compiler might run mlir-opt as a subroutine in between the front end (C++ to some MLIR dialects) and the backend (MLIR’s LLVM dialect to LLVM to machine code).

In an out-of-tree MLIR project, mlir-opt can’t be used because it isn’t compiled with the project’s custom dialects or passes. Instead, MLIR makes it easy to build a custom version of the mlir-opt tool for an out-of-tree project. It primarily provides a set of registration hooks that you can use to plug in your dialects and passes, and the framework handles reading/writing, CLI flags, and adds that all on top of the baseline MLIR passes and dialects. We’ll start this article by creating the shell for such a tool with an empty custom pass, which we’ll call tutorial-opt. If this repository were to become one step of an end-to-end compiler, then tutorial-opt would be the main interface to the MLIR part.

The structure of the codebase is a persnickety question here. A typical MLIR codebase seems to split the code into two directories with roughly equivalent hierarchies: an include/ directory for headers and tablegen files (more on tablegen in a future article), and a lib/ directory for implementation code. Then, within those two directories a project would have a Transform/ subdirectory that stores the files for passes that transform code within a dialect, Conversion/ for passes that convert between dialects, Analysis/ for analysis passes, etc. Each of these directories might have subdirectories for the specific dialects they operate on.

For this tutorial we will do it slightly differently by merging include/ and lib/ together (header files will live next to implementation files). I believe the reason that C++ codebases separate this is a combination of implicit public/private interface (client code should only depend on headers in include/, not headers in lib/ or src/). But bazel has many more facilities for enforcing private/public interface boundaries, I find it tedious to navigate parallel directory structures, and this is a tutorial so simpler is better.

So the project’s directory structure will add like this once we create the initial pass:

.
├── README.md
├── WORKSPACE
├── bazel
│   └──  . . .
├── lib
│   └── Transform
│       └── Affine
│           ├── AffineFullUnroll.cpp
│           ├── AffineFullUnroll.h
│           └── BUILD
├── tests
│   └── . . .
└── tools
    ├── BUILD
    └── tutorial-opt.cpp

Unrolling loops, a starter pass

Though MLIR provides multiple mechanisms for defining loops and control flow, the highest level one is in the affine dialect. Originally defined for polyhedral loop analysis (using lattices to study loop structure!), it also simply defines a nice for operation that you can use whenever you have simple loop bounds like iterating over a range with an optional step size. An example loop that sums some values in an array stored in memory might look like:

func.func @sum_buffer(%buffer: memref<4xi32>) -> (i32) {
  %sum_0 = arith.constant 0 : i32
  %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
    %t = affine.load %buffer[%i] : memref<4xi32>
    %sum_next = arith.addi %sum_iter, %t : i32
    affine.yield %sum_next : i32
  }
  return %sum : i32
}

The iter_args is a custom bit of syntax that defines accumulation variables to operate across the loop body (to be in compliance with SSA form; for more on SSA, see this MLIR doc), along with an initial value.

Unrolling loops is a nontrivial operation, but thankfully MLIR provides a utility method for fully unrolling a loop, so our pass will be a thin wrapper around this function call, to showcase some of the rest of the infrastructure before we write a more meaningful pass. The code for this section is in this commit.

This implementation will be technically the most general implementation, by implementing directly from the C++ API, rather than using the more special case features like the pattern rewrite engine, the dialect conversion framework, or tablegen. Those will all come later.

The main idea is to implement the required methods for the OperationPass base class, which “anchors” the pass to work within the context of a specific instance of a specific type of operation, and is applied to every operation of that type. It looks like this:

// lib/Transform/Affine/AffineFullUnroll.h
class AffineFullUnrollPass
    : public PassWrapper<AffineFullUnrollPass,
                         OperationPass<mlir::func::FuncOp>> {
private:
  void runOnOperation() override;  // implemented in AffineFullUnroll.cpp

  StringRef getArgument() const final { return "affine-full-unroll"; }

  StringRef getDescription() const final {
    return "Fully unroll all affine loops";
  }
};

The PassWrapper helps implement some of the required methods for free (mainly adding a compliant copy method), and uses the Curiously Recurring Template Pattern (CRTP) to achieve that. But what matters for us is that OperationPass<FuncOp> anchors this pass to operation on function bodies, and provides the getOperation method in the class which returns the FuncOp being operated on.

Aside: The MLIR docs more formally describe what is required of an OperationPass, and in particular it limits the “anchoring” to specific operations like functions and modules, the insides of which are isolated from modifying the semantics of the program outside of the operation’s scope. That’s a fancy way of saying FuncOps in MLIR can’t screw with variables outside the lexical scope of their function body. More importantly for this example, it explains why we can’t anchor this pass on a for loop operation directly: a loop can modify stuff outside its body (like the contents of memory) via the operations within the loop (store, etc.). This matters because the MLIR pass infrastructure runs passes in parallel. If some other pass is tinkering with neighboring operations, race conditions ensue.

The three functions we need to implement are

  • runOnOperation: the function that performs the pass logic.
  • getArgument: the CLI argument for an mlir-opt-like tool.
  • getDescription: the CLI description when running --help on the mlir-opt-like tool.

The initial implementation of runOperation is empty in the commit for this section. Next, we create a tutorial-opt binary that registers the pass.

// tools/tutorial-opt.cpp
#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/include/mlir/InitAllDialects.h"
#include "mlir/include/mlir/Pass/PassManager.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
#include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h"

int main(int argc, char **argv) {
  mlir::DialectRegistry registry;
  mlir::registerAllDialects(registry);

  mlir::PassRegistration<mlir::tutorial::AffineFullUnrollPass>();

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry));
}

This registers all the built-in MLIR dialects, adds our AffineFullUnrollPass, and then calls the MlirOptMain function which handles the rest. At this point we can run bazel run tools:tutorial-opt --help and see a long list of options with our new pass in it.

OVERVIEW: Tutorial Pass Driver
Available Dialects: acc, affine, amdgpu, <...SNIP...>
USAGE: tutorial-opt [options] <input file>

OPTIONS:

General options:

  Compiler passes to run
    Passes:
      --affine-full-unroll                             -   Fully unroll all affine loops
  --allow-unregistered-dialect                         - Allow operation with no registered dialects
  --disable-i2p-p2i-opt                                - Disables inttoptr/ptrtoint roundtrip optimization
  <...SNIP...>

To allow us to run lit tests that use this tool, we add it to the test_utilities target in this commit, and then we add a first (failing) test in this commit. To avoid complexity, I’m just asserting that the output has no for loops in it.

// RUN: tutorial-opt %s --affine-full-unroll > %t
// RUN: FileCheck %s < %t

func.func @test_single_nested_loop(%buffer: memref<4xi32>) -> (i32) {
  %sum_0 = arith.constant 0 : i32
  // CHECK-NOT: affine.for
  %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
    %t = affine.load %buffer[%i] : memref<4xi32>
    %sum_next = arith.addi %sum_iter, %t : i32
    affine.yield %sum_next : i32
  }
  return %sum : i32
}

Next, we can implement the pass itself in this commit:

#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/include/mlir/Pass/Pass.h"

using mlir::affine::AffineForOp;
using mlir::affine::loopUnrollFull;

void AffineFullUnrollPass::runOnOperation() {
  getOperation().walk([&](AffineForOp op) {
    if (failed(loopUnrollFull(op))) {
      op.emitError("unrolling failed");
      signalPassFailure();
    }
  });
}

getOperation returns a FuncOp, though we don’t use any specific information about it being a function. We instead call the walk method (present on all Operation instances), which traverses the abstract syntax tree (AST) of the operation in post-order (i.e., the function body), and for each operation it encounters, if the type of that operation matches the input type of the callback, the callback is executed. In our case, we attempt to unroll the loop, and if it fails we quit with a diagnostic error.

Exercise: determine how the loop unrolling might fail, and create a test MLIR input that causes it to fail, and observe the error messages that result.

Running this on our test shows the operation is applied:

$ bazel run tools:tutorial-opt -- --affine-full-unroll < tests/affine_loop_unroll.mlir
<...>
#map = affine_map<(d0) -> (d0 + 1)>
#map1 = affine_map<(d0) -> (d0 + 2)>
#map2 = affine_map<(d0) -> (d0 + 3)>
module {
  func.func @test_single_nested_loop(%arg0: memref<4xi32>) -> i32 {
    %c0 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = affine.load %arg0[%c0] : memref<4xi32>
    %1 = arith.addi %c0_i32, %0 : i32
    %2 = affine.apply #map(%c0)
    %3 = affine.load %arg0[%2] : memref<4xi32>
    %4 = arith.addi %1, %3 : i32
    %5 = affine.apply #map1(%c0)
    %6 = affine.load %arg0[%5] : memref<4xi32>
    %7 = arith.addi %4, %6 : i32
    %8 = affine.apply #map2(%c0)
    %9 = affine.load %arg0[%8] : memref<4xi32>
    %10 = arith.addi %7, %9 : i32
    return %10 : i32
  }
}

I won’t explain what this affine.apply thing is doing, but suffice it to say the loop is correctly unrolled. A subsequent commit does the same test for a doubly-nested loop.

A Rewrite Pattern Version

In this commit, we rewrote the loop unroll pass in the next level of abstraction provided by MLIR: the pattern rewrite engine. It is useful in the kind of situation where one wants to repeatedly apply the same subset of transformations to a given IR substructure until that substructure is completely removed. The next section will write a pass that uses that in a meaningful way, but for now we’ll just rewrite the loop unroll pass to show the extra boilerplate.

A rewrite pattern is a subclass of OpRewritePattern, and it has a method called matchAndRewrite which performs the transformation.

struct AffineFullUnrollPattern :
  public OpRewritePattern<AffineForOp> {
  AffineFullUnrollPattern(mlir::MLIRContext *context)
      : OpRewritePattern<AffineForOp>(context, /*benefit=*/1) {}

  LogicalResult matchAndRewrite(AffineForOp op,
                                PatternRewriter &rewriter) const override {
    return loopUnrollFull(op);
  }
};

The return value of matchAndRewrite is a LogicalResult, which is a wrapper around a boolean to signal success or failure, along with named utility functions like failure() and success() to generate instances, and failed(...) to test for failure. LogicalResult also comes with a subclass FailureOr that is subclass of optional that inter-operates with LogicalResult via the presence or absence of a value.

Aside: In a proper OpRewritePattern, the mutations of the IR must go through the PatternRewriter argument, but because loopUnrollFull doesn’t have a variant that takes a PatternRewriter as input, we’re violating that part of the function contract. More generally, the PatternRewriter handles atomicity of the mutations that occur within the OpRewritePattern, ensuring that the operations are applied only if the method reaches the end and succeeds.

Then we instantiate the pattern inside the pass

// A pass that invokes the pattern rewrite engine.
void AffineFullUnrollPassAsPatternRewrite::runOnOperation() {
  mlir::RewritePatternSet patterns(&getContext());
  patterns.add<AffineFullUnrollPattern>(&getContext());
  // One could use GreedyRewriteConfig here to slightly tweak the behavior of
  // the pattern application.
  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

The overall pass is still anchored on FuncOp, but an OpRewritePattern can match against any op. The rewrite engine invokes the walk that we did manually, and one can pass an optional configuration struct that chooses the walk order.

The PatternSet can accept any number of patterns, and the greedy rewrite engine will keep trying to apply them (in a certain order related to the benefit constructor argument) until there are no matching operations to apply, all applied patterns return failure, or some large iteration limit is reached to avoid infinite loops.

A proper greedy RewritePattern

In homomorphic encryption (FHE), multiplication ops are typically MUCH more expensive than addition ops. I’m not an expert in classical CPU hardware performance, but if I recall correctly, on a typical CPU multiplication is something like 2-4x slower than addition, and that advantage probably goes away when doing multiplications in bulk/pipelines.

In FHE, the operations introduce noise growth in the ciphertext, and in some schemes multiplication introduces something like 100x the noise of addition, and reducing that noise is very time consuming. So it makes sense that one would want to convert multiplication ops into addition ops. In this section we’ll write a very simple pass that greedily rewrites multiplication ops as repeated additions.

The idea is to rewrite an operation like y = 9*x as y = 8*x + x (the 8 is a power of 2) and then expand it further as a = x+x; b = a+a; c = b+b; y = c+x. It replaces a multiplication by a constant with a roughly log-number of additions (the base-2 logarithm of the constant), though it gets worse the further away the constant gets from a power of two.

This commit contains a similar “empty shell” of a pass, with two patterns defined. The first, PowerOfTwoExpand, will be a pattern that rewrites y=C*x as y = C/2*x + C/2*x, when C is a power of 2, otherwise fails. The second, PeelFromMul “peels” a single addition off a product that is not with a power of 2, rewriting y = 9x as y = 8*x + x. These are applied repeatedly via the greedy pattern rewrite engine. By setting the benefit argument of PowerOfTwoExpand to be larger than PeelFromMul, we tell the greedy rewrite engine to prefer PowerOfTwoExpand whenever possible. Together, that achieves the transformation mentioned above.

This commit adds a failing test that only exercises PowerOfTwoExpand, and then this commit implements it. Here’s the implementation:

  LogicalResult matchAndRewrite(
       MulIOp op, PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);

    // canonicalization patterns ensure the constant is on the right, if there is a constant
    // See https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules
    Value rhs = op.getOperand(1);
    auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
    if (!rhsDefiningOp) {
      return failure();
    }

    int64_t value = rhsDefiningOp.value();
    bool is_power_of_two = (value & (value - 1)) == 0;

    if (!is_power_of_two) {
      return failure();
    }

    ConstantOp newConstant = rewriter.create<ConstantOp>(
        rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value / 2));
    MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
    AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, newMul);

    rewriter.replaceOp(op, {newAdd});
    rewriter.eraseOp(rhsDefiningOp);

    return success();
  }

Some notes:

  • Value is the type that represents an SSA value (i.e., an MLIR variable), and getDefiningOp fetches the unique operation that defines it in its scope.
  • There are a variety of “casting” operations like rhs.getDefiningOp<arith::ConstantIntOp>() that take the type you want as output as a template parameter, and return null if the type cannot be converted. You might also see dyn_cast<>
  • (value & (value - 1)) is a classic bit-twiddling trick to compute if an integer is a power of two. We check it and skip the pattern if it’s not.
  • The actual constant itself is represented as an MLIR attribute, which is essentially compile-time static data attached to the op. You can put strings or dictionaries as attributes, but for ConstantOp it’s just an int.

The rewriter.create part is where we actually do the real work. Create a new constant that is half the original constant, create new multiplication and addition ops, and then finally rewriter.replaceOp removes the original multiplication op and uses the output of newAdd for any other operations that used the original multiplication op’s output.

It’s worth noting that we’re relying on MLIR’s built-in canonicalization passes in a few ways here:

  • To ensure that the constant is always the second operand of a multiplication op.
  • To ensure that the base case (x*1) is “folded” into a plain x and the constant 1 is removed.
  • The fold part of applyPatternsAndFoldGreedily is what runs these cleanup steps for us.

PeelFromMul is similar, implemented and tested in this commit:

  LogicalResult matchAndRewrite(MulIOp op,
                                PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);
    Value rhs = op.getOperand(1);
    auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
    if (!rhsDefiningOp) { return failure(); }
    int64_t value = rhsDefiningOp.value();

    // We are guaranteed `value` is not a power of two, because the greedy
    // rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
    // it has higher benefit.

    ConstantOp newConstant = rewriter.create<ConstantOp>(
        rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value - 1));
    MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
    AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, lhs);

    rewriter.replaceOp(op, {newAdd});
    rewriter.eraseOp(rhsDefiningOp);

Running it! Input:

func.func @power_of_two_plus_one(%arg: i32) -> i32 {
  %0 = arith.constant 9 : i32
  %1 = arith.muli %arg, %0 : i32
  func.return %1 : i32
}

Output:

module { 
  func.func @power_of_two_plus_one(%arg0: i32) -> i32 {
    %0 = arith.addi %arg0, %arg0 : i32
    %1 = arith.addi %0, %0 : i32
    %2 = arith.addi %1, %1 : i32
    %3 = arith.addi %2, %arg0 : i32
    return %3 : i32
  }
}

Exercise: Try swapping the benefit arguments to see how the output changes.

Though this pass is quite naive, you can imagine a more sophisticated technique that builds a cost model for multiplications and additions, and optimizes for the cheapest cost representation of an arithmetic operation in terms of repeated additions, multiplications, and other supported ops.

Should we walk?

With two options for how to define a pass—one to walk the entire syntax tree from the root operation, and one to match and rewrite patterns with the rewrite engine—the natural question is when should you use one versus the other.

The MLIR docs describe the motivation behind the pattern rewrite engine, and it comes from a long history of experience with the LLVM project. For one, the pattern rewrite engine expresses a convenient subset of what can be achieved with an MLIR pass. This is conceptually trivial, in the sense that anyone who can walk the entire AST can, with enough effort, do anything they want including reimplementing the pattern rewrite engine.

More practically, the pattern rewrite engine is convenient to represent local transformations. “Local” here means that the input and output can be detected via a subset of the AST as a directed acyclic graph. More pragmatically, think of it as any operation you can identify by looking around at neighboring operations in the same block and applying some filtering logic. E.g., “is this exp operation followed by a log operation with no other uses of the output of the exp?”

On the other hand, some analyses and optimizations need to construct the entire dataflow of a program to work. A good example is common subexpression elimination, which determines whether it is cost effective to extract a subexpression used in multiple places into a separate variable. Doing so may introduce additional cost of memory access, so it depends both on the operation’s cost and on the availability of registers at that point in the program. You can’t get this information by pattern matching the AST locally.

The wisdom seems to be: using the pattern rewrite engine is usually easier than writing a pass that. walks the AST. You don’t need large case/switch statements to handle everything that could show up in the IR. The engine handles re-applying patterns many times. And so you can write the patterns in isolation and trust the engine to combine them appropriately.

Bonus: IDEs and CI

Since we explored the C++ API, it helps to have an IDE integration. I use neovim with the clangd LSP, and to make it work with a Bazel C++ project, one needs to use something analogous to Hedron Vision’s compile_commands extractor, which I configured for this tutorial project in this commit. It’s optional, but if you want to use it you have to run bazel run @hedron_compile_commands//:refresh_all once to set it up, and then clangd and clang-tidy, etc., should find the generated json file and use it. Also, if you edit a BUILD file, you have to re-run refresh_all for the changes to show up in the LSP.

Though it’s not particularly relevant to this tutorial, I also added a commit that configures GitHub actions to build and test the project in CI in this commit. It is worth noting that the GitHub cache action reduces subsequent build times from 1-2 hours down to just a few minutes.

Thanks to Patrick Schmidt for feedback on a draft of this article.

MLIR — Getting Started

Table of Contents

As we announced recently, my team at Google has started a new effort to build production-worthy engineering tools for Fully Homomorphic Encryption (FHE). One focal point of this, and one which I’ll be focusing on as long as Google is willing to pay me to do so, is building out a compiler toolchain for FHE in the MLIR framework (Multi-Level Intermediate Representation). The project is called Homomorphic Encryption Intermediate Representation, or HEIR.

The MLIR community is vibrant. But because it’s both a new and a fast-moving project, there isn’t a lot in the way of tutorials and documentation available for it. There is no authoritative MLIR book. Most of the reasoning around things is in folk lore and heavily technical RFCs. And because MLIR is built on top of LLVM (the acronym formerly meaning “Low Level Virtual Machine”), much of the documentation that exists explains concepts by analogy to LLVM, which is unhelpful for someone like me who isn’t familiar with the internals of how LLVM works. Finally, the “proper” tutorials that do exist are, in my opinion, too high level to allow one to really get a sense for how to write programs in the framework.

I want people interested in FHE to contribute to HEIR. To that end, I want to lower the barrier to entry to working with MLIR. And so this series of blog posts will be a detailed introduction to MLIR in general, with some bias toward the topics that show up in HEIR and that I have spent time studying and internalizing.

This first article describes a typical MLIR project’s structure, and the build system that we use in HEIR. But the series as a whole will be built up along with a GitHub repository that breaks down each step into clean, communicative commits, similar to my series about the Riemann Hypothesis. To avoid being broken by upstream changes to MLIR (our project will be “out of tree”, so to speak), we will pin the dependency on MLIR to a specific commit hash. While this implies that the content in these articles will eventually become stale, I will focus on parts of MLIR that are relatively stable.

A brief history of MLIR and LLVM

The first thing you’ll notice about MLIR is that it lives within the LLVM project’s monorepo under a folder called mlir/. LLVM is a sort of abstracted assembly language that compiler developers can target as a backend, and then LLVM itself comes packaged with a host of optimizations and “real” backend targets that can be compiled to. If you’re, say, the Rust programming language and you want to compile to x86, ARM, and WebAssembly without having to do all that work, you can just output LLVM code and then run LLVM’s compilation suite.

I don’t want to get too much into the history of LLVM (see this interview for more details), and I don’t have any first hand knowledge of it, but from what I can gather LLVM (formerly standing for “Low Level Virtual Machine”) was the PhD project of Chris Lattner in the early 2000’s, aiming to be a next-generation C compiler. Chris moved to Apple, where he worked on LLVM and languages like Swift which build on LLVM. In 2017 he moved to Google Brain as a director of the TensorFlow infrastructure team, and he and his team built MLIR to unify the siloed tooling in their ecosystem.

We’ll talk more about what exactly MLIR is and what it provides in a future article. For a high level overview, see the MLIR paper. In short, it’s a framework for building compilers, with the underlying philosophy that a big compiler should be broken up into lots of small compilers between sub-languages (which compiler folks call “intermediate representations” or “IR”s), where each sub-language is designed to make a particular kind of optimization more natural to express. Hence the MLIR acronym standing for Multi-Level Intermediate Representation.

MLIR is relevant for TensorFlow because training and inference can both be thought of as programs whose instructions are things like “2d convolution” and “softmax.” And the process for optimizing those instructions, while converting them to lower level hardware instructions (especially on TPU accelerators) is very much a compilers problem. MLIR breaks the process up into IRs at various levels of abstraction, like Tensor operations, linear algebra, and lower-level control flow.

But LLVM just couldn’t be directly reused as a TensorFlow compiler. It was too legacy and too specialized to CPU, operated at a much lower abstraction layer, and had incidental tech debt. But LLVM did have lots of reusable pieces, like data structures, error handling, and testing infrastructure. And combined with Lattner’s intimate familiarity with a project he’d worked on for almost 20 years, it was probably just easier to jumpstart MLIR by putting it in the monorepo.

Build systems

The rest of this article is going to focus on setting up the build system for our tutorial project. It will describe each commit in this pull request.

Now, the official build system of LLVM and MLIR is CMake. But I’ll be using Bazel for a few reasons. First, I want to induct interested readers into HEIR, and that’s what HEIR uses because it’s a Google-owned project. Second, though one might worry that the Bazel configuration is complicated or unsupported, because MLIR and LLVM have become critical to Google’s production infrastructure, Google helps to main a Bazel “overlay” in parallel with the CMake configuration, and Google has on call engineers responsible for ensuring that both Google’s internal copy of MLIR stays up to date with the LLVM monorepo, and that any build issues are promptly fixed. The rough edges that remain are simple enough for an impatient dummy like me to handle.

So here’s an overview of Bazel (with parts repeated from my prior article). Bazel is the open source analogue of Google’s internal build system, “Blaze”, and Starlark is its Python-inspired scripting language. There are lots of opinions about Bazel that I won’t repeat here. You can install it using the bazelisk program.

First some terminology. To work with Bazel you do the following.

  • Define a WORKSPACE file which defines all your project’s external dependencies, how to fetch their source code, and what bazel commands should be used to build them. This can be thought of as a top-level CMakeLists, except that it doesn’t contain any instructions for building the project beyond declaring the root of the project’s directory tree and the project’s name.
  • Define a set of BUILD files in each subdirectory, declaring the build targets that can be built from the source files in that directory (but not its subdirectories). This is analogous to CMakeLists files in subdirectories. Each build target can declare dependence on other build targets, and bazel build ensures the dependencies are built first, and caches the build results across a session. Many projects have a BUILD file in the project root to expose the project’s public libraries and APIs.
  • Use the built-in bazel rules like cc_library and cc_binary and cc_test to group files into libraries that can be built with bazel build, executable binaries that can also be run with bazel run, and tests that can also be run with bazel test. Most bazel rules boil down to calling some executable program like gcc or javac with specific arguments, while also keeping track of the accumulated dependency set of build artifacts in a “hermetic” location on the filesystem.
  • Define new bazel rules that execute custom programs, and which declare dependencies and outputs for the static dependency graph. MLIR’s custom rules revolve around the tblgen program, which is MLIR’s custom templating language that generates C++ code.
  • Write any additional bazel macros that chain together built-in bazel commands. Macros look like Python functions that call individual bazel rules and possibly pass data between them. They’re written in .bzl files (containing Starlark code) which are interpreted directly by bazel. We’ll see a good example of a bazel macro when we talk about MLIR’s testing framework lit, but this article contains a simple one for setting up the LLVM dependency in the WORKSPACE file (which is also Starlark).

Generally, bazel builds targets in two phases. First—the analysis phase—it loads all the BUILD files and imported .bzl files, and scans for all the rules that were called. In particular, it runs the macros, because it needs to know what rules are called by the macros (and rules can be guarded by control flow, or their arguments can be generated dynamically, etc.). But it doesn’t run the build rules themselves. In doing this, it can build a complete graph of dependencies, and report errors about typos, missing dependencies, cycles, etc. Once the analysis phase is complete, it runs the underlying rules in dependency order, and caches the results. Bazel will only run a rule again if something changes with the files it depends on or its underlying dependencies.

The WORKSPACE and llvm-project dependency

The commits in this section will come from https://github.com/j2kun/mlir-tutorial/pull/1.

After adding a .gitignore to filter out Bazel’s build directories, this commit sets up an initial WORKSPACE file and two bazel files that perform an unusual two-step dance for configuring the LLVM codebase. The workspace file looks like this:

workspace(name = "mlir_tutorial")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")

load("//bazel:import_llvm.bzl", "import_llvm")

import_llvm("llvm-raw")

load("//bazel:setup_llvm.bzl", "setup_llvm")

setup_llvm("llvm-project")

This is not a normal sort of dependency. A normal dependency might look like this:

http_archive(
    name = "abc",
    build_file = "//bazel:abc.BUILD",
    sha256 = "7fa5a448a4309fb4d6cf856c3fe4cc4be46b09dd552a05d5cfacd75f8d9504ad",
    urls = [
        "https://github.com/berkeley-abc/abc/archive/eb44a80bf2eb8723231e72bb095c97d1e4834d56.zip",
    ],
)

The above tells bazel: go pull the zip file from the given URL, double check it’s hashsum, and then (because the dependent project is not build with bazel) I’ll tell you where in my repository to find the BUILD file that you should use to build it. If the project had a BUILD file, we could omit build_file and it would just work.

Now, LLVM has bazel build files, but they are hidden in the utils/bazel subdirectory of the project. Bazel requires its special files to be in the right places, plus the bazel configuration is designed to be in sync with the CMake configuration. So the utils/bazel directory has an llvm_configure bazel macro which executes a python script that symlinks everything properly. More info about the upstream system can be found here.

So to run this macro we have to download the LLVM code as a repository, which I put into the import_llvm.bzl file, as well as call the macro, which I put into setup_llvm.bzl. Why two files? An apparent quirk of bazel is that you can’t load() a macro from a dependency’s bazel file in the same WORKSPACE file in which you download the dependency.

It’s also worth mentioning that import_llvm.bzl is where I put the hard-coded commit hash that pins this project to a specific LLVM version.

Getting past some build errors

In an ideal world this would be enough, but trying to build MLIR now gives errors. In the following examples I will try to build the @llvm-project//mlir:IR build target (arbitrarily chosen).

Side note: some readers of early drafts have had trouble getting these steps to work exactly. Despite bazel aiming to be a perfectly hermetic build system, it has to store temporary files somewhere, and that can lead to inconsistencies and permission errors. If you’re not able to get these steps to work, check out these links:

For starters, the build fails with

$ bazel build @llvm-project//mlir:IR
ERROR: Skipping '@llvm-project//mlir:IR': error loading package '@llvm-project//mlir': 
Unable to find package for @bazel_skylib//rules:write_file.bzl: 
The repository '@bazel_skylib' could not be resolved: 
Repository '@bazel_skylib' is not defined.

Bazel complains that it can’t find @bazel_skylib, which is a sort of extended standard library for Bazel. The MLIR Bazel overlay uses it for macros like “run shell command.” And so we learn another small quirk about Bazel, that each project must declare all transitive workspace dependencies (for now).

So in this commit we add bazel_skylib as a dependency.

Now it fails because of two other dependencies, llvm_zlib and llvm_std. This commit adds them.

$ bazel build @llvm-project//mlir:IR
ERROR: /home/j2kun/.cache/bazel/_bazel_j2kun/fc8ffaa09c93321753c7c87483153cea/external/llvm-project/llvm/BUILD.bazel:184:11: 
no such package '@llvm_zlib//': 
The repository '@llvm_zlib' could not be resolved: 
Repository '@llvm_zlib' is not defined and referenced by '@llvm-project//llvm:Support'

Now when you try to build you get a bona-fide compiler error.

$ bazel build @llvm-project//mlir:IR
INFO: Analyzed target @llvm-project//mlir:IR (41 packages loaded, 1495 targets configured).
INFO: Found 1 target...
ERROR: <... snip ...>
In file included from external/llvm-project/llvm/lib/Demangle/Demangle.cpp:13:
external/llvm-project/llvm/include/llvm/Demangle/Demangle.h:35:28: error: 
'string_view' is not a member of 'std'
   35 | char *itaniumDemangle(std::string_view mangled_name);
      |                            ^~~~~~~~~~~
external/llvm-project/llvm/include/llvm/Demangle/Demangle.h:35:28: note: 'std::string_view' is only available from C++17 onwards

note: ‘std::string_view’ is only available from C++17 onwards” suggests something is still wrong with our setup, and indeed, we need to tell bazel to compile with C++17 support. This can be done in a variety of ways, but the way that has been the most reliable for me is to add a .bazelrc file that enables this by default in every bazel build command run while the working directory is underneath the project root. This is done in this commit. (also see this extra step that may be needed for MacOS users)

# in .bazelrc
build --action_env=BAZEL_CXXOPTS=-std=c++17

Then, finally, it builds.

At this point you could build ALL of the LLVM/MLIR project by running bazel build @llvm-project//mlir/…:all. However, while you will need to do something similar to this eventually, and doing it now (while you read) is a good way to eagerly populate the build cache, it will take 30 minutes to an hour, make your computer go brrr, and use a few gigabytes of disk space for the cached build artifacts. (After working one three projects that each depend on LLVM and/or MLIR, my bazel cache is currently sitting at 23 GiB).

But! If you try there’s still one more error:

$ bazel build @llvm-project//mlir/...:all
ERROR: /home/j2kun/.cache/bazel/_bazel_j2kun/fc8ffaa09c93321753c7c87483153cea/external/llvm-project/mlir/test/BUILD.bazel:591:11: 
no such target '@llvm-project//llvm:NVPTXCodeGen': 
target 'NVPTXCodeGen' not declared in package 'llvm' defined by
/home/j2kun/.cache/bazel/_bazel_j2kun/fc8ffaa09c93321753c7c87483153cea/external/llvm-project/llvm/BUILD.bazel 
(Tip: use `query "@llvm-project//llvm:*"` to see all the targets in that package) and referenced by '@llvm-project//mlir/test:TestGPU'

This is another little bug in the Bazel overlays that I hope will go away soon. It took me a while to figure this one out when I first encountered it, but here’s what’s happening. In the bazel/setup_llvm.bzl file that chooses which backend targets to compile, we chose only X86. The bazel overlay files are supposed to treat all backends as optional, and only define targets when the chosen backend dependencies are present. This is how you can avoid compiling a bunch of code for doing GPU optimization when you don’t want to target GPUs.

But, in this case the NVPTX backend (a GPU backend) is defined whether or not you include it as a target. So the simple option is to just include it as a target and take the hit on the cold-start build time. This commit fixes it.

Now you can build all of LLVM, and in particular you can build the main MLIR binary mlir-opt.

$ bazel run @llvm-project//mlir:mlir-opt -- --help
OVERVIEW: MLIR modular optimizer driver

Available Dialects: acc, affine, amdgpu, amx, arith, arm_neon, arm_sve, async, bufferization, builtin, cf,
complex, dlti, emitc, func, gpu, index, irdl, linalg, llvm, math, memref, ml_program, nvgpu, nvvm, omp, pdl, 
pdl_interp, quant, rocdl, scf, shape, sparse_tensor, spirv, tensor, test, test_dyn, tosa, transform, vector, 
x86vector
USAGE: mlir-opt [options] <input file>

OPTIONS:
...

mlir-opt is the main entry point for running optimization passes and lowering code from one MLIR dialect to another. Next time, we’ll explore what some of the simpler dialects look like, run some pre-defined lowerings, and learn about how the end-to-end testing framework works.

Thanks to Patrick Schmidt for feedback on a draft of this article.