MLIR — Defining a New Dialect

In the last article in the series, we migrated the passes we had written to use the tablegen code generation framework. That was a preface to using tablegen to define dialects.

In this article we’ll define a dialect that represents arithmetic on single-variable polynomials, with coefficients in $\mathbb{Z} / 2^{32} \mathbb{Z}$ (32-bit unsigned integers).

Sketching out a design

The basic dialect will define a new polynomial type, and provide operations to define polynomials by specifying their coefficients from standard MLIR types, extract data about a polynomial to store the results in standard MLIR types, and to do arithmetic operations on polynomials.

There is quite a large surface area of design options when writing a dialect. This talk by Jeff Niu and Mehdi Amini gives some indication of how one might start to think about dialect design. But in brief, a polynomial dialect would fit into the “computation” bucket (a dialect to represent number crunching).

Another idea relevant to starting a dialect design is to ask how a dialect will enable easy optimizations. That is the Optimizer dialect class above. But since this tutorial series is focusing on the low-level, day to day, nitty gritty details of working with MLIR, we’re going to focus on getting some custom dialect defined, and then come back to what makes a good dialect design later.

An empty dialect

We’ll start by defining an empty dialect and just look at the tablegen-generated code, which is done in this commit. The tablegen file looks like

include "mlir/IR/DialectBase.td"

def Poly_Dialect : Dialect {
let name = "poly";
let summary = "A dialect for polynomial math";
let description = [{
The poly dialect defines types and operations for single-variable
polynomials over integers.
}];

let cppNamespace = "::mlir::tutorial::poly";
}


It is almost indistinguishable from the tablegen for a Pass, except for the Dialect base class. The build rule has different flags too.

gentbl_cc_library(
name = "dialect_inc_gen",
tbl_outs = [
(["-gen-dialect-decls"], "PolyDialect.h.inc"),
(["-gen-dialect-defs"], "PolyDialect.cpp.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PolyDialect.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
],
)


Unlike the Pass codegen, here the dialect has us specify a codegen’ed header and implementation file, which are short enough to include directly in the article:

// bazel-bin/lib/Dialect/Poly/PolyDialect.h.inc
namespace mlir {
namespace tutorial {

class PolyDialect : public ::mlir::Dialect {
explicit PolyDialect(::mlir::MLIRContext *context);

void initialize();
friend class ::mlir::MLIRContext;
public:
~PolyDialect() override;
static constexpr ::llvm::StringLiteral getDialectNamespace() {
return ::llvm::StringLiteral("poly");
}
};
} // namespace tutorial
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::PolyDialect)


And the cpp:

// bazel-bin/lib/Dialect/Poly/PolyDialect.cpp.inc
MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::tutorial::PolyDialect)
namespace mlir {
namespace tutorial {

PolyDialect::PolyDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<PolyDialect>()) {
initialize();
}

PolyDialect::~PolyDialect() = default;

} // namespace tutorial
} // namespace mlir



Basically, they are empty containers that will hold the types and ops that we define next.

In this commit we register the dialect with the tutorial-opt main program, which is handled by a single API call. Now the dialect shows up in the help text of the tutorial-opt binary, though nothing else is there because we have no passes associated with it.

$bazel run tools:tutorial-opt -- --help Available Dialects: ..., pdl_interp, poly, quant, ...  Adding a trivial type Next we’ll define a poly.poly type with no semantics, and in the next section we’ll focus on the semantics. This commit defines a simple test that ensures we can parse and print our new type. // RUN: tutorial-opt %s module { func.func @main(%arg0: !poly.poly) -> !poly.poly { return %arg0 : !poly.poly } }  Note that the exclamation mark ! sigil prefix is required for out-of-tree MLIR types. Next, in this commit we add the tablegen for the poly type. The tablegen looks like include "PolyDialect.td" include "mlir/IR/AttrTypeBase.td" // A base class for all types in this dialect class Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> { let mnemonic = typeMnemonic; } def Poly : Poly_Type<"Polynomial", "poly"> { let summary = "A polynomial with i32 coefficients"; let description = [{ A type for polynomials with integer coefficients in a single-variable polynomial ring. }]; }  This is effectively a boilerplate shell since nothing here is specific to polynomial arithmetic. But it shows a few new things about tablegen worth mentioning. First and most trivially, tablegen has include statements so that you can split definitions across files. I like to put each conceptual type of thing in its own tablegen file (types, ops, attributes, etc.), but conventions differ across projects. Second, a common pattern when defining types (and ops) is to have a base class for each dialect that all concrete type definitions inherit from. This shows first of all that there is a difference between class and def in tablegen. def defines actual types, where by “actual” I mean it generates C++ code, whereas class is only an inheritance base and disappears after tablegen is done. Think of class as allowing us to refactor out common shared code among type definitions in one file. I make a stink of this because I have mixed up class and def many times and been confused by the error messages and generated code. For example, if you change the class to a def above, you’ll see the following unhelpful error message. lib/Dialect/Poly/PolyTypes.td:8:5: error: Expected a class name, got 'Poly_Type' def Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> {  Moreover, TypeDef itself is a class that takes as template arguments the dialect the type should belong to and a name field (related to the type’s eventual C++ class name), and results in Tablegen associating the generated C++ classes with the same namespace as we told the dialect to use, among other things. Third, the new field is the mnemonic declaration. This determines the name of the type in the textual representation of the IR. The generated code again has separate .h.inc and .cpp.inc files: // PolyTypes.h.inc #ifdef GET_TYPEDEF_CLASSES #undef GET_TYPEDEF_CLASSES namespace mlir { class AsmParser; class AsmPrinter; } // namespace mlir namespace mlir { namespace tutorial { namespace poly { class PolynomialType; class PolynomialType : public ::mlir::Type::TypeBase<PolynomialType, ::mlir::Type, ::mlir::TypeStorage> { public: using Base::Base; static constexpr ::llvm::StringLiteral getMnemonic() { return {"poly"}; } }; } // namespace poly } // namespace tutorial } // namespace mlir MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolynomialType) #endif // GET_TYPEDEF_CLASSES  // PolyTypes.cpp.inc #ifdef GET_TYPEDEF_LIST #undef GET_TYPEDEF_LIST ::mlir::tutorial::poly::PolynomialType #endif // GET_TYPEDEF_LIST #ifdef GET_TYPEDEF_CLASSES #undef GET_TYPEDEF_CLASSES static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) .Case(::mlir::tutorial::poly::PolynomialType::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) { value = ::mlir::tutorial::poly::PolynomialType::get(parser.getContext()); return ::mlir::success(!!value); }) .Default([&](llvm::StringRef keyword, llvm::SMLoc) { *mnemonic = keyword; return std::nullopt; }); } static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) { return ::llvm::TypeSwitch<::mlir::Type, ::mlir::LogicalResult>(def) .Case<::mlir::tutorial::poly::PolynomialType>([&](auto t) { printer << ::mlir::tutorial::poly::PolynomialType::getMnemonic(); return ::mlir::success(); }) .Default([](auto) { return ::mlir::failure(); }); } namespace mlir { namespace tutorial { namespace poly { } // namespace poly } // namespace tutorial } // namespace mlir MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolynomialType) #endif // GET_TYPEDEF_CLASSES  The name PolynomialType is generated by adding Type to the "Polynomial" template argument we passed in the tablegen file. The name of the def itself is used to refer to the class elsewhere in tablegen files, and the two can be different. One thing to pay attention to here is that tablegen is attempting to generate a type parser and printer for us. But it’s not usable yet—we’ll come back to this. If you build at this commit you’ll see a compiler warning like generatedTypePrinter defined but not used and a hard failure if you try running the test. A second thing to notice is that it uses the header-guards-as-function-arguments style again, here to separate the cpp file into two include-guarded sections, one that just has a list of the types defined, and the other that has implementations of functions. The first one, GET_TYPEDEF_LIST is curious because it just includes a comma-separated list of class names. This is because the PolyDialect.cpp from this commit is responsible for registering the types with the dialect, and that happens by using this include to add the C++ class names for the types as template arguments in the Dialect’s initialization function. // PolyDialect.cpp #include "lib/Dialect/Poly/PolyDialect.h" #include "lib/Dialect/Poly/PolyTypes.h" #include "mlir/include/mlir/IR/Builders.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" #include "lib/Dialect/Poly/PolyDialect.cpp.inc" #define GET_TYPEDEF_CLASSES #include "lib/Dialect/Poly/PolyTypes.cpp.inc" namespace mlir { namespace tutorial { namespace poly { void PolyDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "lib/Dialect/Poly/PolyTypes.cpp.inc" >(); } } // namespace poly } // namespace tutorial } // namespace mlir  We’ll do the same registration dance for ops, attributes, etc., later. The expected way to set up the C++ interface to the tablegen files is: • Create a header file PolyTypes.h which is the only file allowed to include PolyTypes.h.inc, • Include PolyTypes.cpp.inc inside PolyDialect.cpp with any additional #includes needed for the auto-generated implementations in PolyTypes.cpp.inc. In our case, the default parser/printer uses the type switch function from llvm/include/llvm/ADT/TypeSwitch.h. • If needed, add a PolyTypes.cpp with any additional implementations needed for functions declared by tablegen that can’t be automatically generated. I don’t think I’ve found the best arrangement of these files and build targets yet. What I really want is to have each header with a corresponding .h.inc, a corresponding .cpp file that includes the relevant .cpp.inc, and to connect them all with PolyDialect.cpp. However, PolyDialect::initialize does some introspection on the classes declared in PolyTypes.h.inc to ensure they’re valid (specifically looking for details of Storage types like is_trivially_destructible, which are generated for us in the next section), and that requires the implementations to exist in the same compilation unit (I believe). From what I’ve read of other projects, people just tend to cram things into the same cpp files, leading to multi-thousand line implementation files, which I dislike. As a final note, the build file uses a new rule called td_files to group all the tablegen files into one build target. The last ingredient required to get this to compile and run is to make the type parser and printer usable. Thankfully it’s one line: adding let useDefaultTypePrinterParser = 1; to the dialect tablegen (see this commit). This adds the following declarations to PolyDialect.h.inc, and modifies the generated implementations to be member functions on PolyDialect.  /// Parse a type registered to this dialect. ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; /// Print a type registered to this dialect. void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &os) const override;  Adding a poly type parameter When designing the poly type, we have to think about what we want to express, and how the type will be lowered to types in lower level dialects. For the sake of this tutorial we’ll restrict to single-variable polynomials, so the indeterminate can be arbitrary and implicit (we don’t need to keep track of the variable name x or y or t in the IR). The next question is the polynomial analogue of integer bit width: degree. The difficulty here is that tracking the exact degree of a polynomial is impossible in general. The degree of a sum of two polynomials depends on the coefficient values (the highest degree terms may cancel), which may not be known statically because they are read from an external source at runtime. So at best we could track an upper bound on the degree. The simplest approach is to require a polynomial type to declare its “degree upper bound” statically, and then define “overflow” semantics just like integers have. A natural option is to implicitly treat polynomials as members of a ring$R[x]/ (x^D-1)$, where$D$is the degree upper bound and$R$is the ring of 32-bit unsigned integers. In less mathematical terms, this means that whenever a term$x^n$in a polynomial would exceed degree$D \leq n$, we would replace$x^n$with$x^{n \mod D}$, which is the same thing as “declaring”$x^D=1$and reducing polynomials via that substitution until all terms have degree less than$D$. This simple approach has the benefit of making the dialect design easy: all binary polynomial ops have the same input and output type, and lowering a poly type requires replacing poly.poly<D> with a tensor of D coefficients. Since the shapes are static and most polynomials would have the same bound, it’s probably the most performant option. The downsides are that now the programmer has to worry about overflow semantics, and if you want polynomials of larger degree you have to insert bookkeeping operations to extend the degree bound. On the other hand, we could allow a poly to have a growing degree. So that, e.g., a poly.mul operation with two poly.poly<7> input polynomials would have a poly.poly<14> as output. This would require more work in the dialect definition to ensure the IR is valid, and the lowering would be more complex in that you’d have to manage tensors of different sizes. It would probably also have an impact on performance, since there would be more memory allocations and copying involved (but maybe the compiler could be smart enough to avoid that). I would like to do both so as to show how these differences materialize in lowerings and optimization passes. But for now I will start with the one that I think is easier: wrapping overflow semantics. To make that work, we need to add an attribute to our polynomial type representing its degree upper bound. The official docs on how to do this are here. This is where tablegen starts to be quite useful, because the changes only require two lines added to the type definition’s tablegen.  let parameters = (ins "int":$degreeBound);
let assemblyFormat = "< $degreeBound >";  The first line associates each instance of the type with an integer parameter (the "int" can be any string containing a literal C++ type), and a name $degreeBound that is used both for the generated C++ and to refer to it elsewhere in the tablegen file.

