MLIR — Folders and Constant Propagation

Table of Contents

Last time we saw how to use pre-defined MLIR traits to enable upstream MLIR passes like loop-invariant-code-motion to apply to poly programs. We left out -sccp (sparse conditional constant propagation), and so this time we’ll add what is needed to make that pass work. It requires the concept of folding.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Constant Propagation vs Canonicalization

-sccp is sparse conditional constant propagation, which attempts to infer when an operation has a constant output, and then replaces the operation with the constant value. Repeating this, it “propagates” the constants as far as possible through the program. You can think of it like eagerly computing values it can during compile time, and then sticking them into the compiled program as constants.

Here’s what it looks like for arith, where all the needed tools are implemented. For an input like:

func.func @test_arith_sccp() -> i32 {
  %0 = arith.constant 7 : i32
  %1 = arith.constant 8 : i32
  %2 = arith.addi %0, %0 : i32
  %3 = arith.muli %0, %0 : i32
  %4 = arith.addi %2, %3 : i32
  return %2 : i32
}

The output of tutorial-opt --sccp is

func.func @test_arith_sccp() -> i32 {
  %c63_i32 = arith.constant 63 : i32
  %c49_i32 = arith.constant 49 : i32
  %c14_i32 = arith.constant 14 : i32
  %c8_i32 = arith.constant 8 : i32
  %c7_i32 = arith.constant 7 : i32
  return %c14_i32 : i32
}

Note two additional facts: sccp doesn’t delete dead code, and what is not shown here is the main novelty in sccp, which is that it can propagate constants through control flow (ifs and loops).

A related concept is the idea of canonicalization, which gets its own --canonicalize pass, and which hides a lot of the heavy lifting in MLIR. Canonicalize overlaps a little bit with sccp, in that it also computes constants and materializes them in the IR. Take, for example, the --canonicalize pass on the same IR:

func.func @test_arith_sccp() -> i32 {
  %c14_i32 = arith.constant 14 : i32
  return %c14_i32 : i32
}

The intermediate constants are all pruned, and all that remains is the return value and no operations. Canonicalize cannot propagate constants through control flow, and as such should be thought of as more “local” than sccp.

Both of these, however, are supported via folding, which is the process of taking series of ops and merging them together into simpler ops. It also requires our dialect has some sort of constant operation, which is inserted (“materialized”) with the results of a fold. Folding and canonicalization are more general than what I’m showing here, so we’ll come back to what else they can do in a future article.

The rough outline of what is needed to support folding in this way is:

  • Adding a constant operation
  • Adding a materialization hook
  • Adding folders for each op

This would result in a situation (test case commit) as follows. Starting from

%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%1 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
%2 = poly.mul %1, %1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %1, %1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>

We would get

%0 = poly.constant dense<[2, 8, 20, 24, 18]> : tensor<5xi32> : <10>
%1 = poly.constant dense<[1, 4, 10, 12, 9]> : tensor<5xi32> : <10>
%2 = poly.constant dense<[1, 2, 3]> : tensor<3xi32> : <10>

Making a constant op

Currently we’re imitating a constant polynomial construction by combing a from_tensor op with arith.constant. Like this:

%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>

While a constant operation might combine them into one op.

%0 = poly.constant dense<[2, 8, 20, 24, 18]> : !poly.poly<10>

The from_tensor op can also be used to build a polynomial from data, not just constants, so it’s worth having around even after we implement poly.constant.

Having a dedicated constant operation has benefits explained in the MLIR documentation on folding. What’s relevant here is that fold can be used to signal to passes like sccp that the result of an op is constant (statically known), or it can be used to say that the result of an op is equivalent to a pre-existing value created by a different op. For the constant case, a materializeConstant hook is also needed to tell MLIR how to take the constant result and turn it into a proper IR op.

The constant op itself, in this commit, comes with two new concepts, the ConstantLike trait and an argument that is an attribute constraint.

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {   // new
  let summary = "Define a constant polynomial via an attribute.";
  let arguments = (ins AnyIntElementsAttr:$coefficients);    // new
  let results = (outs Polynomial:$output);
  let assemblyFormat = "$coefficients attr-dict `:` type($output)";
}

The ConstantLike attribute is checked here during folding via the constant op matcher as an assertion. [Aside: I’m not sure why the trait is specifically required, so long as the materialization function is present on the dialect; it just seems like this check is used for assertions. Perhaps it’s just a safeguard.]

Next we have the line let arguments = (ins AnyIntElementsAttr:$coefficients); This defines the input to the op as an attribute (statically defined data) rather than a previous SSA value. The AnyIntElementsAttr is itself an attribute constraint, allowing any attribute that is has the IntElementsAttrBase as a base class to be used (e.g., 32-bit or 64-bit integer attributes). This means that we could use all of the following syntax forms:

%10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%11 = poly.constant dense<[2, 3, 4]> : tensor<3xi8> : !poly.poly<10>
%12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10>
%13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10>

Adding folders

We add folders in these commits:

Each has the same structure: add let hasFolder = 1; to the tablegen for the op, which adds a header declaration of the following form (noting that the signature would be different if the op has more than one result value, see docs).

OpFoldResult <OpName>::fold(<OpName>::FoldAdaptor adaptor);

Then we implement it in PolyOps.cpp. The semantics of this function are such that, if the fold decides the op should be replaced with a constant, it must return an attribute representing that constant, which can be given as the input to a poly.constant. The FoldAdaptor is a shim that has the same method names as an instance of the op’s C++ class, but arguments that have been folded themselves are replaced with Attribute instances representing the constants they were folded with. This will be relevant for folding add and mul, since the body needs to actually compute the result eagerly, and needs access to the actual values to do so.

For poly.constant the timplementation is trivial: you just return the input attribute.

OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
  return adaptor.getCoefficients();
}

The from_tensor op is similar, but has an extra cast that acts as an assertion, since the tensor might have been constructed with weird types we don’t want as input. If the dyn_cast fails, the result is nullptr, which is cast by MLIR to a failed OpFoldResult.

OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
  // Returns null if the cast failed, which corresponds to a failed fold.
  return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
}

The poly binary ops are slightly more complicated since they are actually doing work. Each of these fold methods effectively takes as input two DenseIntElementsAttr for each operand, and expects us to return another DenseIntElementsAttr for the result.

For add/sub which are elementwise operations on the coefficients, we get to use an existing upstream helper method, constFoldBinaryOp, which through some template metaprogramming wizardry, allows us to specify only the elementwise operation itself.

OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
  return constFoldBinaryOp<IntegerAttr, APInt>(
      adaptor.getOperands(), [&](APInt a, APInt b) { return a + b; });
}

For mul, we have to write out the multiplication routine manually. In what’s below, I’m implementing the naive textbook polymul algorithm, which could be optimized if one expects people to start compiling programs with large, static polynomials in them.

OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
  auto lhs = cast<DenseIntElementsAttr>(adaptor.getOperands()[0]);
  auto rhs = cast<DenseIntElementsAttr>(adaptor.getOperands()[1]);
  auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
  auto maxIndex = lhs.size() + rhs.size() - 1;

  SmallVector<APInt, 8> result;
  result.reserve(maxIndex);
  for (int i = 0; i < maxIndex; ++i) {
    result.push_back(APInt((*lhs.begin()).getBitWidth(), 0));
  }

  int i = 0;
  for (auto lhsIt = lhs.value_begin<APInt>(); lhsIt != lhs.value_end<APInt>();
       ++lhsIt) {
    int j = 0;
    for (auto rhsIt = rhs.value_begin<APInt>(); rhsIt != rhs.value_end<APInt>();
         ++rhsIt) {
      // index is modulo degree because poly's semantics are defined modulo x^N = 1.
      result[(i + j) % degree] += *rhsIt * (*lhsIt);
      ++j;
    }
    ++i;
  }

  return DenseIntElementsAttr::get(
      RankedTensorType::get(static_cast<int64_t>(result.size()),
                            IntegerType::get(getContext(), 32)),
      result);
}

Adding a constant materializer

Finally, we add a constant materializer. This is a dialect-level feature, so we start by adding let hasConstantMaterializer = 1; to the dialect tablegen, and observing the newly generated header signature:

Operation *PolyDialect::materializeConstant(
    OpBuilder &builder, Attribute value, Type type, Location loc);

The Attribute input represents the result of each folding step above. The Type is the desired result type of the op, which is needed in cases like arith.constant where the same attribute can generate multiple different types via different interpretations of a hex string or splatting with a result tensor that has different dimensions.

In our case the implementation is trivial: just construct a constant op from the attribute.

Operation *PolyDialect::materializeConstant(
    OpBuilder &builder, Attribute value, Type type, Location loc) {
  auto coeffs = dyn_cast<DenseIntElementsAttr>(value);
  if (!coeffs)
    return nullptr;
  return builder.create<ConstantOp>(loc, type, coeffs);
}

Other kinds of folding

While this has demonstrated a generic kind of folding with respect to static constants, many folding functions in MLIR use simple matches to determine when an op can be replaced with a value from a previously computed op.

Take, for example, the complex dialect (for complex numbers). A complex.create op constructs a complex number from real and imaginary parts. A folder in that dialect checks for a pattern like complex.create(complex.re(%z), complex.im(%z)), and replaces it with %z directly. The arith dialect similarly has folds for things like a-b + b -> a and a + 0 -> a.

However, most work on simplifying an IR according to algebraic rules belongs in the canonicalization pass, since while it supports folding, it also supports general rewrite patterns that are allowed to delete and create ops as needed to simplify the IR. We’ll cover canonicalization in more detail in a future article. But just remember, folds may only modify the single operation being folded, use existing SSA values, and may not create new ops. So they are limited in power and decidedly local operations.

MLIR — Using Traits

Table of Contents

Last time we defined a new dialect poly for polynomial arithmetic. This time we’ll spruce up the dialect by adding some pre-defined MLIR traits, and see how the application of traits enables some general purpose passes to optimize poly programs.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Traits and Loop Invariant Code Motion

As a compiler toolchain, MLIR heavily emphasizes code reuse. This is accomplished largely through two constructs: traits and interfaces. An interface is an interface in the normal programming sense of the word: a set of function signatures implemented by a type and providing some limited-scope behavior. Applied to MLIR, you can implement interfaces on operations and types, and then passes can operate at the interface level. Traits are closely related: a trait is an interface with no methods. Traits can just be “slapped on” an operation and passes can magically start working with them. They can also serve as mixins for common op verification routines, type inference, and more.

In this article I’ll show how adding traits to the operations in the poly dialect (defined last time) allows us to reuse existing MLIR passes on poly constructs. In future articles we’ll see how to define new traits and interfaces. But for now, existing traits are an extremely simple way to start using the batteries included in MLIR on new dialects.

As mentioned, one applies traits primarily to enable passes to do things with custom operations. So let’s start from the passes. The general transformation passes list includes a pass called loop invariant code motion. This checks loop bodies for any operations that don’t need to be in the loop, and moves them out of the loop body. Using this requires us to add two traits to ops to express that they are safe to move around. The first is called NoMemoryEffect (which is technically an empty implementation of an interface) that asserts the operation does not have any side effects related to writing to memory. The second is AlwaysSpeculatable (technically a list of two traits), which says that an operation is allowed to be “speculatively executed,” i.e., computed early. If it is speculatable, then the compiler can move the op to another location. If not, say it reads from a memory location that can be written to, there is an earliest point before which it’s not safe to move the op.

Loop invariant code motion takes ops with these two traits, and hoists them outside of loops when the operation’s operands are unchanged by the body of the loop. Conveniently, MLIR also defines a single named list of traits called Pure, which is NoMemoryEffect and AlwaysSpeculatable. So we can just add the trait name to our tablegen op definition (via a template parameter that defaults to an empty list of traits) as in this commit.

//lib/Dialect/Poly/PolyOps.td
@@ -3,9 +3,10 @@

 include "PolyDialect.td"
 include "PolyTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"


