Table of Contents

This series is an introduction to MLIR and an onboarding tutorial for the HEIR project.

Last time we saw how to run and test a basic lowering. This time we will write some simple passes to illustrate the various parts of the MLIR API and the pass infrastructure.

As mentioned previously, the main work in MLIR is defining passes that either optimize part of a program, lower from parts of one dialect to others, or perform various normalization and canonicalization operations. In this article, we’ll start by defining a pass that operates entirely within a given dialect by fully unrolling loops. Then we’ll define a pass that does a simple replacement of one instruction with another. Neither pass will be particularly complex, but rather they will show how to set up a pass, how to navigate through a program via the MLIR API, and how to modify the IR by deleting and adding operations.

The code for this post is contained within this pull request.

tutorial-opt and project organization

Last time we used the mlir-opt binary as the main entry point to parse MLIR, run a pass, and emit the output IR. A compiler might run mlir-opt as a subroutine in between the front end (C++ to some MLIR dialects) and the backend (MLIR’s LLVM dialect to LLVM to machine code).

In an out-of-tree MLIR project, mlir-opt can’t be used because it isn’t compiled with the project’s custom dialects or passes. Instead, MLIR makes it easy to build a custom version of the mlir-opt tool for an out-of-tree project. It primarily provides a set of registration hooks that you can use to plug in your dialects and passes, and the framework handles reading/writing, CLI flags, and adds that all on top of the baseline MLIR passes and dialects. We’ll start this article by creating the shell for such a tool with an empty custom pass, which we’ll call tutorial-opt. If this repository were to become one step of an end-to-end compiler, then tutorial-opt would be the main interface to the MLIR part.

The structure of the codebase is a persnickety question here. A typical MLIR codebase seems to split the code into two directories with roughly equivalent hierarchies: an include/ directory for headers and tablegen files (more on tablegen in a future article), and a lib/ directory for implementation code. Then, within those two directories a project would have a Transform/ subdirectory that stores the files for passes that transform code within a dialect, Conversion/ for passes that convert between dialects, Analysis/ for analysis passes, etc. Each of these directories might have subdirectories for the specific dialects they operate on.

For this tutorial we will do it slightly differently by merging include/ and lib/ together (header files will live next to implementation files). I believe the reason that C++ codebases separate this is a combination of implicit public/private interface (client code should only depend on headers in include/, not headers in lib/ or src/). But bazel has many more facilities for enforcing private/public interface boundaries, I find it tedious to navigate parallel directory structures, and this is a tutorial so simpler is better.

So the project’s directory structure will add like this once we create the initial pass:

.
├── README.md
├── WORKSPACE
├── bazel
│   └──  . . .
├── lib
│   └── Transform
│       └── Affine
│           ├── AffineFullUnroll.cpp
│           ├── AffineFullUnroll.h
│           └── BUILD
├── tests
│   └── . . .
└── tools
    ├── BUILD
    └── tutorial-opt.cpp

Unrolling loops, a starter pass

Though MLIR provides multiple mechanisms for defining loops and control flow, the highest level one is in the affine dialect. Originally defined for polyhedral loop analysis (using lattices to study loop structure!), it also simply defines a nice for operation that you can use whenever you have simple loop bounds like iterating over a range with an optional step size. An example loop that sums some values in an array stored in memory might look like:

func.func @sum_buffer(%buffer: memref<4xi32>) -> (i32) {
  %sum_0 = arith.constant 0 : i32
  %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
    %t = affine.load %buffer[%i] : memref<4xi32>
    %sum_next = arith.addi %sum_iter, %t : i32
    affine.yield %sum_next : i32
  }
  return %sum : i32
}

The iter_args is a custom bit of syntax that defines accumulation variables to operate across the loop body (to be in compliance with SSA form; for more on SSA, see this MLIR doc), along with an initial value.

Unrolling loops is a nontrivial operation, but thankfully MLIR provides a utility method for fully unrolling a loop, so our pass will be a thin wrapper around this function call, to showcase some of the rest of the infrastructure before we write a more meaningful pass. The code for this section is in this commit.

This implementation will be technically the most general implementation, by implementing directly from the C++ API, rather than using the more special case features like the pattern rewrite engine, the dialect conversion framework, or tablegen. Those will all come later.

The main idea is to implement the required methods for the OperationPass base class, which “anchors” the pass to work within the context of a specific instance of a specific type of operation, and is applied to every operation of that type. It looks like this:

// lib/Transform/Affine/AffineFullUnroll.h
class AffineFullUnrollPass
    : public PassWrapper<AffineFullUnrollPass,
                         OperationPass<mlir::func::FuncOp>> {
private:
  void runOnOperation() override;  // implemented in AffineFullUnroll.cpp

  StringRef getArgument() const final { return "affine-full-unroll"; }

  StringRef getDescription() const final {
    return "Fully unroll all affine loops";
  }
};