The second line is required because now that a poly type has this associated data, we need to be able to print it to and parse it from the textual IR representation. This simplest option is to let tablegen auto-generate the parser and printer for us, but we could have also used the line let hasCustomAssemblyFormat = 1; which would generate some headers that it expects us to fill in the implementation for. In the syntax of that line, tokens in backticks are literally printed/parsed.

After adding these two lines in this commit, the generated code gets quite a bit more complicated. Here’s the entirety of both files. Some things to note:

• PolynomialType gets a new int getDegreeBound() method, as well as a static get(MLIRContext, int) factory method.
• The parser and printer are upgraded to the new format.
• There is a new class called PolynomialTypeStorage that holds the int parameter and is hidden in an inner detail namespace.

The storage class is autogenerated for us now because integers have simple construction/destruction semantics. If we had a more complicated argument like an array that needed allocation, we’d have to implement special classes to define those semantics. And at the most extreme end, if we had a fully custom type parameter, we’d have to implement a storage class manually, implement things like hash_code, and register it with the dialect. For the curious, I had to do this in heir/pull/74 to implement a custom Polynomial type parameter.

The same commit updates the syntax test to include the type parameter.

Moving a bit faster now, we want to add some polynomial operations. This commit adds a polynomial addition op. The tablegen:

include "PolyDialect.td"
include "PolyTypes.td"

let summary = "Addition operation between polynomials.";
let arguments = (ins Polynomial:$lhs, Polynomial:$rhs);
let results = (outs Polynomial:$output); let assemblyFormat = "$lhs , $rhs attr-dict : ( type($lhs) , type($rhs) ) -> type($output)";
}


It looks very similar to a type, but the base class is Op, the arguments correspond to the operation’s inputs, and the assembly format is more complicated. We’ll enhance it in a future article, but the reason is that without inserting some special hooks, tablegen isn’t able to generate a parser that can infer what the types of the inputs and outputs should be when constructing it from the textual representation. [Aside: I feel like it should be able to with the simple example above, but for whatever reason the MLIR devs appear to have made auto-generated type inference opt-in via the trait infrastructure, see next article.]

Still, we can add a test and it passes

  // CHECK-LABEL: test_add_syntax
func.func @test_add_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> {
%0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
return %0 : !poly.poly<10>
}


The generated C++ has quite a bit more going on. Here’s the full generated code. Mainly the generated header defines AddOp which has getters for the arguments and results, “mutable” getters for modfying the op in place, parse/print, and generated builder methods. It also defines an AddOpAdaptor class which is used during lowering. The cpp file contains mostly rote implementations of these, converting from a generic internal representation to specific types and named objects that client code would use.

Next, in this commit I added a sub and mul operation, with a slight refactoring to make a base tablegen class for a binary operation.

Next, in this commit I added a from_tensor operation that converts a list of coefficients to a polynomial, and an eval op that evaluates a polynomial at a given input value.

In the next few articles we’ll expand on this dialect’s capabilities. First, we’ll study what other “batteries” we can include in the dialect itself, and how we can run optimizations on poly programs. Then we’ll study the nuances of lowering poly to existing MLIR dialects, and eventually through to LLVM and then machine code.

MLIR — Using Tablegen for Passes

In the last article in this series, we defined some custom lowering passes that modified an MLIR program. Notably, we accomplished that by implementing the required interfaces of the MLIR API directly.

This is not the way that most MLIR developers work. Instead, they use a code generation tool called tablegen to generate boilerplate for them, and then only add the implementation methods that are custom to their work. In this article, we’ll convert the custom passes we wrote previously to use this infrastructure, and going forward we will use tablegen for defining new dialects as well.

The basic idea of tablegen is that you can define a pass—or a dialect, as we’ll do next time—in the tablegen DSL (as specialized for MLIR) and then run the mlir-tblgen binary with appropriate flags, and it will generate headers for the functions you need to implement your pass, along with default implementations of some common functions, documentation, and hooks to register the pass/dialect with the mlir-opt-like entry point tool.