-class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic> {
+class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure]> {
   let arguments = (ins Polynomial:$lhs, Polynomial:$rhs);

This commit adds the boilerplate to register all default MLIR passes in tutorial-opt, and adds an example test asserting a poorly-placed poly.mul is hoisted out of the loop body.

// RUN: tutorial-opt %s --loop-invariant-code-motion > %t
// RUN: FileCheck %s < %t
... <setup> ...
// CHECK: poly.mul
// CHECK: affine.for
%ret_val = affine.for %i = 0 to 100 iter_args(%sum_iter = %p0) -> !poly.poly<10> {
  // The poly.mul should be hoisted out of the loop.
  // CHECK-NOT: poly.mul
  %2 = poly.mul %p0, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
  %sum_next = poly.add %sum_iter, %2 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
  affine.yield %sum_next : !poly.poly<10>
}

In the generated C++ code, adding new traits or interfaces adds new template arguments in the op’s class definition:

// PolyOps.h.inc
class SubOp : public ::mlir::Op<SubOp, 
::mlir::OpTrait::ZeroRegions, 
::mlir::OpTrait::OneResult, 
::mlir::OpTrait::OneTypedResult<::mlir::tutorial::poly::PolynomialType>::Impl, 
::mlir::OpTrait::ZeroSuccessors, 
::mlir::OpTrait::NOperands<2>::Impl, 
::mlir::OpTrait::OpInvariants,
::mlir::ConditionallySpeculatable::Trait,            // <-- new
::mlir::OpTrait::AlwaysSpeculatableImplTrait,   // <-- new
::mlir::MemoryEffectOpInterface::Trait>          // <--- new
{ ... }

And NoMemoryEffect adds a trivial implementation of the memory effects interface:

// PolyOps.h.inc
void SubOp::getEffects(
  ::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects) {
}

As far as using traits goes, this is it. However, to figure out what each trait does, you have to dig through the pass implementations a bit. That, and all the “helper” definitions like Pure that combine multiple traits are not documented, nor is the complete list of available traits and interfaces (the traits list is missing quite a few, like ConstantLike, Involution, Idempotent, etc.). When I first drafted this article, I had only applied the AlwaysSpeculatable trait list to the Poly_BinOp, and I was confused when --loop-invariant-code-motion was a no-op. I had to dig into the passes implementation here to actually see that it also needs isMemoryEffectFree which uses MemoryEffectOpInterface.

So next we’ll explore the general passes and traits, and add relevant traits to the poly dialect.

Passes already handled by Pure, or not relevant to poly

control-flow-sink moves ops that are only used in one branch of a conditional into the relevant branch. Requires the op to be memory-effect free, which Pure already covers. To demonstrate, I added the Pure trait to poly.from_tensor and added a test in this commit.

cse is constant subexpression elimination, which removes unnecessarily repeated computations when possible. Again, a lack of memory-effects suffices. Demonstrated in this commit.

-inline inlines function calls, which does not apply to poly.

-mem2reg replaces memory store/loads with direct usage of the underlying value, when possible. Should not require any changes and not interesting enough for me to demo here.

-remove-dead-values does things like remove function arguments that are unused, or return values that are not used by any caller. Should not require any changes.

-sroa is “scalar replacement of aggregates,” which seems like it is about reshuffling memory allocations around. Not really sure why this is useful.

-symbol-dce eliminates dead private functions, which does not apply to poly.

Punting one pass to next time

-sccp is sparse conditional constant propagation, which attempts to infer when an operation has a constant output, and then replaces the operation with the constant value. Repeating this, it “propagates” the constants as far as possible through the program. To support this requires a bit of extra work and requires me to introduces some more concepts, so I’ll do that next time.

Elementwise mappings

Now that we’ve gone through the most generic passes, I’ll cover some remaining traits I’m aware of.

There are a number of traits that extend scalar operations to tensor operations and vice versa. Elementwise, Scalarizable, Tensorizable, and Vectorizable, whose docs you can read in detail here, essentially allow you to use ops that work on scalars in tensors in the natural way. The trait list ElementwiseMappable combines them into a single trait. This commit demonstrates how adding the trait allows the poly.add op to magically work on tensor arguments. It also requires relaxing the ins arguments to the op in the tablegen, and we do this by using a so-called type constraint that permits polynomials and tensors containing them

// PolyOps.td
def PolyOrContainer : TypeOrContainer<Polynomial, "poly-or-container">;

class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, ElementwiseMappable]> {
  let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs);
  let results = (outs PolyOrContainer:$output);
  ...
}

Behind the hood, the generated code has a new type checking routine used in parsing:

// PolyOps.cpp.inc
static ::mlir::LogicalResult __mlir_ods_local_type_constraint_PolyOps0(
    ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
    unsigned valueIndex) {
  if (!(((::llvm::isa<::mlir::tutorial::poly::PolynomialType>(type))) || ((((::llvm::isa<::mlir::VectorType>(type))) && ((::llvm::cast<::mlir::VectorType>(type).getRank() > 0))) && ([](::mlir
::Type elementType) { return (::llvm::isa<::mlir::tutorial::poly::PolynomialType>(elementType)); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))) || (((::llvm::isa<::mlir::TensorT
ype>(type))) && ([](::mlir::Type elementType) { return (::llvm::isa<::mlir::tutorial::poly::PolynomialType>(elementType)); }(::llvm::cast<::mlir::ShapedType>(type).getElementType()))))) {
    return op->emitOpError(valueKind) << " #" << valueIndex
        << " must be poly-or-container, but got " << type;
  }
  return ::mlir::success();
}

Moreover, the accessors that previously returned a PolynomialType or ::mlir::TypedValue<PolynomialType> now must return a more generic ::mlir::Type or ::mlir::Value because the values could be tensors or vectors as well. It’s left to the caller to type-switch or dyn_cast these manually during passes.

Note that after adding this, we now get the following legal syntax:

    %0 = ... -> !poly.poly<10>
    %1 = ... -> !poly.poly<10>

    %2 = tensor.from_elements %0, %1 : tensor<2x!poly.poly<10>>
    %3 = poly.add %2, %0 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>>

I would say the implied semantics is that the second poly is constant across the mapping.

Verifier traits

Some traits add extra verification checks to the operation as mixins. While we’ll cover custom verifiers in a future article, for now we can notice that the following is legal (with or without ElementwiseMappable):

    %0 = ... -> !poly.poly<10>
    %1 = ... -> !poly.poly<9>
    %2 = poly.add %0, %1 : (!poly.poly<10>, !poly.poly<9>) -> !poly.poly<10>

While one could make this legal by defining the semantics of add to embed the smaller-degree polynomial ring into the larger, we’re demonstrating traits, so we’ll add the SameOperandsAndResultElementType trait (a vectorized cousin of SameOperandsAndResultType), which asserts that the poly type in all the arguments (and elements of containers) are the same. This commit does it.

Last few unneeded traits

Involution is for operations that are their own opposite, $f(f(x)) = x$. This is a common math concept, and if we had something like a poly.neg op, it would be perfect for it. Adding it would enable a free canonicalization to remove the repeated ops.

Idempotent is for operations $f(x)$ for which $f(f(x)) = f(x)$. This is a common math concept, but none of the poly ops have it. A rounding op like ceiling or floor would. If it did apply, adding this trait would enable a free canonicalization like involution.