The PassWrapper helps implement some of the required methods for free (mainly adding a compliant copy method), and uses the Curiously Recurring Template Pattern (CRTP) to achieve that. But what matters for us is that OperationPass<FuncOp> anchors this pass to operation on function bodies, and provides the getOperation method in the class which returns the FuncOp being operated on.

Aside: The MLIR docs more formally describe what is required of an OperationPass, and in particular it limits the “anchoring” to specific operations like functions and modules, the insides of which are isolated from modifying the semantics of the program outside of the operation’s scope. That’s a fancy way of saying FuncOps in MLIR can’t screw with variables outside the lexical scope of their function body. More importantly for this example, it explains why we can’t anchor this pass on a for loop operation directly: a loop can modify stuff outside its body (like the contents of memory) via the operations within the loop (store, etc.). This matters because the MLIR pass infrastructure runs passes in parallel. If some other pass is tinkering with neighboring operations, race conditions ensue.

The three functions we need to implement are

The initial implementation of runOperation is empty in the commit for this section. Next, we create a tutorial-opt binary that registers the pass.

// tools/tutorial-opt.cpp
#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/include/mlir/InitAllDialects.h"
#include "mlir/include/mlir/Pass/PassManager.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
#include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h"

int main(int argc, char **argv) {
  mlir::DialectRegistry registry;
  mlir::registerAllDialects(registry);

  mlir::PassRegistration<mlir::tutorial::AffineFullUnrollPass>();

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry));
}

This registers all the built-in MLIR dialects, adds our AffineFullUnrollPass, and then calls the MlirOptMain function which handles the rest. At this point we can run bazel run tools:tutorial-opt --help and see a long list of options with our new pass in it.

OVERVIEW: Tutorial Pass Driver
Available Dialects: acc, affine, amdgpu, <...SNIP...>
USAGE: tutorial-opt [options] <input file>

OPTIONS:

General options:

  Compiler passes to run
    Passes:
      --affine-full-unroll                             -   Fully unroll all affine loops
  --allow-unregistered-dialect                         - Allow operation with no registered dialects
  --disable-i2p-p2i-opt                                - Disables inttoptr/ptrtoint roundtrip optimization
  <...SNIP...>

To allow us to run lit tests that use this tool, we add it to the test_utilities target in this commit, and then we add a first (failing) test in this commit. To avoid complexity, I’m just asserting that the output has no for loops in it.

// RUN: tutorial-opt %s --affine-full-unroll > %t
// RUN: FileCheck %s < %t

func.func @test_single_nested_loop(%buffer: memref<4xi32>) -> (i32) {
  %sum_0 = arith.constant 0 : i32
  // CHECK-NOT: affine.for
  %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
    %t = affine.load %buffer[%i] : memref<4xi32>
    %sum_next = arith.addi %sum_iter, %t : i32
    affine.yield %sum_next : i32
  }
  return %sum : i32
}

Next, we can implement the pass itself in this commit:

#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/include/mlir/Pass/Pass.h"

using mlir::affine::AffineForOp;
using mlir::affine::loopUnrollFull;

void AffineFullUnrollPass::runOnOperation() {
  getOperation().walk([&](AffineForOp op) {
    if (failed(loopUnrollFull(op))) {
      op.emitError("unrolling failed");
      signalPassFailure();
    }
  });
}

getOperation returns a FuncOp, though we don’t use any specific information about it being a function. We instead call the walk method (present on all Operation instances), which traverses the abstract syntax tree (AST) of the operation in post-order (i.e., the function body), and for each operation it encounters, if the type of that operation matches the input type of the callback, the callback is executed. In our case, we attempt to unroll the loop, and if it fails we quit with a diagnostic error.

Exercise: determine how the loop unrolling might fail, and create a test MLIR input that causes it to fail, and observe the error messages that result.

Running this on our test shows the operation is applied:

$ bazel run tools:tutorial-opt -- --affine-full-unroll < tests/affine_loop_unroll.mlir
<...>
#map = affine_map<(d0) -> (d0 + 1)>
#map1 = affine_map<(d0) -> (d0 + 2)>
#map2 = affine_map<(d0) -> (d0 + 3)>
module {
  func.func @test_single_nested_loop(%arg0: memref<4xi32>) -> i32 {
    %c0 = arith.constant 0 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = affine.load %arg0[%c0] : memref<4xi32>
    %1 = arith.addi %c0_i32, %0 : i32
    %2 = affine.apply #map(%c0)
    %3 = affine.load %arg0[%2] : memref<4xi32>
    %4 = arith.addi %1, %3 : i32
    %5 = affine.apply #map1(%c0)
    %6 = affine.load %arg0[%5] : memref<4xi32>
    %7 = arith.addi %4, %6 : i32
    %8 = affine.apply #map2(%c0)
    %9 = affine.load %arg0[%8] : memref<4xi32>
    %10 = arith.addi %7, %9 : i32
    return %10 : i32
  }
}