It sounds nice, but I have a love hate relationship with tablegen. I personally find it to be unpleasant to use, primarily because it provides poor diagnostic information when you do things wrong. Today, though, I realize that my part of my frustration came from having the wrong initial mindset around tablegen. I thought, incorrectly, that tablegen was an abstraction layer. That is, I could write my tablegen files, build them, and only think about the parts of the generated code that I needed to implement.

Instead, I’ve found that I still need to understand the details of the generated code, enough that I may as well read the generated code closely until it becomes familiar. This materializes in a few ways:

• mlir-tablegen doesn’t clearly tell you what functions are left unimplemented or explain the contract of the functions you have to write. For that you have to read the docs (which are often incomplete), or often the source code of the base classes that the generated boilerplate subclasses. Or, sadly, sometimes the Discourse threads or Phabricator commit messages are the only places to find the answers to some questions.
• The main way to determine what is missing is to try to build the generated code with some code that uses it, and then sift through hundreds of lines of C++ compiler errors, which in turn requires understanding the various template gymnastics in the generated code.
• The generated code will make use of symbols that you have to know to import or forward-declare in the right places, and it expects you to manage the namespaces in which the generated code lives.

With the expectation that tablegen is a totally see-through code generator, and then just sitting down and learning the MLIR APIs, the result is not so bad. Indeed, that was the motivation for building the passes in the last article from “scratch” (without tablegen), because it makes the transition to using tablegen more transparent. Pass generation is also a good starting point for tablegen because, by comparison, using tablegen for dialects results in many more moving parts and configurable options.

Tablegen files and the mlir-tblgen binary

We’ll start by migrating the AffineFullUnroll pass from last time. The starting point is to write some tablegen code. This commit implements it, which I’ll abbreviate below. If you want the official docs on how to do this, see Declarative Pass Specification.

include "mlir/Pass/PassBase.td"

def AffineFullUnroll : Pass<"affine-full-unroll"> {
let summary = "Fully unroll all affine loops";
let description = [{
Fully unroll all affine loops. (could add more docs here like code examples)
}];
let dependentDialects = ["mlir::affine::AffineDialect"];
}


[Aside: In the actual commit, there are two definitions, one for the AffineFullUnroll that walks the IR, and the other for AffineFullUnrollPatternRewrite which uses the pattern rewrite engine. In the article I’ll just show the generated code for the first.]

As you can see, tablegen has concepts like classes and class inheritance (the : Pass<...> is subclassing the Pass base class defined in PassBase.td, an upstream MLIR file. But the def keyword here specifically implies we’re instantiating this thing in a way that the codegen tool should see and generate real code for (as opposed to the class keyword, which is just for tablegen templating and code reuse).

So tablegen lets you let-define string variables and lists, but one thing that won’t be apparent in this article, but will be apparent in the next article, is that tablegen lets you define variables and use them across definitions, as well as define snippets of C++ code that should be put into the generated classes (which can use the defined variables). This is visible in the present context in the PassBase.td class which defines a code constructor variable. If you have a special constructor for your pass, you can write the C++ code for it there.

Next we define a build rule using gentbl_cc_library to run the mlir-tblgen binary (in the same commit) with the right options. This gentbl_cc_library bazel rule is provided by the MLIR devs, and it basically just assembles the CLI flags to mlir-tblgen and ensures the code-genned files end up in the right places on the filesystem and are compatible with the bazel dependency system. The build rule invocation looks like

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=Affine",
],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)