Broadcastable handles broadcast semantics for tensor/vector ops, which is not relevant.

Commutative is for ops whose arguments can be reordered, (including across multiple ops of the same kind). This is used by a pass here for the purpose of simplifying pattern matching, but as far as I can tell the pass is never registered or manually applied so this trait is a no-op.

AffineScope, IsolatedFromAbove, Terminator, SingleBlock, and SingleBlockImplicitTerminator are for ops that have regions and scoping rules, which we don’t.

SymbolTable is for ops that define a symbol table, which I think is mainly module (in which the symbols are functions), and this is mainly used in the -inline and -symbol-dce passes.

MLIR — Using Tablegen for Passes

Table of Contents

In the last article in this series, we defined some custom lowering passes that modified an MLIR program. Notably, we accomplished that by implementing the required interfaces of the MLIR API directly.

This is not the way that most MLIR developers work. Instead, they use a code generation tool called tablegen to generate boilerplate for them, and then only add the implementation methods that are custom to their work. In this article, we’ll convert the custom passes we wrote previously to use this infrastructure, and going forward we will use tablegen for defining new dialects as well.

The code relevant to this article is contained in this pull request, and as usual the commits are organized so that they can be read in order.

How to think about tablegen

The basic idea of tablegen is that you can define a pass—or a dialect, as we’ll do next time—in the tablegen DSL (as specialized for MLIR) and then run the mlir-tblgen binary with appropriate flags, and it will generate headers for the functions you need to implement your pass, along with default implementations of some common functions, documentation, and hooks to register the pass/dialect with the mlir-opt-like entry point tool.

It sounds nice, but I have a love hate relationship with tablegen. I personally find it to be unpleasant to use, primarily because it provides poor diagnostic information when you do things wrong. Today, though, I realize that my part of my frustration came from having the wrong initial mindset around tablegen. I thought, incorrectly, that tablegen was an abstraction layer. That is, I could write my tablegen files, build them, and only think about the parts of the generated code that I needed to implement.

Instead, I’ve found that I still need to understand the details of the generated code, enough that I may as well read the generated code closely until it becomes familiar. This materializes in a few ways:

  • mlir-tablegen doesn’t clearly tell you what functions are left unimplemented or explain the contract of the functions you have to write. For that you have to read the docs (which are often incomplete), or often the source code of the base classes that the generated boilerplate subclasses. Or, sadly, sometimes the Discourse threads or Phabricator commit messages are the only places to find the answers to some questions.
  • The main way to determine what is missing is to try to build the generated code with some code that uses it, and then sift through hundreds of lines of C++ compiler errors, which in turn requires understanding the various template gymnastics in the generated code.
  • The generated code will make use of symbols that you have to know to import or forward-declare in the right places, and it expects you to manage the namespaces in which the generated code lives.

With the expectation that tablegen is a totally see-through code generator, and then just sitting down and learning the MLIR APIs, the result is not so bad. Indeed, that was the motivation for building the passes in the last article from “scratch” (without tablegen), because it makes the transition to using tablegen more transparent. Pass generation is also a good starting point for tablegen because, by comparison, using tablegen for dialects results in many more moving parts and configurable options.

Tablegen files and the mlir-tblgen binary

We’ll start by migrating the AffineFullUnroll pass from last time. The starting point is to write some tablegen code. This commit implements it, which I’ll abbreviate below. If you want the official docs on how to do this, see Declarative Pass Specification.

include "mlir/Pass/PassBase.td"

def AffineFullUnroll : Pass<"affine-full-unroll"> {
  let summary = "Fully unroll all affine loops";
  let description = [{
    Fully unroll all affine loops. (could add more docs here like code examples)
  }];
  let dependentDialects = ["mlir::affine::AffineDialect"];
}

[Aside: In the actual commit, there are two definitions, one for the AffineFullUnroll that walks the IR, and the other for AffineFullUnrollPatternRewrite which uses the pattern rewrite engine. In the article I’ll just show the generated code for the first.]

