Last time, we covered a Bazel build system setup for an MLIR project. This time we’ll give an overview of a simple lowering and show how end-to-end tests work in MLIR. All of the code for this article is contained in this pull request on GitHub, and the commits are nicely organized and quite readable.
Two of the central concepts in MLIR are dialects and lowerings. These are the scaffolding within which we can do the truly interesting parts of a compiler—that is, the optimizations and analyses. In traditional compilers, there is typically one “dialect” (called an intermediate representation, or IR) that is the textual or data-structural description of a program within the compiler’s code. For example, in GCC the IR is called GIMPLE, and in LLVM it’s called LLVM-IR. They convert the input program to the IR, do their optimizations, and then convert the optimized IR to machine code.
In MLIR one splits the job into much smaller steps. First, MLIR allows one to define many dialects, some considered “high level” and some “low level,” but each with a set of types, operations, metadata, and semantics that defines what the operations do. Then, one writes a set of lowering passes that incrementally converts different parts of the program from higher level dialects to lower and lower dialects until you get to machine code (or, in many cases, LLVM, which finishes the job). Along the way, optimizing passes are run to make the code more efficient. The main point here is that the high level dialects exist so that they make it easy to write these important optimizing passes. And there’s not a special distinction between lowering passes and optimizing passes, they’re both just called passes in MLIR and are generic IR-rewriting modules.
Aside: From what I can gather, a big part of the motivation for MLIR was to build the affine
dialect, which is specifically designed to enable polyhedral optimizations for loop transformations, along with the linalg dialect, which does optimization passes like tiling for low-level ML operations on specialized hardware. Folks built polyhedral optimizations in LLVM and GCC (without affine
), and it was a huge pain in the ass, mainly because they had to take a low-level mess of branches and GOTOs and try to reconstruct a simple (‘affine’) for loop structure from it. This was necessary even if the input program was a simple set of for loops, because by the time they got to the compiler, the rigid for loop structure had been discarded. MLIR instead says, keep the structure in the higher level dialect, optimize there, and then discard it when you lower to lower level dialects.
Two example programs
A general understanding properly begins with concrete examples. Here are two MLIR programs that define a function that counts the leading zeroes of a 32-bit integer (i32
). The first uses the math
dialect’s defined ctlz
operation and just returns it.
func.func @main(%arg0: i32) -> i32 {
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}
This shows the basic structure of an MLIR operation (see here for a more complete spec). Variable names are prefixed with %
, functions by @
, and each variable/value in a program has a type, often expressed after a colon. In this case all the types are i32
, except for the function type which is (i32) -> i32
(not specified explicitly above, but you’ll see it in the func.call
in the next code snippet).
Each statement is anchored around an expression like math.ctlz
which specifies the dialect math
and the operation ctlz
. The rest of the syntax of the operation is determined by a parser defined by the dialect, and so many operations will have different syntaxes, though many are pulled from a fixed set of options we’ll see later in the series. In the simple case of math.ctlz
, the sole argument is the integer whose leading zeros are to be counted, and the trailing : i32
denotes the output type.
It’s also important to note that func
is itself a dialect, and func.func
is considered an “operation,” where the braces and the function’s body is part of the syntax. In MLIR a set of operations within braces is called a region, and an operation can have zero or many regions.
There is a lot more to say about regions, and their cousins “basic blocks,” but in brief: operations may have attached regions, like the body of a for loop, and each region is a list of blocks (with an implicit block if non is explicitly listed). A block is a list of operations that is guaranteed to have only one entry and exit point. I think of the label in a block as the destination of a jump command in assembly languages. A block has exactly one “jumping in” point and one “jumping out” point. It has a more precise definition that aligns with the classical compiler concept of a basic block.
Also note, in MLIR multiple dialects often coexist in the same program as it is progressively lowered to some final backend target.
The second version of this program has a software implementation of the ctlz
function and calls it.
func.func @main(%arg0: i32) -> i32 {
%0 = func.call @my_ctlz(%arg0) : (i32) -> i32
func.return %0 : i32
}
func.func @my_ctlz(%arg0: i32) -> i32 {
%c32_i32 = arith.constant 32 : i32
%c0_i32 = arith.constant 0 : i32
%0 = arith.cmpi eq, %arg0, %c0_i32 : i32
%1 = scf.if %0 -> (i32) {
scf.yield %c32_i32 : i32
} else {
%c1 = arith.constant 1 : index
%c1_i32 = arith.constant 1 : i32
%c32 = arith.constant 32 : index
%c0_i32_0 = arith.constant 0 : i32
%2:2 = scf.for %arg1 = %c1 to %c32 step %c1 iter_args(%arg2 = %arg0, %arg3 = %c0_i32_0) -> (i32, i32) {
%3 = arith.cmpi slt, %arg2, %c0_i32 : i32
%4:2 = scf.if %3 -> (i32, i32) {
scf.yield %arg2, %arg3 : i32, i32
} else {
%5 = arith.addi %arg3, %c1_i32 : i32
%6 = arith.shli %arg2, %c1_i32 : i32
scf.yield %6, %5 : i32, i32
}
scf.yield %4#0, %4#1 : i32, i32
}
scf.yield %2#1 : i32
}
func.return %1 : i32
}
The algorithm above is not relevant to this post, but either way it is quite simple: count the leading zeros by shifting the input left one bit at a time until it becomes negative (as a signed integer), because that occurs exactly when its leading bit is a 1. Then add a special case to handle zero, which would loop infinitely otherwise.
Here you can see two more MLIR dialects. arith
is for low-level arithmetic and boolean conditions on integers and floats. You can define constants, compare integers with arith.cmpi
, and do things like add and bit shift (arith.shli
is a left shift). scf
, short for “structured control flow,” defines for loops, while loops, and control flow branching. scf.yield
defines the “output” value from each region of an if/else operation or loop body which is necessary here because, as you can see, an if
operation has a result value.
Two other minor aspects of the syntax are on display. First is the syntax %4:2
, which defines a variable %4
which is a tuple of two values. The corresponding %4#1
accesses the second entry in the tuple. Second, you’ll notice there’s a type called index
that is different from i32
. Though they both represent integers, index
is intended to be a platform-dependent integer type which is suitable for indexing arrays, representing sizes and dimensions of things, and, in our case, being loop counters and iteration bounds. More details on index
in the MLIR docs.
Lowerings and the math-to-funcs pass
We have two versions of the same program because one is a lowered version of the other. In most cases, the machine you’re going to run a program on has a “count leading zeros” function, so the lowering would simply map math.ctlz
to the corresponding machine instruction. But if there is no ctlz
instruction, a lowering can provide an implementation in terms of lower level dialects and ops. Specifically, this one lowers ctlz to {func
, arith
, scf
}.
The second version of this code was actually generated by the mlir-opt
command line tool, which is the main entry-point to running MLIR passes on specific MLIR code. For starters, one can take the mlir-opt
tool and run it with no arguments on any MLIR code, and it will parse it, verify it is well formed, and print it back out with some slight normalizations. In this case, it will wrap the code in a module
, which is a namespace isolation mechanism.
$ echo 'func.func @main(%arg0: i32) -> i32 {
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}' > ctlz.mlir
$ bazel run @llvm-project//mlir:mlir-opt -- $(pwd)/ctlz.mlir
<... snip ...>
module {
func.func @main(%arg0: i32) -> i32 {
%0 = math.ctlz %arg0 : i32
return %0 : i32
}
}
Aside: The -- $(pwd)/ctlz.mlir
is a quirk of bazel. When one program runs another program, the --
is the standard mechanism to separate CLI flags from the runner program (bazel
) and the run program (mlir-opt
). Everything after --
goes to mlir-opt
. Also, the need for $(pwd)
is because when bazel runs mlir-opt
, it runs it with a working directory that is in some temporary, isolated location on the filesystem. So we need to give it an absolute path to the MLIR file to input. Or we could pipe from standard in. Or we could run the mlir-opt
binary directly from bazel-bin/external/llvm-project/mlir/mlir-opt
.
Next we can run our first lowering, which is already built-in to mlir-opt
, and which generates the long program above.
$ bazel run @llvm-project//mlir:mlir-opt -- --convert-math-to-funcs=convert-ctlz $(pwd)/ctlz.mlir
<... snip ...>
module {
func.func @main(%arg0: i32) {
%0 = call @__mlir_math_ctlz_i32(%arg0) : (i32) -> i32
return
}
func.func private @__mlir_math_ctlz_i32(%arg0: i32) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
<... snip ...>
Each pass gets its own command line flag, some are grouped into pipelines, and the --pass-pipeline
command line flag can be used to provide a (serialized version of) an ordered list of passes to run on the input MLIR.
We won’t cover the internal workings of the math-to-funcs
pass in this or a future article, but next time we will actually write our own, simpler pass that does something nontrivial. Until then, I’ll explain a bit about how testing works in MLIR, using these two ctlz programs as example test cases.
For those who are interested, the MLIR documentation contains a complete list of passes owned by the upstream MLIR project, which can be used by invoking the corresponding command line flag or nesting it inside of a larger --pass-pipeline
.
Lit, FileCheck, and Bazel again
The LLVM and MLIR projects both use the same testing framework, which is split into two parts. The first is lit (LLVM Integrated Tester; though I don’t know why it’s called “integrated”), which handles test discovery and running. The second is FileCheck, which handles test assertions and reporting.
I don’t know why they’re two separate tools, but they are primarily end to end testing tools—as opposed to unit testing tools. Because end-to-end testing in a compiler toolchain implies the inputs and outputs are essentially big strings (programs) in unknown languages (user-defined MLIR dialects), these tools basically have you express the test setup and assertions in comments inside of the file representing the input program to be tested. An example might look like this:
// RUN: mlir-opt %s --convert-math-to-funcs=convert-ctlz | FileCheck %s
func.func @main(%arg0: i32) -> i32 {
// CHECK-NOT: math.ctlz
// CHECK: call
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}
I added this test in this commit. This trivial function calls math.ctlz
on its input and promptly returns it. The interesting parts are the comments, which define the test command and assertions.
Warning: you may run into a python issue where python cannot find the lit
module. See https://github.com/j2kun/mlir-tutorial/issues/8, wherein I realized too late that bazel uses the system Python by default. tl;dr: you can either run pip install lit
on your system Python, or else cherry-pick a commit in that issue to use a bazel-managed Python.
A lit
test file contains some number of lines with RUN:
as the lead of a comment, and the text after that describes a shell script to run, with some magic strings instructing lit
to make substitutions. In this case %s
is the current file path, but there is a table of default substitutions and one can add their own custom substitutions in a config file (see later).
In typical unix fashion, all lit
does is check for the exit status of the RUN
command to determine if the test passes or fails. Hence, this test pipes the output of mlir-opt
to the FileCheck
program, again passing in the current file path, which contains the assertions to check for.
FileCheck
is a bit more complicated than lit
, but in brief it takes the input passed to stdin
, scans for CHECK
comments in the file passed as the CLI argument, and then for each CHECK
comment, it does some logic to determine if the assertion passes. The simplest kind of assertion is a // CHECK: foo
which searches for a line matching the foo
regular expression. Similarly, a // CHECK-NOT:
assertion asserts the regular expression is not matched in the file. Beyond that, the main thing that is enforced is that multiple CHECK
assertions match the input file in the same order that the CHECK
comments occur. So if you had
// RUN: mlir-opt %s --convert-math-to-funcs=convert-ctlz | FileCheck %s
func.func @main(%arg0: i32) -> i32 {
// CHECK: call
// CHECK: foo
// CHECK: return
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}
Then it would expect that there are three lines (possibly with other lines between them) that match these regular expressions in order. A line matching call
and a line matching return
would fail unless there is a line between them matching foo
.
FileCheck
can do a lot more, like use regular expressions to capture variable names and then refer to them in later CHECK
assertions. With this, one can give a much more precise test on the ctlz
lowering, expecting a relatively rigid structure of the output function, as in this commit. I won’t give full details here, but you can read the FileCheck documentation here and intuitively tell that an expression like %[[ARGCMP:.*]]
captures a variable name where ARCMP
is how it is referred to in later assertions, while .*
is the regular expression used to capture it (and %
is an anchor that ensures it only applies to a variable name).
To run these tests, requires a bit of finnicky configuration. lit
can be run like a normal python program, and if all the executables invoked in the RUN
directives are in the PATH
environment variable, then it will just work. However, in any build system the executables are in exotic places, which is especially true in Bazel.
Aside: In typical CMake-oriented MLIR projects, there are actually two config files, one called something like lit.site.cfg.py.in
, which has variables that CMake substitutes in pointing to the build artifact paths, and one called lit.cfg.py
which configures lit
to use those paths. In my opinion the Bazel configuration is marginally simpler, but I am biased because I’m more familiar with it.
In bazel a single test would correspond to a build target in a BUILD
file of the following form:
py_test(
name = "my_test_file.mlir.test",
srcs = ["@llvm_project//llvm:lit"],
args = ["-v", "tests/my_test_file.mlir"],
data = ["@llvm-project//llvm:FileCheck", ..., ":my_test_file.mlir"],
main = "lit.py",
)
This tells bazel to run lit
with the right arguments, and crucially, the data
argument allows us to pull in binary targets corresponding to the commands used in the RUN
directives. Then two things happen. First, lit
runs in a working directory determined arbitrarily by bazel (with the data
stuff pulled in somehow). We need to understand the directory structure of this working directory in order to configure lit
properly. Then, when lit
runs, it looks for a file called lit.cfg.py
in the directory containing the test (and recursively upward to the project root), loads it, and uses that to set the PATH
and other configuration.
In our case, the lit.cfg.py
looks like this
import os
from pathlib import Path
from lit.formats import ShTest
config.name = "mlir_tutorial"
config.test_format = ShTest()
config.suffixes = [".mlir"]
runfiles_dir = Path(os.environ["RUNFILES_DIR"])
tool_relpaths = [
"llvm-project/mlir",
"llvm-project/llvm",
]
config.environment["PATH"] = (
":".join(str(runfiles_dir.joinpath(Path(path))) for path in tool_relpaths)
+ ":"
+ os.environ["PATH"]
)
Two weird things are happening here. First, config
is an undefined variable at first glance, but the lit
documentation states that an instance is inserted into the module scope when lit
runs. Second, we are using the RUNFILES_DIR
environment variable as the base for the paths we will construct pointing to the binaries. RUNFILES_DIR
is defined by bazel, and it is generally different from the working directory of the binary run by native.py_test
. It contains a directory tree for the current project (mlir_tutorial
) as well as all dependent projects defined in the WORKSPACE
, so long as some targets from those projects were included in the data
option of the test rule.
Once this is all worked out, one can define individual test targets for each lit
test. However, since that is laborious, instead what I did in this commit was define all of the above configuration together with a bazel macro that will search for all .mlir
files in a given directory and create test targets for them. So in this project, a new .mlir
file added to tests/
will be automatically run when you run bazel test //...
. Then, tests/BUILD
contains the glob_lit_tests
macro invocation, and a filegroup
that describes all the tools and files that should be included in the data
to run them.
# tests/BUILD
load("//bazel:lit.bzl", "glob_lit_tests")
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tests:lit.cfg.py",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
"@llvm-project//llvm:not",
"@llvm-project//mlir:mlir-opt",
],
)
glob_lit_tests()
Bonus: functional testing
The previous lowering of math.ctlz
to a software implementation has a very detailed test, but in MLIR the lowerings are primarily syntactic
in nature. That is, the test does not assert that the lowering itself is functionally correct. The author may have created assertions that align with the generated code, but the generated code has a bug or otherwise does not compute what it is supposed to compute.
One way around this is to continue compiling the MLIR code down through LLVM to machine code, running it, and asserting something about the output (presumably printed to stdout
). While this is possible in lit
, since RUN
can run anything, it does require pulling in quite a few more dependencies. A slightly more lightweight means to achieve this is to use mlir-cpu-runner
, which is an interpreter for some of the lowest-level MLIR dialects (in particular, the llvm
dialect, which is the “exit” dialect before going to LLVM).
Here’s what such a test might look like in lit
for our ctlz
lowering pass, which tests that 7, as a 32-bit integer, has 29 leading zeros. I added the test in this commit. Notably, I had to add the mlir-cpu-runner
binary to that list of test_utilities
mentioned in the previous section, or else the test will fail with the inability to find the mlir-cpu-runner
binary.
// RUN: mlir-opt %s \
// RUN: --pass-pipeline="builtin.module( \
// RUN: convert-math-to-funcs{convert-ctlz}, \
// RUN: func.func(convert-scf-to-cf,convert-arith-to-llvm), \
// RUN: convert-func-to-llvm, \
// RUN: convert-cf-to-llvm, \
// RUN: reconcile-unrealized-casts)" \
// RUN: | mlir-cpu-runner -e test_7i32_to_29 -entry-point-result=i32 > %t
// RUN: FileCheck %s --check-prefix=CHECK_TEST_7i32_TO_29 < %t
func.func @test_7i32_to_29() -> i32 {
%arg = arith.constant 7 : i32
%0 = math.ctlz %arg : i32
func.return %0 : i32
}
// CHECK_TEST_7i32_TO_29: 29
The RUN
command is quite a bit more complicated. First, we need to run more passes than just convert-math-to-funcs
in order to get the code down to the LLVM dialect, which is what mlir-cpu-runner
supports. The --pass-pipeline
flag allows you to build a more complex chain of passes on the command line. Then the result is piped to mlir-cpu-runner
, which takes as command line flags the top level function to run and the type of the result. Finally, the output is piped to %t
, which is a lit
substitution magic that represents a per-test temporary file. In this case, it is used so that if this first command fails, the error message from that is displayed in the test failure, rather than the subsequent failure of FileCheck
to parse an empty input from the pipe.
Then, a second RUN
command runs FileCheck
, again using the current file for the test assertion, piping the input to test as %t
, and adding the special --check-prefix
flag so that it only runs a subset of CHECK
assertions in the file (allowing us to add a second test in the same file, as in the next commit that runs a similar test for an i64
input). Then, because mlir-cpu-runner
prints the result of the function to stdout
, the CHECK
assertion just expects the output to be 29 for input 7.
It may also be interesting to the reader to see what MLIR outputs when I run the full pass (but not the mlir-cpu-runner
on the input. Here it is:
module attributes {llvm.data_layout = ""} {
llvm.func @test_7i32_to_29() -> i32 {
%0 = llvm.mlir.constant(7 : i32) : i32
%1 = llvm.mlir.constant(29 : i32) : i32
llvm.return %1 : i32
}
llvm.func linkonce_odr @__mlir_math_ctlz_i32(%arg0: i32) -> i32 attributes {sym_visibility = "private"} {
%0 = llvm.mlir.constant(32 : i32) : i32
%1 = llvm.mlir.constant(0 : i32) : i32
%2 = llvm.icmp "eq" %arg0, %1 : i32
llvm.cond_br %2, ^bb1, ^bb2
^bb1: // pred: ^bb0
llvm.br ^bb10(%0 : i32)
^bb2: // pred: ^bb0
%3 = llvm.mlir.constant(1 : index) : i64
%4 = llvm.mlir.constant(1 : i32) : i32
%5 = llvm.mlir.constant(32 : index) : i64
%6 = llvm.mlir.constant(0 : i32) : i32
llvm.br ^bb3(%3, %arg0, %6 : i64, i32, i32)
^bb3(%7: i64, %8: i32, %9: i32): // 2 preds: ^bb2, ^bb8
%10 = llvm.icmp "slt" %7, %5 : i64
llvm.cond_br %10, ^bb4, ^bb9
^bb4: // pred: ^bb3
%11 = llvm.icmp "slt" %8, %1 : i32
llvm.cond_br %11, ^bb5, ^bb6
^bb5: // pred: ^bb4
llvm.br ^bb7(%8, %9 : i32, i32)
^bb6: // pred: ^bb4
%12 = llvm.add %9, %4 : i32
%13 = llvm.shl %8, %4 : i32
llvm.br ^bb7(%13, %12 : i32, i32)
^bb7(%14: i32, %15: i32): // 2 preds: ^bb5, ^bb6
llvm.br ^bb8
^bb8: // pred: ^bb7
%16 = llvm.add %7, %3 : i64
llvm.br ^bb3(%16, %14, %15 : i64, i32, i32)
^bb9: // pred: ^bb3
llvm.br ^bb10(%9 : i32)
^bb10(%17: i32): // 2 preds: ^bb1, ^bb9
llvm.br ^bb11
^bb11: // pred: ^bb10
llvm.return %17 : i32
}
}
The main new thing here, besides all of the llvm
dialect operations, is the ^bb1
syntax, which is the label identifier for those basic block syntax structures mentioned earlier. With the basic syntax and testing down, next time we will define a custom lowering and explore the MLIR API from that perspective. Then we’ll dive into defining a new dialect.
Thanks to Patrick Schmidt for feedback on a draft of this article.
Want to respond? Send me an email, post a webmention, or find me elsewhere on the internet.