The important part here is that td_file specifies our input file, and tbl_outs defines the generated file, Passes.inc.h, which is at $GIT_ROOT/bazel-bin/lib/Transform/Affine/Passes.h.inc. The main quirk with gentbl_cc_library is that the name of the bazel rule is not the target that actually generates the code. That is, if you run bazel build pass_inc_gen (or from the git root, bazel build lib/Transform/Affine:pass_inc_gen), it won’t create the files but the build will be successful. Instead, under the hood gentbl_cc_library is a bazel macro that generates the rule pass_inc_gen_filegroup, which is what you have to bazel build to see the actual files. I’ve pasted the generated code (with both version of the AffineFullUnroll) into a gist and will highlight the important parts here. The first quirky thing the generated code does is use #ifdef as a sort of function interface for what code is produced. For example, you will see: #ifdef GEN_PASS_DECL_AFFINEFULLUNROLL std::unique_ptr<::mlir::Pass> createAffineFullUnroll(); #undef GEN_PASS_DECL_AFFINEFULLUNROLL #endif // GEN_PASS_DECL_AFFINEFULLUNROLL #ifdef GEN_PASS_DEF_AFFINEFULLUNROLL ... <lots of C++ code> ... #undef GEN_PASS_DEF_AFFINEFULLUNROLL #endif // GEN_PASS_DEF_AFFINEFULLUNROLL  This means that to use this file, we will need to define the appropriate symbol in a #define macro before including this header. You can see it happening in this commit, but in brief it will look like this // in file AffineFullUnroll.h #define GEN_PASS_DECL_AFFINEFULLUNROLL #include "lib/Transform/Affine/Passes.h.inc" // in file AffineFullUnroll.cpp #define GEN_PASS_DEF_AFFINEFULLUNROLL #include "lib/Transform/Affine/Passes.h.inc" ... <implement the missing functions from the generated code> ...  I’m no C++ expert, and this was the first time I’d seen this pattern of using #include as a function with #define as the argument. It was a little unsettling to me, until I landed on that mindset that it’s meant to be a white-box codegen, not an abstraction. So read the generated code. Inside the GEN_PASS_DECL_... guard, it defines a single function std::unique_ptr<::mlir::Pass> createAffineFullUnroll(); that is a very limited sole entry point for code that wants to use the pass. We don’t need to implement it unless our Pass has a custom constructor. Then in the GEN_PASS_DEF_... guard it defines a base class, whose functions I’ll summarize, but you should recognize many of them because we implemented them by hand last time. template <typename DerivedT> class AffineFullUnrollBase : public ::mlir::OperationPass<> { AffineFullUnrollBase() : ::mlir::OperationPass<>(::mlir::TypeID::get<DerivedT>()) {} AffineFullUnrollBase(const AffineFullUnrollBase &other) : ::mlir::OperationPass<>(other) {} static ::llvm::StringLiteral getArgumentName() {...} static ::llvm::StringRef getArgument() { ... } static ::llvm::StringRef getDescription() { ... } static ::llvm::StringLiteral getPassName() { ... } static ::llvm::StringRef getName() { ... } /// Support isa/dyn_cast functionality for the derived pass class. static bool classof(const ::mlir::Pass *pass) { ... } /// A clone method to create a copy of this pass. std::unique_ptr<::mlir::Pass> clonePass() const override { ... } /// Return the dialect that must be loaded in the context before this pass. void getDependentDialects(::mlir::DialectRegistry &registry) const override { registry.insert<mlir::affine::AffineDialect>(); } ... <type_id stuff> ... }  Notably, this doesn’t tell us what functions are left for us to implement. For that we have to either build it and read compiler error messages, or compare it to the base class (OperationPass) and it’s base class (Pass) to see that the only function left to implement is runOnOperation() Or, since we did this last time from the raw API, we can observe that the boilerplate functions we implemented before like getArgument are here, but runOnOperation is not. Another notable aspect of the generated code is that it uses the curiously recurring template pattern (CRTP), so that the base class can know the eventual name of its subclass, and use that name to hook the concrete subclass into the rest of the framework. Lower in the generated file you’ll see another #define-guarded block for GEN_PASS_REGISTRATION, which implements hooks for tutorial-opt to register the passes without having to depend on each internal Pass class directly. #ifdef GEN_PASS_REGISTRATION inline void registerAffineFullUnroll() { ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return createAffineFullUnroll(); }); } inline void registerAffineFullUnrollPatternRewrite() { ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return createAffineFullUnrollPatternRewrite(); }); } inline void registerAffinePasses() { registerAffineFullUnroll(); registerAffineFullUnrollPatternRewrite(); } #undef GEN_PASS_REGISTRATION #endif // GEN_PASS_REGISTRATION  This implies that, once we link everything properly, the changes to tutorial-opt (in this commit) simplify to calling registerAffinePasses. This registration macro is intended to go into a Passes.h file that includes all the individual pass header files, as done in this commit. And we can use that header file as an anchor for a bazel build target that includes all the passes defined in lib/Transform/Affine at once. #include "lib/Transform/Affine/AffineFullUnroll.h" #include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h" namespace mlir { namespace tutorial { #define GEN_PASS_REGISTRATION #include "lib/Transform/Affine/Passes.h.inc" } // namespace tutorial } // namespace mlir  Finally, after all this (abbreviated from this commit), the actual content of the pass reduces to the following subclass (with CRTP) and implementation of runOnOperation, the body of which is identical to the last article except for a change from reference to pointer for the return value of getOperation. #define GEN_PASS_DEF_AFFINEFULLUNROLL #include "lib/Transform/Affine/Passes.h.inc" struct AffineFullUnroll : impl::AffineFullUnrollBase<AffineFullUnroll> { using AffineFullUnrollBase::AffineFullUnrollBase; void runOnOperation() { getOperation()->walk([&](AffineForOp op) { if (failed(loopUnrollFull(op))) { op.emitError("unrolling failed"); signalPassFailure(); } }); } };  I split the AffineFullUnroll migration into multiple commits to highlight the tablegen vs C++ code changes. For MulToAdd, I did it all in one commit. The tests are unchanged, because the entry point is still the tutorial-opt binary with the appropriate CLI flags, and those names are unchanged in the tablegen’ed code. Bonus: mlir-tblgen also has an option -gen-pass-doc, which you’ll see in the commits, which generates a markdown file containing auto-generated documentation for the pass. A CI workflow can copy these to a website directory, as we do in HEIR, and you get free docs. See this example from HEIR. Addendum: hermetic Python When I first set up this tutorial project, I didn’t realize that bazel’s Python rules use the system Python by default. Some early readers found an error that Python couldn’t find the lit module when running tests. While pip install lit in your system Python would work, I also migrated in this commit to a hermetic python runtime and explicit dependency on lit. It should all be handled automatically by bazel now. MLIR — Writing Our First Pass Table of Contents This series is an introduction to MLIR and an onboarding tutorial for the HEIR project. Last time we saw how to run and test a basic lowering. This time we will write some simple passes to illustrate the various parts of the MLIR API and the pass infrastructure. As mentioned previously, the main work in MLIR is defining passes that either optimize part of a program, lower from parts of one dialect to others, or perform various normalization and canonicalization operations. In this article, we’ll start by defining a pass that operates entirely within a given dialect by fully unrolling loops. Then we’ll define a pass that does a simple replacement of one instruction with another. Neither pass will be particularly complex, but rather they will show how to set up a pass, how to navigate through a program via the MLIR API, and how to modify the IR by deleting and adding operations. The code for this post is contained within this pull request. tutorial-opt and project organization Last time we used the mlir-opt binary as the main entry point to parse MLIR, run a pass, and emit the output IR. A compiler might run mlir-opt as a subroutine in between the front end (C++ to some MLIR dialects) and the backend (MLIR’s LLVM dialect to LLVM to machine code). In an out-of-tree MLIR project, mlir-opt can’t be used because it isn’t compiled with the project’s custom dialects or passes. Instead, MLIR makes it easy to build a custom version of the mlir-opt tool for an out-of-tree project. It primarily provides a set of registration hooks that you can use to plug in your dialects and passes, and the framework handles reading/writing, CLI flags, and adds that all on top of the baseline MLIR passes and dialects. We’ll start this article by creating the shell for such a tool with an empty custom pass, which we’ll call tutorial-opt. If this repository were to become one step of an end-to-end compiler, then tutorial-opt would be the main interface to the MLIR part. The structure of the codebase is a persnickety question here. A typical MLIR codebase seems to split the code into two directories with roughly equivalent hierarchies: an include/ directory for headers and tablegen files (more on tablegen in a future article), and a lib/ directory for implementation code. Then, within those two directories a project would have a Transform/ subdirectory that stores the files for passes that transform code within a dialect, Conversion/ for passes that convert between dialects, Analysis/ for analysis passes, etc. Each of these directories might have subdirectories for the specific dialects they operate on. For this tutorial we will do it slightly differently by merging include/ and lib/ together (header files will live next to implementation files). I believe the reason that C++ codebases separate this is a combination of implicit public/private interface (client code should only depend on headers in include/, not headers in lib/ or src/). But bazel has many more facilities for enforcing private/public interface boundaries, I find it tedious to navigate parallel directory structures, and this is a tutorial so simpler is better. So the project’s directory structure will add like this once we create the initial pass: . ├── README.md ├── WORKSPACE ├── bazel │ └── . . . ├── lib │ └── Transform │ └── Affine │ ├── AffineFullUnroll.cpp │ ├── AffineFullUnroll.h │ └── BUILD ├── tests │ └── . . . └── tools ├── BUILD └── tutorial-opt.cpp  Unrolling loops, a starter pass Though MLIR provides multiple mechanisms for defining loops and control flow, the highest level one is in the affine dialect. Originally defined for polyhedral loop analysis (using lattices to study loop structure!), it also simply defines a nice for operation that you can use whenever you have simple loop bounds like iterating over a range with an optional step size. An example loop that sums some values in an array stored in memory might look like: func.func @sum_buffer(%buffer: memref<4xi32>) -> (i32) { %sum_0 = arith.constant 0 : i32 %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 { %t = affine.load %buffer[%i] : memref<4xi32> %sum_next = arith.addi %sum_iter, %t : i32 affine.yield %sum_next : i32 } return %sum : i32 }  The iter_args is a custom bit of syntax that defines accumulation variables to operate across the loop body (to be in compliance with SSA form; for more on SSA, see this MLIR doc), along with an initial value. Unrolling loops is a nontrivial operation, but thankfully MLIR provides a utility method for fully unrolling a loop, so our pass will be a thin wrapper around this function call, to showcase some of the rest of the infrastructure before we write a more meaningful pass. The code for this section is in this commit. This implementation will be technically the most general implementation, by implementing directly from the C++ API, rather than using the more special case features like the pattern rewrite engine, the dialect conversion framework, or tablegen. Those will all come later. The main idea is to implement the required methods for the OperationPass base class, which “anchors” the pass to work within the context of a specific instance of a specific type of operation, and is applied to every operation of that type. It looks like this: // lib/Transform/Affine/AffineFullUnroll.h class AffineFullUnrollPass : public PassWrapper<AffineFullUnrollPass, OperationPass<mlir::func::FuncOp>> { private: void runOnOperation() override; // implemented in AffineFullUnroll.cpp StringRef getArgument() const final { return "affine-full-unroll"; } StringRef getDescription() const final { return "Fully unroll all affine loops"; } };  The PassWrapper helps implement some of the required methods for free (mainly adding a compliant copy method), and uses the Curiously Recurring Template Pattern (CRTP) to achieve that. But what matters for us is that OperationPass<FuncOp> anchors this pass to operation on function bodies, and provides the getOperation method in the class which returns the FuncOp being operated on. Aside: The MLIR docs more formally describe what is required of an OperationPass, and in particular it limits the “anchoring” to specific operations like functions and modules, the insides of which are isolated from modifying the semantics of the program outside of the operation’s scope. That’s a fancy way of saying FuncOps in MLIR can’t screw with variables outside the lexical scope of their function body. More importantly for this example, it explains why we can’t anchor this pass on a for loop operation directly: a loop can modify stuff outside its body (like the contents of memory) via the operations within the loop (store, etc.). This matters because the MLIR pass infrastructure runs passes in parallel. If some other pass is tinkering with neighboring operations, race conditions ensue. The three functions we need to implement are • runOnOperation: the function that performs the pass logic. • getArgument: the CLI argument for an mlir-opt-like tool. • getDescription: the CLI description when running --help on the mlir-opt-like tool. The initial implementation of runOperation is empty in the commit for this section. Next, we create a tutorial-opt binary that registers the pass. // tools/tutorial-opt.cpp #include "lib/Transform/Affine/AffineFullUnroll.h" #include "mlir/include/mlir/InitAllDialects.h" #include "mlir/include/mlir/Pass/PassManager.h" #include "mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::PassRegistration<mlir::tutorial::AffineFullUnrollPass>(); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry)); }  This registers all the built-in MLIR dialects, adds our AffineFullUnrollPass, and then calls the MlirOptMain function which handles the rest. At this point we can run bazel run tools:tutorial-opt --help and see a long list of options with our new pass in it. OVERVIEW: Tutorial Pass Driver Available Dialects: acc, affine, amdgpu, <...SNIP...> USAGE: tutorial-opt [options] <input file> OPTIONS: General options: Compiler passes to run Passes: --affine-full-unroll - Fully unroll all affine loops --allow-unregistered-dialect - Allow operation with no registered dialects --disable-i2p-p2i-opt - Disables inttoptr/ptrtoint roundtrip optimization <...SNIP...>  To allow us to run lit tests that use this tool, we add it to the test_utilities target in this commit, and then we add a first (failing) test in this commit. To avoid complexity, I’m just asserting that the output has no for loops in it. // RUN: tutorial-opt %s --affine-full-unroll > %t // RUN: FileCheck %s < %t func.func @test_single_nested_loop(%buffer: memref<4xi32>) -> (i32) { %sum_0 = arith.constant 0 : i32 // CHECK-NOT: affine.for %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 { %t = affine.load %buffer[%i] : memref<4xi32> %sum_next = arith.addi %sum_iter, %t : i32 affine.yield %sum_next : i32 } return %sum : i32 }  Next, we can implement the pass itself in this commit: #include "lib/Transform/Affine/AffineFullUnroll.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/include/mlir/Pass/Pass.h" using mlir::affine::AffineForOp; using mlir::affine::loopUnrollFull; void AffineFullUnrollPass::runOnOperation() { getOperation().walk([&](AffineForOp op) { if (failed(loopUnrollFull(op))) { op.emitError("unrolling failed"); signalPassFailure(); } }); }  getOperation returns a FuncOp, though we don’t use any specific information about it being a function. We instead call the walk method (present on all Operation instances), which traverses the abstract syntax tree (AST) of the operation in post-order (i.e., the function body), and for each operation it encounters, if the type of that operation matches the input type of the callback, the callback is executed. In our case, we attempt to unroll the loop, and if it fails we quit with a diagnostic error. Exercise: determine how the loop unrolling might fail, and create a test MLIR input that causes it to fail, and observe the error messages that result. Running this on our test shows the operation is applied: $ bazel run tools:tutorial-opt -- --affine-full-unroll < tests/affine_loop_unroll.mlir
<...>
#map = affine_map<(d0) -> (d0 + 1)>
#map1 = affine_map<(d0) -> (d0 + 2)>
#map2 = affine_map<(d0) -> (d0 + 3)>
module {
func.func @test_single_nested_loop(%arg0: memref<4xi32>) -> i32 {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%0 = affine.load %arg0[%c0] : memref<4xi32>
%1 = arith.addi %c0_i32, %0 : i32
%2 = affine.apply #map(%c0)
%3 = affine.load %arg0[%2] : memref<4xi32>
%4 = arith.addi %1, %3 : i32
%5 = affine.apply #map1(%c0)
%6 = affine.load %arg0[%5] : memref<4xi32>
%7 = arith.addi %4, %6 : i32
%8 = affine.apply #map2(%c0)
%9 = affine.load %arg0[%8] : memref<4xi32>
%10 = arith.addi %7, %9 : i32
return %10 : i32
}
}