As you can see, tablegen has concepts like classes and class inheritance (the : Pass<...> is subclassing the Pass base class defined in PassBase.td, an upstream MLIR file. But the def keyword here specifically implies we’re instantiating this thing in a way that the codegen tool should see and generate real code for (as opposed to the class keyword, which is just for tablegen templating and code reuse).

So tablegen lets you let-define string variables and lists, but one thing that won’t be apparent in this article, but will be apparent in the next article, is that tablegen lets you define variables and use them across definitions, as well as define snippets of C++ code that should be put into the generated classes (which can use the defined variables). This is visible in the present context in the PassBase.td class which defines a code constructor variable. If you have a special constructor for your pass, you can write the C++ code for it there.

Next we define a build rule using gentbl_cc_library to run the mlir-tblgen binary (in the same commit) with the right options. This gentbl_cc_library bazel rule is provided by the MLIR devs, and it basically just assembles the CLI flags to mlir-tblgen and ensures the code-genned files end up in the right places on the filesystem and are compatible with the bazel dependency system. The build rule invocation looks like

gentbl_cc_library(
    name = "pass_inc_gen",
    tbl_outs = [
        (
            [
                "-gen-pass-decls",
                "-name=Affine",
            ],
            "Passes.h.inc",
        ),
    ],
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "Passes.td",
    deps = [
        "@llvm-project//mlir:OpBaseTdFiles",
        "@llvm-project//mlir:PassBaseTdFiles",
    ],
)

The important part here is that td_file specifies our input file, and tbl_outs defines the generated file, Passes.inc.h, which is at $GIT_ROOT/bazel-bin/lib/Transform/Affine/Passes.h.inc.

The main quirk with gentbl_cc_library is that the name of the bazel rule is not the target that actually generates the code. That is, if you run bazel build pass_inc_gen (or from the git root, bazel build lib/Transform/Affine:pass_inc_gen), it won’t create the files but the build will be successful. Instead, under the hood gentbl_cc_library is a bazel macro that generates the rule pass_inc_gen_filegroup, which is what you have to bazel build to see the actual files.

I’ve pasted the generated code (with both version of the AffineFullUnroll) into a gist and will highlight the important parts here. The first quirky thing the generated code does is use #ifdef as a sort of function interface for what code is produced. For example, you will see:

#ifdef GEN_PASS_DECL_AFFINEFULLUNROLL
std::unique_ptr<::mlir::Pass> createAffineFullUnroll();
#undef GEN_PASS_DECL_AFFINEFULLUNROLL
#endif // GEN_PASS_DECL_AFFINEFULLUNROLL

#ifdef GEN_PASS_DEF_AFFINEFULLUNROLL
... <lots of C++ code> ...
#undef GEN_PASS_DEF_AFFINEFULLUNROLL
#endif // GEN_PASS_DEF_AFFINEFULLUNROLL

This means that to use this file, we will need to define the appropriate symbol in a #define macro before including this header. You can see it happening in this commit, but in brief it will look like this

// in file AffineFullUnroll.h

#define GEN_PASS_DECL_AFFINEFULLUNROLL
#include "lib/Transform/Affine/Passes.h.inc"

// in file AffineFullUnroll.cpp

#define GEN_PASS_DEF_AFFINEFULLUNROLL
#include "lib/Transform/Affine/Passes.h.inc"

... <implement the missing functions from the generated code> ...

I’m no C++ expert, and this was the first time I’d seen this pattern of using #include as a function with #define as the argument. It was a little unsettling to me, until I landed on that mindset that it’s meant to be a white-box codegen, not an abstraction. So read the generated code. Inside the GEN_PASS_DECL_... guard, it defines a single function std::unique_ptr<::mlir::Pass> createAffineFullUnroll(); that is a very limited sole entry point for code that wants to use the pass. We don’t need to implement it unless our Pass has a custom constructor. Then in the GEN_PASS_DEF_... guard it defines a base class, whose functions I’ll summarize, but you should recognize many of them because we implemented them by hand last time.

template <typename DerivedT>
class AffineFullUnrollBase : public ::mlir::OperationPass<> {
  AffineFullUnrollBase() : ::mlir::OperationPass<>(::mlir::TypeID::get<DerivedT>()) {}
  AffineFullUnrollBase(const AffineFullUnrollBase &other) : ::mlir::OperationPass<>(other) {}

  static ::llvm::StringLiteral getArgumentName() {...}
  static ::llvm::StringRef getArgument() { ... }
  static ::llvm::StringRef getDescription() { ... }
  static ::llvm::StringLiteral getPassName() { ... }
  static ::llvm::StringRef getName() { ... }

  /// Support isa/dyn_cast functionality for the derived pass class.
  static bool classof(const ::mlir::Pass *pass) { ... }

  /// A clone method to create a copy of this pass.
  std::unique_ptr<::mlir::Pass> clonePass() const override { ... }

  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
    registry.insert<mlir::affine::AffineDialect>();
  }

  ... <type_id stuff> ...
}

Notably, this doesn’t tell us what functions are left for us to implement. For that we have to either build it and read compiler error messages, or compare it to the base class (OperationPass) and it’s base class (Pass) to see that the only function left to implement is runOnOperation() Or, since we did this last time from the raw API, we can observe that the boilerplate functions we implemented before like getArgument are here, but runOnOperation is not.

Another notable aspect of the generated code is that it uses the curiously recurring template pattern (CRTP), so that the base class can know the eventual name of its subclass, and use that name to hook the concrete subclass into the rest of the framework.

Lower in the generated file you’ll see another #define-guarded block for GEN_PASS_REGISTRATION, which implements hooks for tutorial-opt to register the passes without having to depend on each internal Pass class directly.

#ifdef GEN_PASS_REGISTRATION

inline void registerAffineFullUnroll() {
  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
    return createAffineFullUnroll();
  });
}

inline void registerAffineFullUnrollPatternRewrite() {
  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
    return createAffineFullUnrollPatternRewrite();
  });
}

inline void registerAffinePasses() {
  registerAffineFullUnroll();
  registerAffineFullUnrollPatternRewrite();
}
#undef GEN_PASS_REGISTRATION
#endif // GEN_PASS_REGISTRATION

This implies that, once we link everything properly, the changes to tutorial-opt (in this commit) simplify to calling registerAffinePasses. This registration macro is intended to go into a Passes.h file that includes all the individual pass header files, as done in this commit. And we can use that header file as an anchor for a bazel build target that includes all the passes defined in lib/Transform/Affine at once.

#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_REGISTRATION
#include "lib/Transform/Affine/Passes.h.inc"

}  // namespace tutorial
}  // namespace mlir

Finally, after all this (abbreviated from this commit), the actual content of the pass reduces to the following subclass (with CRTP) and implementation of runOnOperation, the body of which is identical to the last article except for a change from reference to pointer for the return value of getOperation.

#define GEN_PASS_DEF_AFFINEFULLUNROLL
#include "lib/Transform/Affine/Passes.h.inc"

struct AffineFullUnroll : impl::AffineFullUnrollBase<AffineFullUnroll> {
  using AffineFullUnrollBase::AffineFullUnrollBase;

  void runOnOperation() {
    getOperation()->walk([&](AffineForOp op) {
      if (failed(loopUnrollFull(op))) {
        op.emitError("unrolling failed");
        signalPassFailure();
      }
    });
  }
};

I split the AffineFullUnroll migration into multiple commits to highlight the tablegen vs C++ code changes. For MulToAdd, I did it all in one commit. The tests are unchanged, because the entry point is still the tutorial-opt binary with the appropriate CLI flags, and those names are unchanged in the tablegen’ed code.

Bonus: mlir-tblgen also has an option -gen-pass-doc, which you’ll see in the commits, which generates a markdown file containing auto-generated documentation for the pass. A CI workflow can copy these to a website directory, as we do in HEIR, and you get free docs. See this example from HEIR.

Addendum: hermetic Python

When I first set up this tutorial project, I didn’t realize that bazel’s Python rules use the system Python by default. Some early readers found an error that Python couldn’t find the lit module when running tests. While pip install lit in your system Python would work, I also migrated in this commit to a hermetic python runtime and explicit dependency on lit. It should all be handled automatically by bazel now.

MLIR — Writing Our First Pass

Table of Contents

This series is an introduction to MLIR and an onboarding tutorial for the HEIR project.

Last time we saw how to run and test a basic lowering. This time we will write some simple passes to illustrate the various parts of the MLIR API and the pass infrastructure.

As mentioned previously, the main work in MLIR is defining passes that either optimize part of a program, lower from parts of one dialect to others, or perform various normalization and canonicalization operations. In this article, we’ll start by defining a pass that operates entirely within a given dialect by fully unrolling loops. Then we’ll define a pass that does a simple replacement of one instruction with another. Neither pass will be particularly complex, but rather they will show how to set up a pass, how to navigate through a program via the MLIR API, and how to modify the IR by deleting and adding operations.