I won’t explain what this affine.apply thing is doing, but suffice it to say the loop is correctly unrolled. A subsequent commit does the same test for a doubly-nested loop.

A Rewrite Pattern Version

In this commit, we rewrote the loop unroll pass in the next level of abstraction provided by MLIR: the pattern rewrite engine. It is useful in the kind of situation where one wants to repeatedly apply the same subset of transformations to a given IR substructure until that substructure is completely removed. The next section will write a pass that uses that in a meaningful way, but for now we’ll just rewrite the loop unroll pass to show the extra boilerplate.

A rewrite pattern is a subclass of OpRewritePattern, and it has a method called matchAndRewrite which performs the transformation.

struct AffineFullUnrollPattern :
  public OpRewritePattern<AffineForOp> {
  AffineFullUnrollPattern(mlir::MLIRContext *context)
      : OpRewritePattern<AffineForOp>(context, /*benefit=*/1) {}

  LogicalResult matchAndRewrite(AffineForOp op,
                                PatternRewriter &rewriter) const override {
    return loopUnrollFull(op);
  }
};

The return value of matchAndRewrite is a LogicalResult, which is a wrapper around a boolean to signal success or failure, along with named utility functions like failure() and success() to generate instances, and failed(...) to test for failure. LogicalResult also comes with a subclass FailureOr that is subclass of optional that inter-operates with LogicalResult via the presence or absence of a value.

Aside: In a proper OpRewritePattern, the mutations of the IR must go through the PatternRewriter argument, but because loopUnrollFull doesn’t have a variant that takes a PatternRewriter as input, we’re violating that part of the function contract. More generally, the PatternRewriter handles atomicity of the mutations that occur within the OpRewritePattern, ensuring that the operations are applied only if the method reaches the end and succeeds.

Then we instantiate the pattern inside the pass

// A pass that invokes the pattern rewrite engine.
void AffineFullUnrollPassAsPatternRewrite::runOnOperation() {
  mlir::RewritePatternSet patterns(&getContext());
  patterns.add<AffineFullUnrollPattern>(&getContext());
  // One could use GreedyRewriteConfig here to slightly tweak the behavior of
  // the pattern application.
  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

The overall pass is still anchored on FuncOp, but an OpRewritePattern can match against any op. The rewrite engine invokes the walk that we did manually, and one can pass an optional configuration struct that chooses the walk order.

The PatternSet can accept any number of patterns, and the greedy rewrite engine will keep trying to apply them (in a certain order related to the benefit constructor argument) until there are no matching operations to apply, all applied patterns return failure, or some large iteration limit is reached to avoid infinite loops.

A proper greedy RewritePattern

In this section, we’re going to pretend we’re in a computational model where multiplication ops are much more expensive than addition ops. I’m not an expert in classical CPU hardware performance, but if I recall correctly, on a typical CPU multiplication is something like 2-4x slower than addition, and that advantage probably goes away when doing multiplications in bulk/pipelines. So you can imagine we’re not in a classical CPU hardware model.

The idea is to rewrite an operation like y = 9*x as y = 8*x + x (the 8 is a power of 2) and then expand it further as a = x+x; b = a+a; c = b+b; y = c+x. It replaces a multiplication by a constant with a roughly log-number of additions (the base-2 logarithm of the constant), though it gets worse the further away the constant gets from a power of two.

This commit contains a similar “empty shell” of a pass, with two patterns defined. The first, PowerOfTwoExpand, will be a pattern that rewrites y=C*x as y = C/2*x + C/2*x, when C is a power of 2, otherwise fails. The second, PeelFromMul “peels” a single addition off a product that is not with a power of 2, rewriting y = 9*x as y = 8*x + x. These are applied repeatedly via the greedy pattern rewrite engine. By setting the benefit argument of PowerOfTwoExpand to be larger than PeelFromMul, we tell the greedy rewrite engine to prefer PowerOfTwoExpand whenever possible. Together, that achieves the transformation mentioned above.

This commit adds a failing test that only exercises PowerOfTwoExpand, and then this commit implements it. Here’s the implementation:

  LogicalResult matchAndRewrite(
       MulIOp op, PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);

    // canonicalization patterns ensure the constant is on the right, if there is a constant
    // See https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules
    Value rhs = op.getOperand(1);
    auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
    if (!rhsDefiningOp) {
      return failure();
    }

    int64_t value = rhsDefiningOp.value();
    bool is_power_of_two = (value & (value - 1)) == 0;

    if (!is_power_of_two) {
      return failure();
    }

    ConstantOp newConstant = rewriter.create<ConstantOp>(
        rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value / 2));
    MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
    AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, newMul);

    rewriter.replaceOp(op, {newAdd});
    rewriter.eraseOp(rhsDefiningOp);

    return success();
  }