I won’t explain what this affine.apply thing is doing, but suffice it to say the loop is correctly unrolled. A subsequent commit does the same test for a doubly-nested loop.

A Rewrite Pattern Version

In this commit, we rewrote the loop unroll pass in the next level of abstraction provided by MLIR: the pattern rewrite engine. It is useful in the kind of situation where one wants to repeatedly apply the same subset of transformations to a given IR substructure until that substructure is completely removed. The next section will write a pass that uses that in a meaningful way, but for now we’ll just rewrite the loop unroll pass to show the extra boilerplate.

A rewrite pattern is a subclass of OpRewritePattern, and it has a method called matchAndRewrite which performs the transformation.

struct AffineFullUnrollPattern :
public OpRewritePattern<AffineForOp> {
AffineFullUnrollPattern(mlir::MLIRContext *context)
: OpRewritePattern<AffineForOp>(context, /*benefit=*/1) {}

LogicalResult matchAndRewrite(AffineForOp op,
PatternRewriter &rewriter) const override {
return loopUnrollFull(op);
}
};


The return value of matchAndRewrite is a LogicalResult, which is a wrapper around a boolean to signal success or failure, along with named utility functions like failure() and success() to generate instances, and failed(...) to test for failure. LogicalResult also comes with a subclass FailureOr that is subclass of optional that inter-operates with LogicalResult via the presence or absence of a value.

Aside: In a proper OpRewritePattern, the mutations of the IR must go through the PatternRewriter argument, but because loopUnrollFull doesn’t have a variant that takes a PatternRewriter as input, we’re violating that part of the function contract. More generally, the PatternRewriter handles atomicity of the mutations that occur within the OpRewritePattern, ensuring that the operations are applied only if the method reaches the end and succeeds.

Then we instantiate the pattern inside the pass

// A pass that invokes the pattern rewrite engine.
void AffineFullUnrollPassAsPatternRewrite::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
// One could use GreedyRewriteConfig here to slightly tweak the behavior of
// the pattern application.
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}


The overall pass is still anchored on FuncOp, but an OpRewritePattern can match against any op. The rewrite engine invokes the walk that we did manually, and one can pass an optional configuration struct that chooses the walk order.

The PatternSet can accept any number of patterns, and the greedy rewrite engine will keep trying to apply them (in a certain order related to the benefit constructor argument) until there are no matching operations to apply, all applied patterns return failure, or some large iteration limit is reached to avoid infinite loops.

A proper greedy RewritePattern

In homomorphic encryption (FHE), multiplication ops are typically MUCH more expensive than addition ops. I’m not an expert in classical CPU hardware performance, but if I recall correctly, on a typical CPU multiplication is something like 2-4x slower than addition, and that advantage probably goes away when doing multiplications in bulk/pipelines.

In FHE, the operations introduce noise growth in the ciphertext, and in some schemes multiplication introduces something like 100x the noise of addition, and reducing that noise is very time consuming. So it makes sense that one would want to convert multiplication ops into addition ops. In this section we’ll write a very simple pass that greedily rewrites multiplication ops as repeated additions.

The idea is to rewrite an operation like y = 9*x as y = 8*x + x (the 8 is a power of 2) and then expand it further as a = x+x; b = a+a; c = b+b; y = c+x. It replaces a multiplication by a constant with a roughly log-number of additions (the base-2 logarithm of the constant), though it gets worse the further away the constant gets from a power of two.

This commit contains a similar “empty shell” of a pass, with two patterns defined. The first, PowerOfTwoExpand, will be a pattern that rewrites y=C*x as y = C/2*x + C/2*x, when C is a power of 2, otherwise fails. The second, PeelFromMul “peels” a single addition off a product that is not with a power of 2, rewriting y = 9x as y = 8*x + x. These are applied repeatedly via the greedy pattern rewrite engine. By setting the benefit argument of PowerOfTwoExpand to be larger than PeelFromMul, we tell the greedy rewrite engine to prefer PowerOfTwoExpand whenever possible. Together, that achieves the transformation mentioned above.

This commit adds a failing test that only exercises PowerOfTwoExpand, and then this commit implements it. Here’s the implementation:

  LogicalResult matchAndRewrite(
MulIOp op, PatternRewriter &rewriter) const override {
Value lhs = op.getOperand(0);

// canonicalization patterns ensure the constant is on the right, if there is a constant
// See https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules
Value rhs = op.getOperand(1);
auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
if (!rhsDefiningOp) {
return failure();
}

int64_t value = rhsDefiningOp.value();
bool is_power_of_two = (value & (value - 1)) == 0;

if (!is_power_of_two) {
return failure();
}

ConstantOp newConstant = rewriter.create<ConstantOp>(
rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value / 2));
MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);

rewriter.eraseOp(rhsDefiningOp);

return success();
}


Some notes:

• Value is the type that represents an SSA value (i.e., an MLIR variable), and getDefiningOp fetches the unique operation that defines it in its scope.
• There are a variety of “casting” operations like rhs.getDefiningOp<arith::ConstantIntOp>() that take the type you want as output as a template parameter, and return null if the type cannot be converted. You might also see dyn_cast<>
• (value & (value - 1)) is a classic bit-twiddling trick to compute if an integer is a power of two. We check it and skip the pattern if it’s not.
• The actual constant itself is represented as an MLIR attribute, which is essentially compile-time static data attached to the op. You can put strings or dictionaries as attributes, but for ConstantOp it’s just an int.