The code for this post is contained within this pull request.

tutorial-opt and project organization

Last time we used the mlir-opt binary as the main entry point to parse MLIR, run a pass, and emit the output IR. A compiler might run mlir-opt as a subroutine in between the front end (C++ to some MLIR dialects) and the backend (MLIR’s LLVM dialect to LLVM to machine code).

In an out-of-tree MLIR project, mlir-opt can’t be used because it isn’t compiled with the project’s custom dialects or passes. Instead, MLIR makes it easy to build a custom version of the mlir-opt tool for an out-of-tree project. It primarily provides a set of registration hooks that you can use to plug in your dialects and passes, and the framework handles reading/writing, CLI flags, and adds that all on top of the baseline MLIR passes and dialects. We’ll start this article by creating the shell for such a tool with an empty custom pass, which we’ll call tutorial-opt. If this repository were to become one step of an end-to-end compiler, then tutorial-opt would be the main interface to the MLIR part.

The structure of the codebase is a persnickety question here. A typical MLIR codebase seems to split the code into two directories with roughly equivalent hierarchies: an include/ directory for headers and tablegen files (more on tablegen in a future article), and a lib/ directory for implementation code. Then, within those two directories a project would have a Transform/ subdirectory that stores the files for passes that transform code within a dialect, Conversion/ for passes that convert between dialects, Analysis/ for analysis passes, etc. Each of these directories might have subdirectories for the specific dialects they operate on.

For this tutorial we will do it slightly differently by merging include/ and lib/ together (header files will live next to implementation files). I believe the reason that C++ codebases separate this is a combination of implicit public/private interface (client code should only depend on headers in include/, not headers in lib/ or src/). But bazel has many more facilities for enforcing private/public interface boundaries, I find it tedious to navigate parallel directory structures, and this is a tutorial so simpler is better.

So the project’s directory structure will add like this once we create the initial pass:

.
├── README.md
├── WORKSPACE
├── bazel
│   └──  . . .
├── lib
│   └── Transform
│       └── Affine
│           ├── AffineFullUnroll.cpp
│           ├── AffineFullUnroll.h
│           └── BUILD
├── tests
│   └── . . .
└── tools
    ├── BUILD
    └── tutorial-opt.cpp

Unrolling loops, a starter pass

Though MLIR provides multiple mechanisms for defining loops and control flow, the highest level one is in the affine dialect. Originally defined for polyhedral loop analysis (using lattices to study loop structure!), it also simply defines a nice for operation that you can use whenever you have simple loop bounds like iterating over a range with an optional step size. An example loop that sums some values in an array stored in memory might look like:

func.func @sum_buffer(%buffer: memref<4xi32>) -> (i32) {
  %sum_0 = arith.constant 0 : i32
  %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
    %t = affine.load %buffer[%i] : memref<4xi32>
    %sum_next = arith.addi %sum_iter, %t : i32
    affine.yield %sum_next : i32
  }
  return %sum : i32
}

The iter_args is a custom bit of syntax that defines accumulation variables to operate across the loop body (to be in compliance with SSA form; for more on SSA, see this MLIR doc), along with an initial value.

Unrolling loops is a nontrivial operation, but thankfully MLIR provides a utility method for fully unrolling a loop, so our pass will be a thin wrapper around this function call, to showcase some of the rest of the infrastructure before we write a more meaningful pass. The code for this section is in this commit.

This implementation will be technically the most general implementation, by implementing directly from the C++ API, rather than using the more special case features like the pattern rewrite engine, the dialect conversion framework, or tablegen. Those will all come later.

The main idea is to implement the required methods for the OperationPass base class, which “anchors” the pass to work within the context of a specific instance of a specific type of operation, and is applied to every operation of that type. It looks like this:

// lib/Transform/Affine/AffineFullUnroll.h
class AffineFullUnrollPass
    : public PassWrapper<AffineFullUnrollPass,
                         OperationPass<mlir::func::FuncOp>> {
private:
  void runOnOperation() override;  // implemented in AffineFullUnroll.cpp

  StringRef getArgument() const final { return "affine-full-unroll"; }

  StringRef getDescription() const final {
    return "Fully unroll all affine loops";
  }
};

The PassWrapper helps implement some of the required methods for free (mainly adding a compliant copy method), and uses the Curiously Recurring Template Pattern (CRTP) to achieve that. But what matters for us is that OperationPass<FuncOp> anchors this pass to operation on function bodies, and provides the getOperation method in the class which returns the FuncOp being operated on.

Aside: The MLIR docs more formally describe what is required of an OperationPass, and in particular it limits the “anchoring” to specific operations like functions and modules, the insides of which are isolated from modifying the semantics of the program outside of the operation’s scope. That’s a fancy way of saying FuncOps in MLIR can’t screw with variables outside the lexical scope of their function body. More importantly for this example, it explains why we can’t anchor this pass on a for loop operation directly: a loop can modify stuff outside its body (like the contents of memory) via the operations within the loop (store, etc.). This matters because the MLIR pass infrastructure runs passes in parallel. If some other pass is tinkering with neighboring operations, race conditions ensue.

The three functions we need to implement are

  • runOnOperation: the function that performs the pass logic.
  • getArgument: the CLI argument for an mlir-opt-like tool.
  • getDescription: the CLI description when running --help on the mlir-opt-like tool.

The initial implementation of runOperation is empty in the commit for this section. Next, we create a tutorial-opt binary that registers the pass.

// tools/tutorial-opt.cpp
#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/include/mlir/InitAllDialects.h"
#include "mlir/include/mlir/Pass/PassManager.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
#include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h"

int main(int argc, char **argv) {
  mlir::DialectRegistry registry;
  mlir::registerAllDialects(registry);

  mlir::PassRegistration<mlir::tutorial::AffineFullUnrollPass>();

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry));
}

This registers all the built-in MLIR dialects, adds our AffineFullUnrollPass, and then calls the MlirOptMain function which handles the rest. At this point we can run bazel run tools:tutorial-opt --help and see a long list of options with our new pass in it.

OVERVIEW: Tutorial Pass Driver
Available Dialects: acc, affine, amdgpu, <...SNIP...>
USAGE: tutorial-opt [options] <input file>

OPTIONS:

General options:

  Compiler passes to run
    Passes:
      --affine-full-unroll                             -   Fully unroll all affine loops
  --allow-unregistered-dialect                         - Allow operation with no registered dialects
  --disable-i2p-p2i-opt                                - Disables inttoptr/ptrtoint roundtrip optimization
  <...SNIP...>