Some notes:

The rewriter.create part is where we actually do the real work. Create a new constant that is half the original constant, create new multiplication and addition ops, and then finally rewriter.replaceOp removes the original multiplication op and uses the output of newAdd for any other operations that used the original multiplication op’s output.

It’s worth noting that we’re relying on MLIR’s built-in canonicalization passes in a few ways here:

PeelFromMul is similar, implemented and tested in this commit:

  LogicalResult matchAndRewrite(MulIOp op,
                                PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);
    Value rhs = op.getOperand(1);
    auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
    if (!rhsDefiningOp) { return failure(); }
    int64_t value = rhsDefiningOp.value();

    // We are guaranteed `value` is not a power of two, because the greedy
    // rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
    // it has higher benefit.

    ConstantOp newConstant = rewriter.create<ConstantOp>(
        rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value - 1));
    MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
    AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, lhs);

    rewriter.replaceOp(op, {newAdd});
    rewriter.eraseOp(rhsDefiningOp);

Running it! Input:

func.func @power_of_two_plus_one(%arg: i32) -> i32 {
  %0 = arith.constant 9 : i32
  %1 = arith.muli %arg, %0 : i32
  func.return %1 : i32
}

Output:

module {
  func.func @power_of_two_plus_one(%arg0: i32) -> i32 {
    %0 = arith.addi %arg0, %arg0 : i32
    %1 = arith.addi %0, %0 : i32
    %2 = arith.addi %1, %1 : i32
    %3 = arith.addi %2, %arg0 : i32
    return %3 : i32
  }
}

Exercise: Try swapping the benefit arguments to see how the output changes.

Exercise: When multiplying by a power of two, replace it with an appropriate left-shift op instead. Browse the arith dialect docs to find the right op, and then browse the ArithOps.td tablegen file to find the name of the right op. We’ll discuss the syntax of tablegen and op definitions in the next two articles.

Though this pass is quite naive, you can imagine a more sophisticated technique that builds a cost model for multiplications and additions, and optimizes for the cheapest cost representation of an arithmetic operation in terms of repeated additions, multiplications, and other supported ops.

Should we walk?

With two options for how to define a pass—one to walk the entire syntax tree from the root operation, and one to match and rewrite patterns with the rewrite engine—the natural question is when should you use one versus the other.

The MLIR docs describe the motivation behind the pattern rewrite engine, and it comes from a long history of experience with the LLVM project. For one, the pattern rewrite engine expresses a convenient subset of what can be achieved with an MLIR pass. This is conceptually trivial, in the sense that anyone who can walk the entire AST can, with enough effort, do anything they want including reimplementing the pattern rewrite engine.

More practically, the pattern rewrite engine is convenient to represent local transformations. “Local” here means that the input and output can be detected via a subset of the AST as a directed acyclic graph. More pragmatically, think of it as any operation you can identify by looking around at neighboring operations in the same block and applying some filtering logic. E.g., “is this exp operation followed by a log operation with no other uses of the output of the exp?”

On the other hand, some analyses and optimizations need to construct the entire dataflow of a program to work. A good example is common subexpression elimination, which determines whether it is cost effective to extract a subexpression used in multiple places into a separate variable. Doing so may introduce additional cost of memory access, so it depends both on the operation’s cost and on the availability of registers at that point in the program. You can’t get this information by pattern matching the AST locally.

The wisdom seems to be: using the pattern rewrite engine is usually easier than writing a pass that walks the AST. You don’t need large case/switch statements to handle everything that could show up in the IR. The engine handles re-applying patterns many times. And so you can write the patterns in isolation and trust the engine to combine them appropriately.

Bonus: IDEs and CI

Since we explored the C++ API, it helps to have an IDE integration. I use neovim with the clangd LSP, and to make it work with a Bazel C++ project, one needs to use something analogous to Hedron Vision’s compile_commands extractor, which I configured for this tutorial project in this commit. It’s optional, but if you want to use it you have to run bazel run @hedron_compile_commands//:refresh_all once to set it up, and then clangd and clang-tidy, etc., should find the generated json file and use it. Also, if you edit a BUILD file, you have to re-run refresh_all for the changes to show up in the LSP.

Though it’s not particularly relevant to this tutorial, I also added a commit that configures GitHub actions to build and test the project in CI in this commit. It is worth noting that the GitHub cache action reduces subsequent build times from 1-2 hours down to just a few minutes.

Thanks to Patrick Schmidt for feedback on a draft of this article, and Florent Michel for formatting corrections.


Want to respond? Send me an email, post a webmention, or find me elsewhere on the internet.

DOI: https://doi.org/10.59350/46txm-t5b81