The rewriter.create part is where we actually do the real work. Create a new constant that is half the original constant, create new multiplication and addition ops, and then finally rewriter.replaceOp removes the original multiplication op and uses the output of newAdd for any other operations that used the original multiplication op’s output.

It’s worth noting that we’re relying on MLIR’s built-in canonicalization passes in a few ways here:

• To ensure that the constant is always the second operand of a multiplication op.
• To ensure that the base case (x*1) is “folded” into a plain x and the constant 1 is removed.
• The fold part of applyPatternsAndFoldGreedily is what runs these cleanup steps for us.

PeelFromMul is similar, implemented and tested in this commit:

  LogicalResult matchAndRewrite(MulIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
auto rhsDefiningOp = rhs.getDefiningOp<arith::ConstantIntOp>();
if (!rhsDefiningOp) { return failure(); }
int64_t value = rhsDefiningOp.value();

// We are guaranteed value is not a power of two, because the greedy
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
// it has higher benefit.

ConstantOp newConstant = rewriter.create<ConstantOp>(
rhsDefiningOp.getLoc(), rewriter.getIntegerAttr(rhs.getType(), value - 1));
MulIOp newMul = rewriter.create<MulIOp>(op.getLoc(), lhs, newConstant);

rewriter.eraseOp(rhsDefiningOp);



Running it! Input:

func.func @power_of_two_plus_one(%arg: i32) -> i32 {
%0 = arith.constant 9 : i32
%1 = arith.muli %arg, %0 : i32
func.return %1 : i32
}


Output:

module {
func.func @power_of_two_plus_one(%arg0: i32) -> i32 {
%0 = arith.addi %arg0, %arg0 : i32
%1 = arith.addi %0, %0 : i32
%2 = arith.addi %1, %1 : i32
%3 = arith.addi %2, %arg0 : i32
return %3 : i32
}
}


Exercise: Try swapping the benefit arguments to see how the output changes.

Though this pass is quite naive, you can imagine a more sophisticated technique that builds a cost model for multiplications and additions, and optimizes for the cheapest cost representation of an arithmetic operation in terms of repeated additions, multiplications, and other supported ops.

Should we walk?

With two options for how to define a pass—one to walk the entire syntax tree from the root operation, and one to match and rewrite patterns with the rewrite engine—the natural question is when should you use one versus the other.

The MLIR docs describe the motivation behind the pattern rewrite engine, and it comes from a long history of experience with the LLVM project. For one, the pattern rewrite engine expresses a convenient subset of what can be achieved with an MLIR pass. This is conceptually trivial, in the sense that anyone who can walk the entire AST can, with enough effort, do anything they want including reimplementing the pattern rewrite engine.

More practically, the pattern rewrite engine is convenient to represent local transformations. “Local” here means that the input and output can be detected via a subset of the AST as a directed acyclic graph. More pragmatically, think of it as any operation you can identify by looking around at neighboring operations in the same block and applying some filtering logic. E.g., “is this exp operation followed by a log operation with no other uses of the output of the exp?”

On the other hand, some analyses and optimizations need to construct the entire dataflow of a program to work. A good example is common subexpression elimination, which determines whether it is cost effective to extract a subexpression used in multiple places into a separate variable. Doing so may introduce additional cost of memory access, so it depends both on the operation’s cost and on the availability of registers at that point in the program. You can’t get this information by pattern matching the AST locally.

The wisdom seems to be: using the pattern rewrite engine is usually easier than writing a pass that. walks the AST. You don’t need large case/switch statements to handle everything that could show up in the IR. The engine handles re-applying patterns many times. And so you can write the patterns in isolation and trust the engine to combine them appropriately.

Bonus: IDEs and CI

Since we explored the C++ API, it helps to have an IDE integration. I use neovim with the clangd LSP, and to make it work with a Bazel C++ project, one needs to use something analogous to Hedron Vision’s compile_commands extractor, which I configured for this tutorial project in this commit. It’s optional, but if you want to use it you have to run bazel run @hedron_compile_commands//:refresh_all once to set it up, and then clangd and clang-tidy, etc., should find the generated json file and use it. Also, if you edit a BUILD file, you have to re-run refresh_all for the changes to show up in the LSP.

Though it’s not particularly relevant to this tutorial, I also added a commit that configures GitHub actions to build and test the project in CI in this commit. It is worth noting that the GitHub cache action reduces subsequent build times from 1-2 hours down to just a few minutes.

Thanks to Patrick Schmidt for feedback on a draft of this article.

MLIR — Running and Testing a Lowering

Last time, we covered a Bazel build system setup for an MLIR project. This time we’ll give an overview of a simple lowering and show how end-to-end tests work in MLIR. All of the code for this article is contained in this pull request on GitHub, and the commits are nicely organized and quite readable.

Two of the central concepts in MLIR are dialects and lowerings. These are the scaffolding within which we can do the truly interesting parts of a compiler—that is, the optimizations and analyses. In traditional compilers, there is typically one “dialect” (called an intermediate representation, or IR) that is the textual or data-structural description of a program within the compiler’s code. For example, in GCC the IR is called GIMPLE, and in LLVM it’s called LLVM-IR. They convert the input program to the IR, do their optimizations, and then convert the optimized IR to machine code.

In MLIR one splits the job into much smaller steps. First, MLIR allows one to define many dialects, some considered “high level” and some “low level,” but each with a set of types, operations, metadata, and semantics that defines what the operations do. Then, one writes a set of lowering passes that incrementally converts different parts of the program from higher level dialects to lower and lower dialects until you get to machine code (or, in many cases, LLVM, which finishes the job). Along the way, optimizing passes are run to make the code more efficient. The main point here is that the high level dialects exist so that they make it easy to write these important optimizing passes. And there’s not a special distinction between lowering passes and optimizing passes, they’re both just called passes in MLIR and are generic IR-rewriting modules.

Aside: From what I can gather, a big part of the motivation for MLIR was to build the affine dialect, which is specifically designed to enable polyhedral optimizations for loop transformations, along with the linalg dialect, which does optimization passes like tiling for low-level ML operations on specialized hardware. Folks built polyhedral optimizations in LLVM and GCC (without affine), and it was a huge pain in the ass, mainly because they had to take a low-level mess of branches and GOTOs and try to reconstruct a simple (‘affine’) for loop structure from it. This was necessary even if the input program was a simple set of for loops, because by the time they got to the compiler, the rigid for loop structure had been discarded. MLIR instead says, keep the structure in the higher level dialect, optimize there, and then discard it when you lower to lower level dialects.

Two example programs

A general understanding properly begins with concrete examples. Here are two MLIR programs that define a function that counts the leading zeroes of a 32-bit integer (i32). The first uses the math dialect’s defined ctlz operation and just returns it.

func.func @main(%arg0: i32) -> i32 {
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}


This shows the basic structure of an MLIR operation (see here for a more complete spec). Variable names are prefixed with %, functions by @, and each variable/value in a program has a type, often expressed after a colon. In this case all the types are i32, except for the function type which is (i32) -> i32 (not specified explicitly above, but you’ll see it in the func.call in the next code snippet).

Each statement is anchored around an expression like math.ctlz which specifies the dialect math and the operation ctlz. The rest of the syntax of the operation is determined by a parser defined by the dialect, and so many operations will have different syntaxes, though many are pulled from a fixed set of options we’ll see later in the series. In the simple case of math.ctlz, the sole argument is the integer whose leading zeros are to be counted, and the trailing  : i32 denotes the output type.

It’s also important to note that func is itself a dialect, and func.func is considered an “operation,” where the braces and the function’s body is part of the syntax. In MLIR a set of operations within braces is called a region, and an operation can have zero or many regions.

There is a lot more to say about regions, and their cousins “basic blocks,” but in brief: operations may have attached regions, like the body of a for loop, and each region is a list of blocks (with an implicit block if non is explicitly listed). A block is a list of operations that is guaranteed to have only one entry and exit point. I think of the label in a block as the destination of a jump command in assembly languages. A block has exactly one “jumping in” point and one “jumping out” point. It has a more precise definition that aligns with the classical compiler concept of a basic block.

Also note, in MLIR multiple dialects often coexist in the same program as it is progressively lowered to some final backend target.

The second version of this program has a software implementation of the ctlz function and calls it.

func.func @main(%arg0: i32) -> i32 {
%0 = func.call @my_ctlz(%arg0) : (i32) -> i32
func.return %0 : i32
}
func.func @my_ctlz(%arg0: i32) -> i32 {
%c32_i32 = arith.constant 32 : i32
%c0_i32 = arith.constant 0 : i32
%0 = arith.cmpi eq, %arg0, %c0_i32 : i32
%1 = scf.if %0 -> (i32) {
scf.yield %c32_i32 : i32
} else {
%c1 = arith.constant 1 : index
%c1_i32 = arith.constant 1 : i32
%c32 = arith.constant 32 : index
%c0_i32_0 = arith.constant 0 : i32
%2:2 = scf.for %arg1 = %c1 to %c32 step %c1 iter_args(%arg2 = %arg0, %arg3 = %c0_i32_0) -> (i32, i32) {
%3 = arith.cmpi slt, %arg2, %c0_i32 : i32
%4:2 = scf.if %3 -> (i32, i32) {
scf.yield %arg2, %arg3 : i32, i32
} else {
%5 = arith.addi %arg3, %c1_i32 : i32
%6 = arith.shli %arg2, %c1_i32 : i32
scf.yield %6, %5 : i32, i32
}
scf.yield %4#0, %4#1 : i32, i32
}
scf.yield %2#1 : i32
}
func.return %1 : i32
}


The algorithm above is not relevant to this post, but either way it is quite simple: count the leading zeros by shifting the input left one bit at a time until it becomes negative (as a signed integer), because that occurs exactly when its leading bit is a 1. Then add a special case to handle zero, which would loop infinitely otherwise.

Here you can see two more MLIR dialects. arith is for low-level arithmetic and boolean conditions on integers and floats. You can define constants, compare integers with arith.cmpi, and do things like add and bit shift (arith.shli is a left shift). scf, short for “structured control flow,” defines for loops, while loops, and control flow branching. scf.yield defines the “output” value from each region of an if/else operation or loop body which is necessary here because, as you can see, an if operation has a result value.

Two other minor aspects of the syntax are on display. First is the syntax %4:2, which defines a variable %4 which is a tuple of two values. The corresponding %4#1 accesses the second entry in the tuple. Second, you’ll notice there’s a type called index that is different from i32. Though they both represent integers, index is intended to be a platform-dependent integer type which is suitable for indexing arrays, representing sizes and dimensions of things, and, in our case, being loop counters and iteration bounds. More details on index in the MLIR docs.

Lowerings and the math-to-funcs pass

We have two versions of the same program because one is a lowered version of the other. In most cases, the machine you’re going to run a program on has a “count leading zeros” function, so the lowering would simply map math.ctlz to the corresponding machine instruction. But if there is no ctlz instruction, a lowering can provide an implementation in terms of lower level dialects and ops. Specifically, this one lowers ctlz to {func, arith, scf}.

The second version of this code was actually generated by the mlir-opt command line tool, which is the main entry-point to running MLIR passes on specific MLIR code. For starters, one can take the mlir-opt tool and run it with no arguments on any MLIR code, and it will parse it, verify it is well formed, and print it back out with some slight normalizations. In this case, it will wrap the code in a module, which is a namespace isolation mechanism.

$echo 'func.func @main(%arg0: i32) -> i32 { %0 = math.ctlz %arg0 : i32 func.return %0 : i32 }' > ctlz.mlir$ bazel run @llvm-project//mlir:mlir-opt -- $(pwd)/ctlz.mlir <... snip ...> module { func.func @main(%arg0: i32) -> i32 { %0 = math.ctlz %arg0 : i32 return %0 : i32 } }  Aside: The --$(pwd)/ctlz.mlir is a quirk of bazel. When one program runs another program, the -- is the standard mechanism to separate CLI flags from the runner program (bazel) and the run program (mlir-opt). Everything after -- goes to mlir-opt. Also, the need for $(pwd) is because when bazel runs mlir-opt, it runs it with a working directory that is in some temporary, isolated location on the filesystem. So we need to give it an absolute path to the MLIR file to input. Or we could pipe from standard in. Or we could run the mlir-opt binary directly from bazel-bin/external/llvm-project/mlir/mlir-opt. Next we can run our first lowering, which is already built-in to mlir-opt, and which generates the long program above. $ bazel run @llvm-project//mlir:mlir-opt -- --convert-math-to-funcs=convert-ctlz \$(pwd)/ctlz.mlir
<... snip ...>
module {
func.func @main(%arg0: i32) {
%0 = call @__mlir_math_ctlz_i32(%arg0) : (i32) -> i32
return
}
<... snip ...>


Each pass gets its own command line flag, some are grouped into pipelines, and the --pass-pipeline command line flag can be used to provide a (serialized version of) an ordered list of passes to run on the input MLIR.1

We won’t cover the internal workings of the math-to-funcs pass in this or a future article, but next time we will actually write our own, simpler pass that does something nontrivial. Until then, I’ll explain a bit about how testing works in MLIR, using these two ctlz programs as example test cases.

For those who are interested, the MLIR documentation contains a complete list of passes owned by the upstream MLIR project, which can be used by invoking the corresponding command line flag or nesting it inside of a larger --pass-pipeline.

Lit, FileCheck, and Bazel again

The LLVM and MLIR projects both use the same testing framework, which is split into two parts. The first is lit (LLVM Integrated Tester; though I don’t know why it’s called “integrated”), which handles test discovery and running. The second is FileCheck, which handles test assertions and reporting.

I don’t know why they’re two separate tools, but they are primarily end to end testing tools—as opposed to unit testing tools. Because end-to-end testing in a compiler toolchain implies the inputs and outputs are essentially big strings (programs) in unknown languages (user-defined MLIR dialects), these tools basically have you express the test setup and assertions in comments inside of the file representing the input program to be tested. An example might look like this:

// RUN: mlir-opt %s --convert-math-to-funcs=convert-ctlz | FileCheck %s

func.func @main(%arg0: i32) -> i32 {
// CHECK-NOT: math.ctlz
// CHECK: call
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}


I added this test in this commit. This trivial function calls math.ctlz on its input and promptly returns it. The interesting parts are the comments, which define the test command and assertions.

Warning: you may run into a python issue where python cannot find the lit module. See https://github.com/j2kun/mlir-tutorial/issues/8, wherein I realized too late that bazel uses the system Python by default. tl;dr: you can either run pip install lit on your system Python, or else cherry-pick a commit in that issue to use a bazel-managed Python.

A lit test file contains some number of lines with RUN: as the lead of a comment, and the text after that describes a shell script to run, with some magic strings instructing lit to make substitutions. In this case %s is the current file path, but there is a table of default substitutions and one can add their own custom substitutions in a config file (see later).

In typical unix fashion, all lit does is check for the exit status of the RUN command to determine if the test passes or fails. Hence, this test pipes the output of mlir-opt to the FileCheck program, again passing in the current file path, which contains the assertions to check for.

FileCheck is a bit more complicated than lit, but in brief it takes the input passed to stdin, scans for CHECK comments in the file passed as the CLI argument, and then for each CHECK comment, it does some logic to determine if the assertion passes. The simplest kind of assertion is a // CHECK: foo which searches for a line matching the foo regular expression. Similarly, a // CHECK-NOT: assertion asserts the regular expression is not matched in the file. Beyond that, the main thing that is enforced is that multiple CHECK assertions match the input file in the same order that the CHECK comments occur. So if you had

// RUN: mlir-opt %s --convert-math-to-funcs=convert-ctlz | FileCheck %s

func.func @main(%arg0: i32) -> i32 {
// CHECK: call
// CHECK: foo
// CHECK: return
%0 = math.ctlz %arg0 : i32
func.return %0 : i32
}


Then it would expect that there are three lines (possibly with other lines between them) that match these regular expressions in order. A line matching call and a line matching return would fail unless there is a line between them matching foo.

FileCheck can do a lot more, like use regular expressions to capture variable names and then refer to them in later CHECK assertions. With this, one can give a much more precise test on the ctlz lowering, expecting a relatively rigid structure of the output function, as in this commit. I won’t give full details here, but you can read the FileCheck documentation here and intuitively tell that an expression like %[[ARGCMP:.*]] captures a variable name where ARCMP is how it is referred to in later assertions, while .* is the regular expression used to capture it (and % is an anchor that ensures it only applies to a variable name).

To run these tests, requires a bit of finnicky configuration. lit can be run like a normal python program, and if all the executables invoked in the RUN directives are in the PATH environment variable, then it will just work. However, in any build system the executables are in exotic places, which is especially true in Bazel.

Aside: In typical CMake-oriented MLIR projects, there are actually two config files, one called something like lit.site.cfg.py.in, which has variables that CMake substitutes in pointing to the build artifact paths, and one called lit.cfg.py which configures lit to use those paths. In my opinion the Bazel configuration is marginally simpler, but I am biased because I’m more familiar with it.

In bazel a single test would correspond to a build target in a BUILD file of the following form:

 py_test(
name = "my_test_file.mlir.test",
srcs = ["@llvm_project//llvm:lit"],
args = ["-v", "tests/my_test_file.mlir"],
data = ["@llvm-project//llvm:FileCheck", ..., ":my_test_file.mlir"],
main = "lit.py",
)


This tells bazel to run lit with the right arguments, and crucially, the data argument allows us to pull in binary targets corresponding to the commands used in the RUN directives. Then two things happen. First, lit runs in a working directory determined arbitrarily by bazel (with the data stuff pulled in somehow). We need to understand the directory structure of this working directory in order to configure lit properly. Then, when lit runs, it looks for a file called lit.cfg.py in the directory containing the test (and recursively upward to the project root), loads it, and uses that to set the PATH and other configuration.

In our case, the lit.cfg.py looks like this

import os
from pathlib import Path
from lit.formats import ShTest

config.name = "mlir_tutorial"
config.test_format = ShTest()
config.suffixes = [".mlir"]

runfiles_dir = Path(os.environ["RUNFILES_DIR"])
tool_relpaths = [
"llvm-project/mlir",
"llvm-project/llvm",
]

config.environment["PATH"] = (
":".join(str(runfiles_dir.joinpath(Path(path))) for path in tool_relpaths)
+ ":"
+ os.environ["PATH"]
)


Two weird things are happening here. First, config is an undefined variable at first glance, but the lit documentation states that an instance is inserted into the module scope when lit runs. Second, we are using the RUNFILES_DIR environment variable as the base for the paths we will construct pointing to the binaries. RUNFILES_DIR is defined by bazel, and it is generally different from the working directory of the binary run by native.py_test. It contains a directory tree for the current project (mlir_tutorial) as well as all dependent projects defined in the WORKSPACE, so long as some targets from those projects were included in the data option of the test rule.

Once this is all worked out, one can define individual test targets for each lit test. However, since that is laborious, instead what I did in this commit was define all of the above configuration together with a bazel macro that will search for all .mlir files in a given directory and create test targets for them. So in this project, a new .mlir file added to tests/ will be automatically run when you run bazel test //.... Then, tests/BUILD contains the glob_lit_tests macro invocation, and a filegroup that describes all the tools and files that should be included in the data to run them.

# tests/BUILD

# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tests:lit.cfg.py",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
"@llvm-project//llvm:not",
"@llvm-project//mlir:mlir-opt",
],
)