To allow us to run lit tests that use this tool, we add it to the test_utilities target in this commit, and then we add a first (failing) test in this commit. To avoid complexity, I’m just asserting that the output has no for loops in it.

// RUN: tutorial-opt %s --affine-full-unroll > %t
// RUN: FileCheck %s < %t

func.func @test_single_nested_loop(%buffer: memref<4xi32>) -> (i32) {
  %sum_0 = arith.constant 0 : i32
  // CHECK-NOT: affine.for
  %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
    %t = affine.load %buffer[%i] : memref<4xi32>
    %sum_next = arith.addi %sum_iter, %t : i32
    affine.yield %sum_next : i32
  }
  return %sum : i32
}

Next, we can implement the pass itself in this commit:

#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/include/mlir/Pass/Pass.h"

using mlir::affine::AffineForOp;
using mlir::affine::loopUnrollFull;

void AffineFullUnrollPass::runOnOperation() {
  getOperation().walk([&](AffineForOp op) {
    if (failed(loopUnrollFull(op))) {
      op.emitError("unrolling failed");
      signalPassFailure();
    }
  });
}

getOperation returns a FuncOp, though we don’t use any specific information about it being a function. We instead call the walk method (present on all Operation instances), which traverses the abstract syntax tree (AST) of the operation in post-order (i.e., the function body), and for each operation it encounters, if the type of that operation matches the input type of the callback, the callback is executed. In our case, we attempt to unroll the loop, and if it fails we quit with a diagnostic error.

Exercise: determine how the loop unrolling might fail, and create a test MLIR input that causes it to fail, and observe the error messages that result.

Running this on our test shows the operation is applied:

$ bazel run tools:tutorial-opt -- --affine-full-unroll < tests/affine_loop_unroll.mlir
<...>
#map = affine_map<(d0) -> (d0 + 1)>
#map1 = affine_map<(d0) -> (d0 + 2)>
#map2 = affine_map<(d0) -> (d0 + 3)>
module {
  func.func @test_single_nested_loop(%arg0: memref<4xi32>) -> i32 {
    %c0 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = affine.load %arg0[%c0] : memref<4xi32>
    %1 = arith.addi %c0_i32, %0 : i32
    %2 = affine.apply #map(%c0)
    %3 = affine.load %arg0[%2] : memref<4xi32>
    %4 = arith.addi %1, %3 : i32
    %5 = affine.apply #map1(%c0)
    %6 = affine.load %arg0[%5] : memref<4xi32>
    %7 = arith.addi %4, %6 : i32
    %8 = affine.apply #map2(%c0)
    %9 = affine.load %arg0[%8] : memref<4xi32>
    %10 = arith.addi %7, %9 : i32
    return %10 : i32
  }
}

I won’t explain what this affine.apply thing is doing, but suffice it to say the loop is correctly unrolled. A subsequent commit does the same test for a doubly-nested loop.

A Rewrite Pattern Version

In this commit, we rewrote the loop unroll pass in the next level of abstraction provided by MLIR: the pattern rewrite engine. It is useful in the kind of situation where one wants to repeatedly apply the same subset of transformations to a given IR substructure until that substructure is completely removed. The next section will write a pass that uses that in a meaningful way, but for now we’ll just rewrite the loop unroll pass to show the extra boilerplate.

A rewrite pattern is a subclass of OpRewritePattern, and it has a method called matchAndRewrite which performs the transformation.

struct AffineFullUnrollPattern :
  public OpRewritePattern<AffineForOp> {
  AffineFullUnrollPattern(mlir::MLIRContext *context)
      : OpRewritePattern<AffineForOp>(context, /*benefit=*/1) {}

  LogicalResult matchAndRewrite(AffineForOp op,
                                PatternRewriter &rewriter) const override {
    return loopUnrollFull(op);
  }
};

The return value of matchAndRewrite is a LogicalResult, which is a wrapper around a boolean to signal success or failure, along with named utility functions like failure() and success() to generate instances, and failed(...) to test for failure. LogicalResult also comes with a subclass FailureOr that is subclass of optional that inter-operates with LogicalResult via the presence or absence of a value.

Aside: In a proper OpRewritePattern, the mutations of the IR must go through the PatternRewriter argument, but because loopUnrollFull doesn’t have a variant that takes a PatternRewriter as input, we’re violating that part of the function contract. More generally, the PatternRewriter handles atomicity of the mutations that occur within the OpRewritePattern, ensuring that the operations are applied only if the method reaches the end and succeeds.

Then we instantiate the pattern inside the pass

// A pass that invokes the pattern rewrite engine.
void AffineFullUnrollPassAsPatternRewrite::runOnOperation() {
  mlir::RewritePatternSet patterns(&getContext());
  patterns.add<AffineFullUnrollPattern>(&getContext());
  // One could use GreedyRewriteConfig here to slightly tweak the behavior of
  // the pattern application.
  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

The overall pass is still anchored on FuncOp, but an OpRewritePattern can match against any op. The rewrite engine invokes the walk that we did manually, and one can pass an optional configuration struct that chooses the walk order.

The PatternSet can accept any number of patterns, and the greedy rewrite engine will keep trying to apply them (in a certain order related to the benefit constructor argument) until there are no matching operations to apply, all applied patterns return failure, or some large iteration limit is reached to avoid infinite loops.

A proper greedy RewritePattern

In homomorphic encryption (FHE), multiplication ops are typically MUCH more expensive than addition ops. I’m not an expert in classical CPU hardware performance, but if I recall correctly, on a typical CPU multiplication is something like 2-4x slower than addition, and that advantage probably goes away when doing multiplications in bulk/pipelines.

In FHE, the operations introduce noise growth in the ciphertext, and in some schemes multiplication introduces something like 100x the noise of addition, and reducing that noise is very time consuming. So it makes sense that one would want to convert multiplication ops into addition ops. In this section we’ll write a very simple pass that greedily rewrites multiplication ops as repeated additions.

The idea is to rewrite an operation like y = 9*x as y = 8*x + x (the 8 is a power of 2) and then expand it further as a = x+x; b = a+a; c = b+b; y = c+x. It replaces a multiplication by a constant with a roughly log-number of additions (the base-2 logarithm of the constant), though it gets worse the further away the constant gets from a power of two.

This commit contains a similar “empty shell” of a pass, with two patterns defined. The first, PowerOfTwoExpand, will be a pattern that rewrites y=C*x as y = C/2*x + C/2*x, when C is a power of 2, otherwise fails. The second, PeelFromMul “peels” a single addition off a product that is not with a power of 2, rewriting y = 9x as y = 8*x + x. These are applied repeatedly via the greedy pattern rewrite engine. By setting the benefit argument of PowerOfTwoExpand to be larger than PeelFromMul, we tell the greedy rewrite engine to prefer PowerOfTwoExpand whenever possible. Together, that achieves the transformation mentioned above.

This commit adds a failing test that only exercises PowerOfTwoExpand, and then this commit implements it. Here’s the implementation:

  LogicalResult matchAndRewrite(
       MulIOp op, PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);

    // canonicalization patterns ensure the constant is on the right, if there is a constant
    // See https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules
    Value rhs = op.getOperand(1);
    auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
    if (!rhsDefiningOp) {
      return failure();
    }

    int64_t value = rhsDefiningOp.value();
    bool is_power_of_two = (value & (value - 1)) == 0;

    if (!is_power_of_two) {
      return failure();
    }

    ConstantOp newConstant = rewriter.create<ConstantOp>(
        rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value / 2));
    MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
    AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, newMul);

    rewriter.replaceOp(op, {newAdd});
    rewriter.eraseOp(rhsDefiningOp);

    return success();
  }

