This series is an introduction to MLIR and an onboarding tutorial for the HEIR project.
Last time we saw how to run and test a basic lowering. This time we will write some simple passes to illustrate the various parts of the MLIR API and the pass infrastructure.
As mentioned previously, the main work in MLIR is defining passes that either optimize part of a program, lower from parts of one dialect to others, or perform various normalization and canonicalization operations. In this article, we’ll start by defining a pass that operates entirely within a given dialect by fully unrolling loops. Then we’ll define a pass that does a simple replacement of one instruction with another. Neither pass will be particularly complex, but rather they will show how to set up a pass, how to navigate through a program via the MLIR API, and how to modify the IR by deleting and adding operations.
The code for this post is contained within this pull request.
tutorial-opt and project organization
Last time we used the mlir-opt
binary as the main entry point to parse MLIR, run a pass, and emit the output IR. A compiler might run mlir-opt
as a subroutine in between the front end (C++ to some MLIR dialects) and the backend (MLIR’s LLVM dialect to LLVM to machine code).
In an out-of-tree MLIR project, mlir-opt
can’t be used because it isn’t compiled with the project’s custom dialects or passes. Instead, MLIR makes it easy to build a custom version of the mlir-opt
tool for an out-of-tree project. It primarily provides a set of registration hooks that you can use to plug in your dialects and passes, and the framework handles reading/writing, CLI flags, and adds that all on top of the baseline MLIR passes and dialects. We’ll start this article by creating the shell for such a tool with an empty custom pass, which we’ll call tutorial-opt
. If this repository were to become one step of an end-to-end compiler, then tutorial-opt
would be the main interface to the MLIR part.
The structure of the codebase is a persnickety question here. A typical MLIR codebase seems to split the code into two directories with roughly equivalent hierarchies: an include/
directory for headers and tablegen files (more on tablegen in a future article), and a lib/
directory for implementation code. Then, within those two directories a project would have a Transform/
subdirectory that stores the files for passes that transform code within a dialect, Conversion/
for passes that convert between dialects, Analysis/
for analysis passes, etc. Each of these directories might have subdirectories for the specific dialects they operate on.
For this tutorial we will do it slightly differently by merging include/
and lib/
together (header files will live next to implementation files). I believe the reason that C++ codebases separate this is a combination of implicit public/private interface (client code should only depend on headers in include/
, not headers in lib/
or src/
). But bazel has many more facilities for enforcing private/public interface boundaries, I find it tedious to navigate parallel directory structures, and this is a tutorial so simpler is better.
So the project’s directory structure will add like this once we create the initial pass:
.
├── README.md
├── WORKSPACE
├── bazel
│ └── . . .
├── lib
│ └── Transform
│ └── Affine
│ ├── AffineFullUnroll.cpp
│ ├── AffineFullUnroll.h
│ └── BUILD
├── tests
│ └── . . .
└── tools
├── BUILD
└── tutorial-opt.cpp
Unrolling loops, a starter pass
Though MLIR provides multiple mechanisms for defining loops and control flow, the highest level one is in the affine
dialect. Originally defined for polyhedral loop analysis (using lattices to study loop structure!), it also simply defines a nice for
operation that you can use whenever you have simple loop bounds like iterating over a range with an optional step size. An example loop that sums some values in an array stored in memory might look like:
func.func @sum_buffer(%buffer: memref<4xi32>) -> (i32) {
%sum_0 = arith.constant 0 : i32
%sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
%t = affine.load %buffer[%i] : memref<4xi32>
%sum_next = arith.addi %sum_iter, %t : i32
affine.yield %sum_next : i32
}
return %sum : i32
}
The iter_args
is a custom bit of syntax that defines accumulation variables to operate across the loop body (to be in compliance with SSA form; for more on SSA, see this MLIR doc), along with an initial value.
Unrolling loops is a nontrivial operation, but thankfully MLIR provides a utility method for fully unrolling a loop, so our pass will be a thin wrapper around this function call, to showcase some of the rest of the infrastructure before we write a more meaningful pass. The code for this section is in this commit.
This implementation will be technically the most general implementation, by implementing directly from the C++ API, rather than using the more special case features like the pattern rewrite engine, the dialect conversion framework, or tablegen. Those will all come later.
The main idea is to implement the required methods for the OperationPass
base class, which “anchors” the pass to work within the context of a specific instance of a specific type of operation, and is applied to every operation of that type. It looks like this:
// lib/Transform/Affine/AffineFullUnroll.h
class AffineFullUnrollPass
: public PassWrapper<AffineFullUnrollPass,
OperationPass<mlir::func::FuncOp>> {
private:
void runOnOperation() override; // implemented in AffineFullUnroll.cpp
StringRef getArgument() const final { return "affine-full-unroll"; }
StringRef getDescription() const final {
return "Fully unroll all affine loops";
}
};
The PassWrapper
helps implement some of the required methods for free (mainly adding a compliant copy method), and uses the Curiously Recurring Template Pattern (CRTP) to achieve that. But what matters for us is that OperationPass<FuncOp>
anchors this pass to operation on function bodies, and provides the getOperation
method in the class which returns the FuncOp
being operated on.
Aside: The MLIR docs more formally describe what is required of an OperationPass, and in particular it limits the “anchoring” to specific operations like functions and modules, the insides of which are isolated from modifying the semantics of the program outside of the operation’s scope. That’s a fancy way of saying FuncOps
in MLIR can’t screw with variables outside the lexical scope of their function body. More importantly for this example, it explains why we can’t anchor this pass on a for
loop operation directly: a loop can modify stuff outside its body (like the contents of memory) via the operations within the loop (store, etc.). This matters because the MLIR pass infrastructure runs passes in parallel. If some other pass is tinkering with neighboring operations, race conditions ensue.
The three functions we need to implement are
runOnOperation
: the function that performs the pass logic.getArgument
: the CLI argument for anmlir-opt
-like tool.getDescription
: the CLI description when running--help
on themlir-opt
-like tool.
The initial implementation of runOperation
is empty in the commit for this section. Next, we create a tutorial-opt
binary that registers the pass.
// tools/tutorial-opt.cpp
#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/include/mlir/InitAllDialects.h"
#include "mlir/include/mlir/Pass/PassManager.h"
#include "mlir/include/mlir/Pass/PassRegistry.h"
#include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h"
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::PassRegistration<mlir::tutorial::AffineFullUnrollPass>();
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry));
}
This registers all the built-in MLIR dialects, adds our AffineFullUnrollPass
, and then calls the MlirOptMain
function which handles the rest. At this point we can run bazel run tools:tutorial-opt --help
and see a long list of options with our new pass in it.
OVERVIEW: Tutorial Pass Driver
Available Dialects: acc, affine, amdgpu, <...SNIP...>
USAGE: tutorial-opt [options] <input file>
OPTIONS:
General options:
Compiler passes to run
Passes:
--affine-full-unroll - Fully unroll all affine loops
--allow-unregistered-dialect - Allow operation with no registered dialects
--disable-i2p-p2i-opt - Disables inttoptr/ptrtoint roundtrip optimization
<...SNIP...>
To allow us to run lit
tests that use this tool, we add it to the test_utilities
target in this commit, and then we add a first (failing) test in this commit. To avoid complexity, I’m just asserting that the output has no for loops in it.
// RUN: tutorial-opt %s --affine-full-unroll > %t
// RUN: FileCheck %s < %t
func.func @test_single_nested_loop(%buffer: memref<4xi32>) -> (i32) {
%sum_0 = arith.constant 0 : i32
// CHECK-NOT: affine.for
%sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 {
%t = affine.load %buffer[%i] : memref<4xi32>
%sum_next = arith.addi %sum_iter, %t : i32
affine.yield %sum_next : i32
}
return %sum : i32
}
Next, we can implement the pass itself in this commit:
#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/include/mlir/Pass/Pass.h"
using mlir::affine::AffineForOp;
using mlir::affine::loopUnrollFull;
void AffineFullUnrollPass::runOnOperation() {
getOperation().walk([&](AffineForOp op) {
if (failed(loopUnrollFull(op))) {
op.emitError("unrolling failed");
signalPassFailure();
}
});
}
getOperation
returns a FuncOp, though we don’t use any specific information about it being a function. We instead call the walk
method (present on all Operation
instances), which traverses the abstract syntax tree (AST) of the operation in post-order (i.e., the function body), and for each operation it encounters, if the type of that operation matches the input type of the callback, the callback is executed. In our case, we attempt to unroll the loop, and if it fails we quit with a diagnostic error.
Exercise: determine how the loop unrolling might fail, and create a test MLIR input that causes it to fail, and observe the error messages that result.
Running this on our test shows the operation is applied:
$ bazel run tools:tutorial-opt -- --affine-full-unroll < tests/affine_loop_unroll.mlir
<...>
#map = affine_map<(d0) -> (d0 + 1)>
#map1 = affine_map<(d0) -> (d0 + 2)>
#map2 = affine_map<(d0) -> (d0 + 3)>
module {
func.func @test_single_nested_loop(%arg0: memref<4xi32>) -> i32 {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%0 = affine.load %arg0[%c0] : memref<4xi32>
%1 = arith.addi %c0_i32, %0 : i32
%2 = affine.apply #map(%c0)
%3 = affine.load %arg0[%2] : memref<4xi32>
%4 = arith.addi %1, %3 : i32
%5 = affine.apply #map1(%c0)
%6 = affine.load %arg0[%5] : memref<4xi32>
%7 = arith.addi %4, %6 : i32
%8 = affine.apply #map2(%c0)
%9 = affine.load %arg0[%8] : memref<4xi32>
%10 = arith.addi %7, %9 : i32
return %10 : i32
}
}
I won’t explain what this affine.apply
thing is doing, but suffice it to say the loop is correctly unrolled. A subsequent commit does the same test for a doubly-nested loop.
A Rewrite Pattern Version
In this commit, we rewrote the loop unroll pass in the next level of abstraction provided by MLIR: the pattern rewrite engine. It is useful in the kind of situation where one wants to repeatedly apply the same subset of transformations to a given IR substructure until that substructure is completely removed. The next section will write a pass that uses that in a meaningful way, but for now we’ll just rewrite the loop unroll pass to show the extra boilerplate.
A rewrite pattern is a subclass of OpRewritePattern
, and it has a method called matchAndRewrite
which performs the transformation.
struct AffineFullUnrollPattern :
public OpRewritePattern<AffineForOp> {
AffineFullUnrollPattern(mlir::MLIRContext *context)
: OpRewritePattern<AffineForOp>(context, /*benefit=*/1) {}
LogicalResult matchAndRewrite(AffineForOp op,
PatternRewriter &rewriter) const override {
return loopUnrollFull(op);
}
};
The return value of matchAndRewrite
is a LogicalResult
, which is a wrapper around a boolean to signal success or failure, along with named utility functions like failure()
and success()
to generate instances, and failed(...)
to test for failure. LogicalResult
also comes with a subclass FailureOr
that is subclass of optional
that inter-operates with LogicalResult
via the presence or absence of a value.
Aside: In a proper OpRewritePattern
, the mutations of the IR must go through the PatternRewriter
argument, but because loopUnrollFull
doesn’t have a variant that takes a PatternRewriter
as input, we’re violating that part of the function contract. More generally, the PatternRewriter
handles atomicity of the mutations that occur within the OpRewritePattern
, ensuring that the operations are applied only if the method reaches the end and succeeds.
Then we instantiate the pattern inside the pass
// A pass that invokes the pattern rewrite engine.
void AffineFullUnrollPassAsPatternRewrite::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
patterns.add<AffineFullUnrollPattern>(&getContext());
// One could use GreedyRewriteConfig here to slightly tweak the behavior of
// the pattern application.
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
The overall pass is still anchored on FuncOp
, but an OpRewritePattern
can match against any op. The rewrite engine invokes the walk
that we did manually, and one can pass an optional configuration struct that chooses the walk order.
The PatternSet
can accept any number of patterns, and the greedy rewrite engine will keep trying to apply them (in a certain order related to the benefit
constructor argument) until there are no matching operations to apply, all applied patterns return failure, or some large iteration limit is reached to avoid infinite loops.
A proper greedy RewritePattern
In this section, we’re going to pretend we’re in a computational model where multiplication ops are much more expensive than addition ops. I’m not an expert in classical CPU hardware performance, but if I recall correctly, on a typical CPU multiplication is something like 2-4x slower than addition, and that advantage probably goes away when doing multiplications in bulk/pipelines. So you can imagine we’re not in a classical CPU hardware model.
The idea is to rewrite an operation like y = 9*x
as y = 8*x + x
(the 8 is a power of 2) and then expand it further as a = x+x; b = a+a; c = b+b; y = c+x
. It replaces a multiplication by a constant with a roughly log-number of additions (the base-2 logarithm of the constant), though it gets worse the further away the constant gets from a power of two.
This commit contains a similar “empty shell” of a pass, with two patterns defined. The first, PowerOfTwoExpand
, will be a pattern that rewrites y=C*x
as y = C/2*x + C/2*x
, when C
is a power of 2, otherwise fails. The second, PeelFromMul
“peels” a single addition off a product that is not with a power of 2, rewriting y = 9*x
as y = 8*x + x
. These are applied repeatedly via the greedy pattern rewrite engine. By setting the benefit
argument of PowerOfTwoExpand
to be larger than PeelFromMul
, we tell the greedy rewrite engine to prefer PowerOfTwoExpand
whenever possible. Together, that achieves the transformation mentioned above.
This commit adds a failing test that only exercises PowerOfTwoExpand
, and then this commit implements it. Here’s the implementation:
LogicalResult matchAndRewrite(
MulIOp op, PatternRewriter &rewriter) const override {
Value lhs = op.getOperand(0);
// canonicalization patterns ensure the constant is on the right, if there is a constant
// See https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules
Value rhs = op.getOperand(1);
auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
if (!rhsDefiningOp) {
return failure();
}
int64_t value = rhsDefiningOp.value();
bool is_power_of_two = (value & (value - 1)) == 0;
if (!is_power_of_two) {
return failure();
}
ConstantOp newConstant = rewriter.create<ConstantOp>(
rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value / 2));
MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, newMul);
rewriter.replaceOp(op, {newAdd});
rewriter.eraseOp(rhsDefiningOp);
return success();
}
Some notes:
Value
is the type that represents an SSA value (i.e., an MLIR variable), andgetDefiningOp
fetches the unique operation that defines it in its scope.- There are a variety of “casting” operations like
rhs.getDefiningOp<arith::ConstantIntOp>()
that take the type you want as output as a template parameter, and returnnull
if the type cannot be converted. You might also seecast<>
,dyn_cast<>
, ordyn_cast_or_null<>
to invoke these manually. (value & (value - 1))
is a classic bit-twiddling trick to compute if an integer is a power of two. We check it and skip the pattern if it’s not.- The actual constant itself is represented as an MLIR attribute, which is essentially compile-time static data attached to the op. You can put strings or dictionaries as attributes, but for
ConstantOp
it’s just an int.
The rewriter.create
part is where we actually do the real work. Create a new constant that is half the original constant, create new multiplication and addition ops, and then finally rewriter.replaceOp
removes the original multiplication op and uses the output of newAdd
for any other operations that used the original multiplication op’s output.
It’s worth noting that we’re relying on MLIR’s built-in canonicalization passes in a few ways here:
- To ensure that the constant is always the second operand of a multiplication op.
- To ensure that the base case (
x*1
) is “folded” into a plainx
and the constant1
is removed. - The
fold
part ofapplyPatternsAndFoldGreedily
is what runs these cleanup steps for us.
PeelFromMul
is similar, implemented and tested in this commit:
LogicalResult matchAndRewrite(MulIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
if (!rhsDefiningOp) { return failure(); }
int64_t value = rhsDefiningOp.value();
// We are guaranteed `value` is not a power of two, because the greedy
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
// it has higher benefit.
ConstantOp newConstant = rewriter.create<ConstantOp>(
rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value - 1));
MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);
AddIOp newAdd = rewriter.create<AddIOp>(op.getLoc(), newMul, lhs);
rewriter.replaceOp(op, {newAdd});
rewriter.eraseOp(rhsDefiningOp);
Running it! Input:
func.func @power_of_two_plus_one(%arg: i32) -> i32 {
%0 = arith.constant 9 : i32
%1 = arith.muli %arg, %0 : i32
func.return %1 : i32
}
Output:
module {
func.func @power_of_two_plus_one(%arg0: i32) -> i32 {
%0 = arith.addi %arg0, %arg0 : i32
%1 = arith.addi %0, %0 : i32
%2 = arith.addi %1, %1 : i32
%3 = arith.addi %2, %arg0 : i32
return %3 : i32
}
}
Exercise: Try swapping the benefit
arguments to see how the output changes.
Exercise: When multiplying by a power of two, replace it with an appropriate left-shift op instead. Browse the arith dialect docs to find the right op, and then browse the ArithOps.td tablegen file to find the name of the right op. We’ll discuss the syntax of tablegen and op definitions in the next two articles.
Though this pass is quite naive, you can imagine a more sophisticated technique that builds a cost model for multiplications and additions, and optimizes for the cheapest cost representation of an arithmetic operation in terms of repeated additions, multiplications, and other supported ops.
Should we walk?
With two options for how to define a pass—one to walk the entire syntax tree from the root operation, and one to match and rewrite patterns with the rewrite engine—the natural question is when should you use one versus the other.
The MLIR docs describe the motivation behind the pattern rewrite engine, and it comes from a long history of experience with the LLVM project. For one, the pattern rewrite engine expresses a convenient subset of what can be achieved with an MLIR pass. This is conceptually trivial, in the sense that anyone who can walk the entire AST can, with enough effort, do anything they want including reimplementing the pattern rewrite engine.
More practically, the pattern rewrite engine is convenient to represent local transformations. “Local” here means that the input and output can be detected via a subset of the AST as a directed acyclic graph. More pragmatically, think of it as any operation you can identify by looking around at neighboring operations in the same block and applying some filtering logic. E.g., “is this exp
operation followed by a log
operation with no other uses of the output of the exp
?”
On the other hand, some analyses and optimizations need to construct the entire dataflow of a program to work. A good example is common subexpression elimination, which determines whether it is cost effective to extract a subexpression used in multiple places into a separate variable. Doing so may introduce additional cost of memory access, so it depends both on the operation’s cost and on the availability of registers at that point in the program. You can’t get this information by pattern matching the AST locally.
The wisdom seems to be: using the pattern rewrite engine is usually easier than writing a pass that walks the AST. You don’t need large case/switch statements to handle everything that could show up in the IR. The engine handles re-applying patterns many times. And so you can write the patterns in isolation and trust the engine to combine them appropriately.
Bonus: IDEs and CI
Since we explored the C++ API, it helps to have an IDE integration. I use neovim with the clangd LSP, and to make it work with a Bazel C++ project, one needs to use something analogous to Hedron Vision’s compile_commands
extractor, which I configured for this tutorial project in this commit. It’s optional, but if you want to use it you have to run bazel run @hedron_compile_commands//:refresh_all
once to set it up, and then clangd
and clang-tidy
, etc., should find the generated json
file and use it. Also, if you edit a BUILD file, you have to re-run refresh_all
for the changes to show up in the LSP.
Though it’s not particularly relevant to this tutorial, I also added a commit that configures GitHub actions to build and test the project in CI in this commit. It is worth noting that the GitHub cache action reduces subsequent build times from 1-2 hours down to just a few minutes.
Thanks to Patrick Schmidt for feedback on a draft of this article, and Florent Michel for formatting corrections.
Want to respond? Send me an email, post a webmention, or find me elsewhere on the internet.