glob_lit_tests()


Bonus: functional testing

The previous lowering of math.ctlz to a software implementation has a very detailed test, but in MLIR the lowerings are primarily syntactic in nature. That is, the test does not assert that the lowering itself is functionally correct. The author may have created assertions that align with the generated code, but the generated code has a bug or otherwise does not compute what it is supposed to compute.

One way around this is to continue compiling the MLIR code down through LLVM to machine code, running it, and asserting something about the output (presumably printed to stdout). While this is possible in lit, since RUN can run anything, it does require pulling in quite a few more dependencies. A slightly more lightweight means to achieve this is to use mlir-cpu-runner, which is an interpreter for some of the lowest-level MLIR dialects (in particular, the llvm dialect, which is the “exit” dialect before going to LLVM).

Here’s what such a test might look like in lit for our ctlz lowering pass, which tests that 7, as a 32-bit integer, has 29 leading zeros. I added the test in this commit. Notably, I had to add the mlir-cpu-runner binary to that list of test_utilities mentioned in the previous section, or else the test will fail with the inability to find the mlir-cpu-runner binary.

// RUN: mlir-opt %s \
// RUN:   --pass-pipeline="builtin.module( \
// RUN:      convert-math-to-funcs{convert-ctlz}, \
// RUN:      func.func(convert-scf-to-cf,convert-arith-to-llvm), \
// RUN:      convert-func-to-llvm, \
// RUN:      convert-cf-to-llvm, \
// RUN:      reconcile-unrealized-casts)" \
// RUN: | mlir-cpu-runner -e test_7i32_to_29 -entry-point-result=i32 > %t
// RUN: FileCheck %s --check-prefix=CHECK_TEST_7i32_TO_29 < %t