Some notes:

  • Value is the type that represents an SSA value (i.e., an MLIR variable), and getDefiningOp fetches the unique operation that defines it in its scope.
  • There are a variety of “casting” operations like rhs.getDefiningOp<arith::ConstantIntOp>() that take the type you want as output as a template parameter, and return null if the type cannot be converted. You might also see dyn_cast<>
  • (value & (value - 1)) is a classic bit-twiddling trick to compute if an integer is a power of two. We check it and skip the pattern if it’s not.
  • The actual constant itself is represented as an MLIR attribute, which is essentially compile-time static data attached to the op. You can put strings or dictionaries as attributes, but for ConstantOp it’s just an int.

The rewriter.create part is where we actually do the real work. Create a new constant that is half the original constant, create new multiplication and addition ops, and then finally rewriter.replaceOp removes the original multiplication op and uses the output of newAdd for any other operations that used the original multiplication op’s output.

It’s worth noting that we’re relying on MLIR’s built-in canonicalization passes in a few ways here:

  • To ensure that the constant is always the second operand of a multiplication op.
  • To ensure that the base case (x*1) is “folded” into a plain x and the constant 1 is removed.
  • The fold part of applyPatternsAndFoldGreedily is what runs these cleanup steps for us.

PeelFromMul is similar, implemented and tested in this commit:

  LogicalResult matchAndRewrite(MulIOp op,
                                PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);
    Value rhs = op.getOperand(1);
    auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
    if (!rhsDefiningOp) { return failure(); }
    int64_t value = rhsDefiningOp.value();

    // We are guaranteed `value` is not a power of two, because the greedy
    // rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
    // it has higher benefit.

    ConstantOp newConstant = rewriter.create<ConstantOp>(
        rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value - 1));
    MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
    AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, lhs);

    rewriter.replaceOp(op, {newAdd});
    rewriter.eraseOp(rhsDefiningOp);

Running it! Input:

func.func @power_of_two_plus_one(%arg: i32) -> i32 {
  %0 = arith.constant 9 : i32
  %1 = arith.muli %arg, %0 : i32
  func.return %1 : i32
}

Output:

module { 
  func.func @power_of_two_plus_one(%arg0: i32) -> i32 {
    %0 = arith.addi %arg0, %arg0 : i32
    %1 = arith.addi %0, %0 : i32
    %2 = arith.addi %1, %1 : i32
    %3 = arith.addi %2, %arg0 : i32
    return %3 : i32
  }
}

Exercise: Try swapping the benefit arguments to see how the output changes.

Though this pass is quite naive, you can imagine a more sophisticated technique that builds a cost model for multiplications and additions, and optimizes for the cheapest cost representation of an arithmetic operation in terms of repeated additions, multiplications, and other supported ops.

Should we walk?

With two options for how to define a pass—one to walk the entire syntax tree from the root operation, and one to match and rewrite patterns with the rewrite engine—the natural question is when should you use one versus the other.

The MLIR docs describe the motivation behind the pattern rewrite engine, and it comes from a long history of experience with the LLVM project. For one, the pattern rewrite engine expresses a convenient subset of what can be achieved with an MLIR pass. This is conceptually trivial, in the sense that anyone who can walk the entire AST can, with enough effort, do anything they want including reimplementing the pattern rewrite engine.

More practically, the pattern rewrite engine is convenient to represent local transformations. “Local” here means that the input and output can be detected via a subset of the AST as a directed acyclic graph. More pragmatically, think of it as any operation you can identify by looking around at neighboring operations in the same block and applying some filtering logic. E.g., “is this exp operation followed by a log operation with no other uses of the output of the exp?”

On the other hand, some analyses and optimizations need to construct the entire dataflow of a program to work. A good example is common subexpression elimination, which determines whether it is cost effective to extract a subexpression used in multiple places into a separate variable. Doing so may introduce additional cost of memory access, so it depends both on the operation’s cost and on the availability of registers at that point in the program. You can’t get this information by pattern matching the AST locally.

The wisdom seems to be: using the pattern rewrite engine is usually easier than writing a pass that. walks the AST. You don’t need large case/switch statements to handle everything that could show up in the IR. The engine handles re-applying patterns many times. And so you can write the patterns in isolation and trust the engine to combine them appropriately.

Bonus: IDEs and CI

Since we explored the C++ API, it helps to have an IDE integration. I use neovim with the clangd LSP, and to make it work with a Bazel C++ project, one needs to use something analogous to Hedron Vision’s compile_commands extractor, which I configured for this tutorial project in this commit. It’s optional, but if you want to use it you have to run bazel run @hedron_compile_commands//:refresh_all once to set it up, and then clangd and clang-tidy, etc., should find the generated json file and use it. Also, if you edit a BUILD file, you have to re-run refresh_all for the changes to show up in the LSP.

Though it’s not particularly relevant to this tutorial, I also added a commit that configures GitHub actions to build and test the project in CI in this commit. It is worth noting that the GitHub cache action reduces subsequent build times from 1-2 hours down to just a few minutes.

Thanks to Patrick Schmidt for feedback on a draft of this article.