func.func @test_7i32_to_29() -> i32 {
%arg = arith.constant 7 : i32
%0 = math.ctlz %arg : i32
func.return %0 : i32
}
// CHECK_TEST_7i32_TO_29: 29


The RUN command is quite a bit more complicated. First, we need to run more passes than just convert-math-to-funcs in order to get the code down to the LLVM dialect, which is what mlir-cpu-runner supports. The --pass-pipeline flag allows you to build a more complex chain of passes on the command line. Then the result is piped to mlir-cpu-runner, which takes as command line flags the top level function to run and the type of the result. Finally, the output is piped to %t, which is a lit substitution magic that represents a per-test temporary file. In this case, it is used so that if this first command fails, the error message from that is displayed in the test failure, rather than the subsequent failure of FileCheck to parse an empty input from the pipe.

Then, a second RUN command runs FileCheck, again using the current file for the test assertion, piping the input to test as %t, and adding the special --check-prefix flag so that it only runs a subset of CHECK assertions in the file (allowing us to add a second test in the same file, as in the next commit that runs a similar test for an i64 input). Then, because mlir-cpu-runner prints the result of the function to stdout, the CHECK assertion just expects the output to be 29 for input 7.

It may also be interesting to the reader to see what MLIR outputs when I run the full pass (but not the mlir-cpu-runner on the input. Here it is:

module attributes {llvm.data_layout = ""} {
llvm.func @test_7i32_to_29() -> i32 {
%0 = llvm.mlir.constant(7 : i32) : i32
%1 = llvm.mlir.constant(29 : i32) : i32
llvm.return %1 : i32
}
llvm.func linkonce_odr @__mlir_math_ctlz_i32(%arg0: i32) -> i32 attributes {sym_visibility = "private"} {
%0 = llvm.mlir.constant(32 : i32) : i32
%1 = llvm.mlir.constant(0 : i32) : i32
%2 = llvm.icmp "eq" %arg0, %1 : i32
llvm.cond_br %2, ^bb1, ^bb2
^bb1:  // pred: ^bb0
llvm.br ^bb10(%0 : i32)
^bb2:  // pred: ^bb0
%3 = llvm.mlir.constant(1 : index) : i64
%4 = llvm.mlir.constant(1 : i32) : i32
%5 = llvm.mlir.constant(32 : index) : i64
%6 = llvm.mlir.constant(0 : i32) : i32
llvm.br ^bb3(%3, %arg0, %6 : i64, i32, i32)
^bb3(%7: i64, %8: i32, %9: i32):  // 2 preds: ^bb2, ^bb8
%10 = llvm.icmp "slt" %7, %5 : i64
llvm.cond_br %10, ^bb4, ^bb9
^bb4:  // pred: ^bb3
%11 = llvm.icmp "slt" %8, %1 : i32
llvm.cond_br %11, ^bb5, ^bb6
^bb5:  // pred: ^bb4
llvm.br ^bb7(%8, %9 : i32, i32)
^bb6:  // pred: ^bb4
%12 = llvm.add %9, %4  : i32
%13 = llvm.shl %8, %4  : i32
llvm.br ^bb7(%13, %12 : i32, i32)
^bb7(%14: i32, %15: i32):  // 2 preds: ^bb5, ^bb6
llvm.br ^bb8
^bb8:  // pred: ^bb7
%16 = llvm.add %7, %3  : i64
llvm.br ^bb3(%16, %14, %15 : i64, i32, i32)
^bb9:  // pred: ^bb3
llvm.br ^bb10(%9 : i32)
^bb10(%17: i32):  // 2 preds: ^bb1, ^bb9
llvm.br ^bb11
^bb11:  // pred: ^bb10
llvm.return %17 : i32
}
}


The main new thing here, besides all of the llvm dialect operations, is the ^bb1 syntax, which is the label identifier for those basic block syntax structures mentioned earlier. With the basic syntax and testing down, next time we will define a custom lowering and explore the MLIR API from that perspective. Then we’ll dive into defining a new dialect.

Thanks to Patrick Schmidt for feedback on a draft of this article.