# MLIR — A Global Optimization and Dataflow Analysis

In this article we’ll implement a global optimization pass, and show how to use the dataflow analysis framework to verify the results of our optimization.

## The noisy arithmetic problem

This demonstration is based on a simplified model of computation relevant to the HEIR project. You don’t need to be familiar with that project to follow this article, but if you’re wondering why someone would ever want the kind of optimization I’m going to write, that project is why.

The basic model is “noisy integer arithmetic.” That is, a program can have integer types of bounded width, and each integer is associated with some unknown (but bounded with high probability) symmetric random noise. You can imagine the “real” integer being the top 5 bits of a 32-bit integer, and the bottom 27 bits storing the noise. When a new integer is created, it magically has a random signed 12-bit integer added to it. When you apply operations to combine integers, the noise grows. Adding two integers adds their noise, and at worst you get one more bit of noise. Scaling an integer by a statically-known constant scales the noise by a constant. Multiplying two integers multiplies their noise values, and you get twice the bits of noise. As long as your program stays below 27 bits of noise, you can still “recover” the original 5-bit integer at the end of the program. Such a program is called legal, and otherwise, the output is random junk and the program is called illegal.

Finally, there is an expensive operation called reduce_noise that can explicitly reduce the noise of a noisy integer back to the base level of 12 bits. This operation has a statically known cost relative to the standard operations.

Note that starting from two noisy integers, each with 12 bits of noise, a single multiplication op brings you jarringly close to ruin. You would have at most 24 bits of noise, which is close to the maximum of 26. But the input IR we start with may have arbitrary computations that do not respect the noise limits. The goal of our optimization is to rewrite an MLIR region so that the noisy integer math never exceeds the noise limit at any given step.

A trivial way to do that would be to insert reduce_noise ops greedily, whenever an op would bring the noise of a value too high. Inserting such ops may be necessary, but a more suitable goal would be to minimize the overall cost of the program, subject to the constraint that the noisy arithmetic is legal. One could do this by inserting reduce_noise ops more strategically, or by rewriting the program to reduce the need for reduce_noise ops, or both. We’ll focus on the former: finding the best place to insert reduce_noise ops without rewriting the rest of the program.

## The noisy dialect

We previously wrote about defining a new dialect, and the noisy dialect we created for this article has little new to show. This commit defines the types and ops, hard coding the 32-bit width and 5-bit semantic input type for simplicity, as well as the values of 12 bits of initial noise and 26 bits of max noise.

Note that the noise bound is not expressed on the type as an attribute. If it were, we’d run into a few problems: first, whenever you insert a reduce_noise op, you’d have to update the types on all the downstream ops. Second, it would prevent you from expressing control flow, since the noise bound cannot be statically inferred from the source code when there are two possible paths that could result in different noise values.

So instead, we need a way to compute the noise values, and associate them with each SSA value, and account for control flow. This is what an analysis pass is designed to do.

## An analysis pass is just a class

The typical use of an analysis pass is to construct a data structure that encodes global information about a program, which can then be re-used during different parts of a pass. I imagined there would be more infrastructure around analysis passes in MLIR, but it’s quite simple. You define a C++ class with a constructor that takes an Operation *, and construct it basically whenever you need it. The only infrastructure for it involves storing and caching the constructed analysis within a pass, and understanding when an analysis needs to be recomputed (always between passes, by default).

By way of example, over at the HEIR project I made a simple analysis that chooses a unique variable name for every SSA value in a program, which I then used to generate code in an output language that needed variable names.

For this article we’ll see two analysis passes. One will formulate and solve the optimization problem that decides where to insert reduce_noise operations. This will be one of the “class that does anything” kind of analysis pass. The other analysis pass will rely on MLIR’s data flow analysis framework to propagate the noise model through the IR. This one will actually not require us to write an analysis from scratch, but instead will be implemented by means of the existing IntegerRangeAnalysis, which only requires us to implement an interface on each op that describes how the op affects the noise. This will be used in our pass to verify that the inserted reduce_noise operations ensure, if nothing else, that the noise never exceeds the maximum allowable noise.

## Reusing IntegerRangeAnalysis

Data flow analysis is a classical static analysis technique for propagating information through a program’s IR. It is one part of Frances Allen’s Turing Award. This article gives a good introduction and additional details, but I will paraphrase it briefly here in the context of IntegerRangeAnalysis.

The basic idea is that you want to get information about what possible values an integer-typed value can have at any point in a given program. If you see x = 7, then you know exactly what x is. If you see something like

func (%x : i8) {
%1 = arith.extsi %x : i8 to i32
%2 = arith.addi %x, %x : i32
}


then you know that %2 can be at most a signed 9-bit integer, because it started as an 8-bit integer, and adding two such integers together can’t fill up more than one extra bit.

In such cases, one can find optimizations, like the int-range-optimizations pass in MLIR, which looks at comparison ops arith.cmpi and determines if it can replace them with constants. It does this by looking at the integer range analysis for the two operands. E.g., given the op x > y, if you know y‘s maximum value is less than x‘s minimum value, then you can replace it with a constant true.

Computing the data flow analysis requires two ingredients called a transfer function and a join operation. The transfer function describes what the output integer range should be for a given op and a given set of input integer ranges. This can be an arbitrary function. The join operation describes how to combine two or more integer ranges when you get to a point at the program in which different branches of control flow merge. For example,

def fn(branch):
x = 7
if branch:
y = x * x
else:
y = 2*x
return y


The value of y just before returning cannot be known exactly, but in one branch you know it’s 14, and in another it’s 49. So the final value of y could be estimated as being in the range [14, 49]. Here the join function computes the smallest integer range containing both estimates. [Aside: it could instead use the set {14, 49} to be more precise, but that is not what IntegerRangeAnalysis happens to do]

In order for a data flow analysis to work properly, the values being propagated and the join function must together form a semilattice, which is a partially-ordered set in which every two elements have a least upper bound, that upper bound is computed by join, and join itself must be associative, commutative, and idempotent. For dataflow analysis, the semilattice must also be finite. This is often expressed by having distinct “top” and “bottom” elements as defaults. “Top” represents “could be anything,” and sometimes expresses that a more precise bound would be too computationally expensive to continue to track. “Bottom” usually represents an uninitialized value.

Once you have this, then MLIR provides a general algorithm to propagate values through the IR via a technique called Kildall’s method, which iteratively updates the SSA values, applying the transfer function and joining at the merging of control flow paths, until the process reaches a fixed point.

Here are MLIR’s official docs on dataflow analysis, and here is the RFC where the current data flow solver framework was introduced. In our situation, we want to use the solver framework with the existing IntegerRangeAnalysis, which only asks that we implement the transfer function by implementing InferIntRangeInterface on our ops.

This commit does just that. This requires adding the DeclareOpInterfaceMethods<InferIntRangeInterface> to all relevant ops. That in turn generates function declarations for

void MyOp::inferResultRanges(
ArrayRef<ConstantIntRanges> inputRanges, SetIntRangeFn setResultRange);


The ConstantIntRange is a dataclass holding a min and max integer value. inputRanges represents the known bounds on the inputs to the operation in question, and SetIntRangeFn is the callback used to produce the result.

For example, for AddOp we can implement it as

ConstantIntRanges unionPlusOne(ArrayRef<ConstantIntRanges> inputRanges) {
auto lhsRange = inputRanges[0];
auto rhsRange = inputRanges[1];
auto joined = lhsRange.rangeUnion(rhsRange);
return ConstantIntRanges::fromUnsigned(joined.umin(), joined.umax() + 1);
}

SetIntRangeFn setResultRange) {
setResultRange(getResult(), unionPlusOne(inputRanges));
}


A MulOp is similarly implemented by summing the maxes. Meanwhile, EncodeOp and ReduceNoiseOp each set the initial range to [0, 12]. So the min will always be zero, and we really only care about the max.

The next commit defines an empty pass that will contain our analyses and optimizations, and this commit shows how the integer range analysis is used to validate an IR’s noise growth. In short, you load the IntegerRangeAnalysis and its dependent DeadCodeAnalysis, run the solver, and then walk the IR, asking the solver via lookupState to give the resulting value range for each op’s result, and comparing it against the maximum.

void runOnOperation() {
Operation *module = getOperation();

DataFlowSolver solver;
if (failed(solver.initializeAndRun(module)))
signalPassFailure();

auto result = module->walk([&](Operation *op) {
noisy::ReduceNoiseOp>(*op)) {
}
const dataflow::IntegerValueRangeLattice *opRange =
solver.lookupState<dataflow::IntegerValueRangeLattice>(
op->getResult(0));
if (!opRange || opRange->getValue().isUninitialized()) {
op->emitOpError()
<< "Found op without a set integer range; did the analysis fail?";
return WalkResult::interrupt();
}

ConstantIntRanges range = opRange->getValue().getValue();
if (range.umax().getZExtValue() > MAX_NOISE) {
op->emitOpError() << "Found op after which the noise exceeds the "
"allowable maximum of "
<< MAX_NOISE
<< "; it was: " << range.umax().getZExtValue()
<< "\n";
return WalkResult::interrupt();
}

});

if (result.wasInterrupted())
signalPassFailure();


Finally, in this commit we add a test that exercises it:

func.func @test_op_syntax() -> i5 {
%0 = arith.constant 3 : i5
%1 = arith.constant 4 : i5
%2 = noisy.encode %0 : i5 -> !noisy.i32
%3 = noisy.encode %1 : i5 -> !noisy.i32
%4 = noisy.mul %2, %3 : !noisy.i32
%5 = noisy.mul %4, %4 : !noisy.i32
%6 = noisy.mul %5, %5 : !noisy.i32
%7 = noisy.mul %6, %6 : !noisy.i32
%8 = noisy.decode %7 : !noisy.i32 -> i5
return %8 : i5
}


Running tutorial-opt --noisy-reduce-noise on this file produces the following error:

 error: 'noisy.mul' op Found op after which the noise exceeds the allowable maximum of 26; it was: 48

%5 = noisy.mul %4, %4 : !noisy.i32
^
mlir-tutorial/tests/noisy_reduce_noise.mlir:11:8: note: see current operation: %5 = "noisy.mul"(%4, %4) : (!noisy.i32, !noisy.i32) -> !noisy.i32


And if you run in debug mode with --debug --debug-only=int-range-analysis, you will see the per-op propagations printed to the terminal

$bazel run tools:tutorial-opt -- --noisy-reduce-noise-optimizer$PWD/tests/noisy_reduce_noise.mlir --debug --debug-only=int-range-analysis
Inferring ranges for %c3_i5 = arith.constant 3 : i5
Inferred range unsigned : [3, 3] signed : [3, 3]
Inferring ranges for %c4_i5 = arith.constant 4 : i5
Inferred range unsigned : [4, 4] signed : [4, 4]
Inferring ranges for %0 = noisy.encode %c3_i5 : i5 -> !noisy.i32
Inferred range unsigned : [0, 12] signed : [0, 12]
Inferring ranges for %1 = noisy.encode %c4_i5 : i5 -> !noisy.i32
Inferred range unsigned : [0, 12] signed : [0, 12]
Inferring ranges for %2 = noisy.mul %0, %1 : !noisy.i32
Inferred range unsigned : [0, 24] signed : [0, 24]
Inferring ranges for %3 = noisy.mul %2, %2 : !noisy.i32
Inferred range unsigned : [0, 48] signed : [0, 48]
Inferring ranges for %4 = noisy.mul %3, %3 : !noisy.i32
Inferred range unsigned : [0, 96] signed : [0, 96]
Inferring ranges for %5 = noisy.mul %4, %4 : !noisy.i32
Inferred range unsigned : [0, 192] signed : [0, 192]


As a quick aside, there was one minor upstream problem preventing me from reusing IntegerRangeAnalysis, which I patched in https://github.com/llvm/llvm-project/pull/72007. This means I also had to update the LLVM commit hash used by this project in this commit.

## An ILP optimization pass

Next, we build an analysis that solves a global optimization to insert reduce_noise ops efficiently. As mentioned earlier, this is a “do anything” kind of analysis, so we put all of the logic into the analysis’s construtor.

[Aside: I wouldn’t normally do this, because constructors don’t have return values so it’s hard to signal failure; but the API for the analysis specifies the constructor takes as input the Operation * to analyze, and I would expect any properly constructed object to be “ready to use.” Maybe someone who knows C++ better will comment and shed some wisdom for me.]

This commit sets up the analysis shell and interface.

class ReduceNoiseAnalysis {
public:
ReduceNoiseAnalysis(Operation *op);
~ReduceNoiseAnalysis() = default;

/// Return true if a reduce_noise op should be inserted after the given
/// operation, according to the solution to the optimization problem.
bool shouldInsertReduceNoise(Operation *op) const {
return solution.lookup(op);
}

private:
llvm::DenseMap<Operation *, bool> solution;
};


This commit adds a workspace dependency on Google’s or-tools package (“OR” stands for Operations Research here, a.k.a. discrete optimization), which comes bundled with a number of nice solvers, and an API for formulating optimization problems. And this commit implements the actual solver model.

Now this model is quite a bit of code, and this article is not the best place to give a full-fledged introduction to linear programming, modeling techniques, or the OR-tools API. What I’ll do instead is explain the model in detail here, give a few small notes on how that translates to the OR-tools C++ API. If you want a gentler background on linear programming, see my article series about diet optimization (part 1, part 2).

All linear programs specify a linear function as an objective to minimize, along with a set of linear equalities and inequalities that constrain the solution. In a standard linear program, the variables must be continuously valued. In a mixed-integer linear program, some of those variables are allowed to be discrete integers, which, it turns out, makes it possible to solve many more problems, but requires completely different optimization techniques and may result in exponentially slow runtime. So many techniques in operations research relate to modeling a problem in such a way that the number of integer variables is relatively small.

Our linear model starts by defining some basic variables. Some variables in the model represent “decisions” that we can make, and others represent “state” that reacts to the decisions via constraints.

• For each operation $x$, a $\{0, 1\}$-valued variable $\textup{InsertReduceNoise}_x$. Such a variable is 1 if and only if we insert a reduce_noise op after the operation $x$.
• For each SSA-value $v$ that is input or output to a noisy op, a continuous-valued variable $\textup{NoiseAt}_v$. This represents the upper bound of the noise at value $v$.

In particular, the solver’s performance will get worse as the number of binary variables increases, which in this case corresponds to the number of noisy ops.

The objective function, with a small caveat explained later, is simply the sum of the decision variables, and we’d like to minimize it. Each reduce_noise op is considered equally expensive, and there is no special nuance here about scheduling them in parallel or in serial.

Next, we add constraints. First, $0 \leq \textup{NoiseAt}_v \leq 26$, which asserts that no SSA value can exceed the max noise. Second, we need to enforce that an encode op fixes the noise of its output to 12, i.e., for each encode op $x$, we add the constraint $\textup{NoiseAt}_{\textup{result}(x)} = 12$.

Finally, we need constraints that say that if you choose to insert a reduce_noise op, then the noise is reset to 12, otherwise it is set to the appropriate function of the inputs. This is where the modeling gets a little tricky, but multiplication is easier so let’s start there.

Fix a multiplication op $x$, its two input SSA values $\textup{LHS}, \textup{RHS}$, and its output $\textup{RES}$. As a piecewise function, we want a constraint like:

$\textup{NoiseAt}_\textup{RES} = \begin{cases} \textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS} & \text{ if } \textup{InsertReduceNoise}_x = 0 \\ 12 & \text{ if } \textup{InsertReduceNoise}_x = 1 \end{cases}$

This isn’t linear, but we can combine the two branches to

\begin{aligned} \textup{NoiseAt}_\textup{RES} &= (1 – \textup{ InsertReduceNoise}_x) (\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) \\ & + 12 \textup{ InsertReduceNoise}_x \end{aligned}

This does the classic trick of using a bit as a controlled multiplexer, but it’s still not linear. We can make it linear, however, by replacing this one constraint with four constraints, and an auxiliary constant $C=100$ that we know is larger than the possible range of values that the $\textup{NoiseAt}_v$ variables can attain. Those four linear constraints are:

\begin{aligned} \textup{NoiseAt}_\textup{RES} &\geq 12 \textup{ InsertReduceNoise}_x \\ \textup{NoiseAt}_\textup{RES} &\leq 12 + C(1 – \textup{InsertReduceNoise}_x) \\ \textup{NoiseAt}_\textup{RES} &\geq (\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) – C \textup{ InsertReduceNoise}_x \\ \textup{NoiseAt}_\textup{RES} &\leq (\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) + C \textup{ InsertReduceNoise}_x \\ \end{aligned}

Setting the decision variable to zero results in the first two equations being trivially satisfied. Setting it to 1 causes the first two equations to be equivalent to $\textup{NoiseAt}_\textup{RES} = 12$. Likewise, the second two constraints are trivial when the decision variable is 1, and force the output noise to be equal to the sum of the two input noises when set to zero.

The addition op is handled similarly, except that the term $(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS})$ is replaced by something non-linear, namely $1 + \max(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS})$. We can still handle that, but it requires an extra modeling trick. We introduce a new variable $Z_x$ for each add op $x$, and two constraints:

\begin{aligned} Z_v &\geq 1 + \textup{NoiseAt}_{LHS} \\ Z_v &\geq 1 + \textup{NoiseAt}_{RHS} \end{aligned}

Together these ensure that $Z_v$ is at least 1 plus the max of the two input noises, but it doesn’t force equality. To achieve that, we add $Z_v$ to the minimization objective (alongside the sum of the decision variables) with a small penalty to ensure the solver tries to minimize them. Since they have trivially minimal values equal to “1 plus the max,” the solver will have no trouble optimizing them, and this will be effectively an equality constraint.

[Aside: Whenever you do this trick, you have to convince yourself that the solver won’t somehow be able to increase $Z_v$ as a trade-off against lower values of other objective terms, and produce a lower overall objective value. Solvers are mischievous and cannot be trusted. In our case, there is no risk: if you were to increase $Z_v$ below its minimum value, that would only increase the noise propagation through add ops, meaning the solver would have to compensate by potentially adding even more reduce_noise ops!]

Then, the constraint for an add op uses $Z_v$ in place of $(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS})$ the mul op.

The only other minor aspect of this solver model is that these constraints enforce the consistency of the noise propagation after a reduce_noise op may be added, but if a reduce_noise op is added, it doesn’t necessarily enforce the noise growth of the output of the op before it’s input to reduce_noise. We can achieve this by adding new constraints expressing $(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) leq 26$ and $Z_v \leq 26$ for multiplication and addition ops, respectively.

When converting this to the OR-tools C++ API, as we did in this commit, a few minor things to note:

• You can specify upper and lower bounds on a variable at variable creation time, rather than as separate constraints. You’ll see this in solver->MakeNumVar(min, max, name).
• Constraints must be specified in the form min <= expr <= max, where min and max are constants and expr is a linear combination of variables, meaning that one has to manually re-arrange and simplify all the equations above so the variables are all on one side and the constants on the other. (The OR-tools Python API is more expressive, but we don’t have it here.)
• The constraints and the objective are specified by SetCoefficient, which sets the coefficient of a variable in a linear combination one at a time.

Finally, this commit implements the part of the pass that uses the solver’s output to insert new reduce_noise ops. And this commit adds some more tests.

An example of its use:

// This test checks that the solver can find a single insertion point
// for a reduce_noise op that handles two branches, each of which would
// also need a reduce_noise op if handled separately.
func.func @test_single_insertion_branching() -> i5 {
%0 = arith.constant 3 : i5
%1 = arith.constant 4 : i5
%2 = noisy.encode %0 : i5 -> !noisy.i32
%3 = noisy.encode %1 : i5 -> !noisy.i32
// Noise: 12
%4 = noisy.mul %2, %3 : !noisy.i32
// Noise: 24

// branch 1
%b1 = noisy.add %4, %3 : !noisy.i32
// Noise: 25
%b2 = noisy.add %b1, %3 : !noisy.i32
// Noise: 25
%b3 = noisy.add %b2, %3 : !noisy.i32
// Noise: 26
%b4 = noisy.add %b3, %3 : !noisy.i32
// Noise: 27

// branch 2
%c1 = noisy.sub %4, %2 : !noisy.i32
// Noise: 25
%c2 = noisy.sub %c1, %3 : !noisy.i32
// Noise: 25
%c3 = noisy.sub %c2, %3 : !noisy.i32
// Noise: 26
%c4 = noisy.sub %c3, %3 : !noisy.i32
// Noise: 27

%x1 = noisy.decode %b4 : !noisy.i32 -> i5
%x2 = noisy.decode %c4 : !noisy.i32 -> i5
%x3 = arith.addi %x1, %x2 : i5
return %x3 : i5
}


And the output:

  func.func @test_single_insertion_branching() -> i5 {
%c3_i5 = arith.constant 3 : i5
%c4_i5 = arith.constant 4 : i5
%0 = noisy.encode %c3_i5 : i5 -> !noisy.i32
%1 = noisy.encode %c4_i5 : i5 -> !noisy.i32
%2 = noisy.mul %0, %1 : !noisy.i32
%3 = noisy.reduce_noise %2 : !noisy.i32
%4 = noisy.add %3, %1 : !noisy.i32
%5 = noisy.add %4, %1 : !noisy.i32
%6 = noisy.add %5, %1 : !noisy.i32
%7 = noisy.add %6, %1 : !noisy.i32
%8 = noisy.sub %3, %0 : !noisy.i32
%9 = noisy.sub %8, %1 : !noisy.i32
%10 = noisy.sub %9, %1 : !noisy.i32
%11 = noisy.sub %10, %1 : !noisy.i32
%12 = noisy.decode %7 : !noisy.i32 -> i5
%13 = noisy.decode %11 : !noisy.i32 -> i5
%14 = arith.addi %12, %13 : i5
return %14 : i5
}


# MLIR — Lowering through LLVM

In the last article we lowered our custom poly dialect to standard MLIR dialects. In this article we’ll continue lowering it to LLVM IR, exporting it out of MLIR to LLVM, and then compiling to x86 machine code.

## Defining a Pipeline

The first step in lowering to machine code is to lower to an “exit dialect.” That is, a dialect from which there is a code-gen tool that converts MLIR code to some non-MLIR format. In our case, we’re targeting x86, so the exit dialect is the LLVM dialect, and the code-gen tool is the binary mlir-translate (more on that later). Lowering to an exit dialect, as it turns out, is not all that simple.

One of the things I’ve struggled with when learning MLIR is how to compose all the different lowerings into a pipeline that ends in the result I want, especially when other people wrote those pipelines. When starting from a high level dialect like linalg (linear algebra), there can be dozens of lowerings involved, and some of them can reintroduce ops from dialects you thought you completely lowered already, or have complex pre-conditions or relationships to other lowerings. There are also simply too many lowerings to easily scan.

There are two ways to specify a lowering from the MLIR binary. One is completely on the command line. You can use the --pass-pipeline flag with a tiny DSL, like this

bazel run //tools:tutorial-opt -- foo.mlir \
--pass-pipeline='builtin.module(func.func(cse,canonicalize),convert-func-to-llvm)'


Above the thing that looks like a function call is “anchoring” a sequences of passes to operate on a particular op (allowing them to run in parallel across ops).

Thankfully, you can also declare the pipeline in code, and wrap it up in a single flag. The above might be equivalently defined as follows:

void customPipelineBuilder(mlir::OpPassManager &pm) {
pm.addPass(createConvertFuncToLLVMPass());  // runs on builtin.module by default
}

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
<... register dialects ...>

mlir::PassPipelineRegistration<>(
"my-pipeline", "A custom pipeline", customPipelineBuilder);

return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Tutorial opt main", registry));
}


We’ll do the actually-runnable analogue of this in the next section when we lower the poly dialect to LLVM.

## Lowering Poly to LLVM

In this section we’ll define a pipeline lowering poly to LLVM and show the MLIR along each step. Strap in, there’s going to be a lot of MLIR code in this article.

The process I’ve used to build up a big pipeline is rather toilsome and incremental. Basically, start from an empty pipeline and the starting MLIR, then look for the “highest level” op you can think of, add a pass that lowers it to the pipeline, and if that pass fails, figure out what pass is required before it. Then repeat until you have achieved your target.

In this commit we define a pass pipeline --poly-to-llvm that includes only the --poly-to-standard pass defined last time, along with a canonicalization pass. Then we start from this IR:

$cat$PWD/tests/poly_to_llvm.mlir

func.func @test_poly_fn(%arg : i32) -> i32 {
%tens = tensor.splat %arg : tensor<10xi32>
%input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
%0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%1 = poly.add %0, %input : !poly.poly<10>
%2 = poly.mul %1, %1 : !poly.poly<10>
%3 = poly.sub %2, %input : !poly.poly<10>
%4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
return %4 : i32
}

$bazel run //tools:tutorial-opt -- --poly-to-llvm$PWD/tests/poly_to_llvm.mlir

module {
func.func @test_poly_fn(%arg0: i32) -> i32 {
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : tensor<10xi32>
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
%splat = tensor.splat %arg0 : tensor<10xi32>
^bb0(%arg1: index):
tensor.yield %c0_i32 : i32
} : tensor<3xi32> to tensor<10xi32>
%1 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
%4 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
%5 = arith.addi %arg1, %arg3 : index
%6 = arith.remui %5, %c10 : index
%extracted = tensor.extract %0[%arg3] : tensor<10xi32>
%extracted_1 = tensor.extract %0[%arg1] : tensor<10xi32>
%7 = arith.muli %extracted_1, %extracted : i32
%extracted_2 = tensor.extract %arg4[%6] : tensor<10xi32>
%8 = arith.addi %7, %extracted_2 : i32
%inserted = tensor.insert %8 into %arg4[%6] : tensor<10xi32>
scf.yield %inserted : tensor<10xi32>
}
scf.yield %4 : tensor<10xi32>
}
%2 = arith.subi %1, %splat : tensor<10xi32>
%3 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
%4 = arith.subi %c11, %arg1 : index
%5 = arith.muli %arg0, %arg2 : i32
%extracted = tensor.extract %2[%4] : tensor<10xi32>
%6 = arith.addi %5, %extracted : i32
scf.yield %6 : i32
}
return %3 : i32
}
}


Let’s naively start from the top and see what happens. One way you can try this more interactively is to run a tentative pass to add like bazel run //tools:tutorial-opt -- --poly-to-llvm --pass-to-try PWD/tests/poly_to_llvm.mlir, and it will run the hard-coded pipeline followed by the new pass. The first op that can be lowered is func.func, and there is a convert-func-to-llvm pass, which we add in this commit. It turns out to process a lot more than just the func op: module attributes {llvm.data_layout = ""} { llvm.func @test_poly_fn(%arg0: i32) -> i32 { %0 = llvm.mlir.constant(11 : index) : i64 %1 = builtin.unrealized_conversion_cast %0 : i64 to index %2 = llvm.mlir.constant(1 : index) : i64 %3 = builtin.unrealized_conversion_cast %2 : i64 to index %4 = llvm.mlir.constant(10 : index) : i64 %5 = builtin.unrealized_conversion_cast %4 : i64 to index %6 = llvm.mlir.constant(0 : index) : i64 %7 = builtin.unrealized_conversion_cast %6 : i64 to index %cst = arith.constant dense<0> : tensor<10xi32> %8 = llvm.mlir.constant(0 : i32) : i32 %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32> %splat = tensor.splat %arg0 : tensor<10xi32> %padded = tensor.pad %cst_0 low[0] high[7] { ^bb0(%arg1: index): tensor.yield %8 : i32 } : tensor<3xi32> to tensor<10xi32> %9 = arith.addi %padded, %splat : tensor<10xi32> %10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) { %13 = builtin.unrealized_conversion_cast %arg1 : index to i64 %14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) { %15 = builtin.unrealized_conversion_cast %arg3 : index to i64 %16 = llvm.add %13, %15 : i64 %17 = llvm.urem %16, %4 : i64 %18 = builtin.unrealized_conversion_cast %17 : i64 to index %extracted = tensor.extract %9[%arg3] : tensor<10xi32> %extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32> %19 = llvm.mul %extracted_1, %extracted : i32 %extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32> %20 = llvm.add %19, %extracted_2 : i32 %inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32> scf.yield %inserted : tensor<10xi32> } scf.yield %14 : tensor<10xi32> } %11 = arith.subi %10, %splat : tensor<10xi32> %12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) { %13 = builtin.unrealized_conversion_cast %arg1 : index to i64 %14 = llvm.sub %0, %13 : i64 %15 = builtin.unrealized_conversion_cast %14 : i64 to index %16 = llvm.mul %arg0, %arg2 : i32 %extracted = tensor.extract %11[%15] : tensor<10xi32> %17 = llvm.add %16, %extracted : i32 scf.yield %17 : i32 } llvm.return %12 : i32 } }  Notably, this pass converted most of the arithmetic operations—though not the tensor-generating ones that use the dense attribute—and inserted a number of unrealized_conversion_cast ops for the resulting type conflicts, which we’ll have to get rid of eventually, but we can’t now because the values on both sides of the type conversion are still used. Next we’ll lower arith to LLVM using the suggestively-named convert-arith-to-llvm in this commit. However, it has no effect on the resulting IR, and the arith ops remain. What gives? It turns out that arith ops that operate on tensors are not supported by convert-arith-to-llvm. To deal with this, we need a special pass called convert-elementwise-to-linalg, which lowers these ops to linalg.generic ops. (I plan to cover linalg.generic in a future tutorial). We add it in this commit, and this is the diff between the above IR and the new output (right > is the new IR): > #map = affine_map<(d0) -> (d0)> 19c20,24 < %9 = arith.addi %padded, %splat : tensor<10xi32> --- > %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) { > ^bb0(%in: i32, %in_1: i32, %out: i32): > %13 = arith.addi %in, %in_1 : i32 > linalg.yield %13 : i32 > } -> tensor<10xi32> 37c42,46 < %11 = arith.subi %10, %splat : tensor<10xi32> --- > %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) { > ^bb0(%in: i32, %in_1: i32, %out: i32): > %13 = arith.subi %in, %in_1 : i32 > linalg.yield %13 : i32 > } -> tensor<10xi32> 49a59 >  This is nice, but now you can see we have a new problem: lowering linalg, and re-lowering the arith ops that were inserted by the convert-elementwise-to-linalg pass. So adding back the arith pass at the end in this commit replaces the two inserted arith ops, giving this IR #map = affine_map<(d0) -> (d0)> module attributes {llvm.data_layout = ""} { llvm.func @test_poly_fn(%arg0: i32) -> i32 { %0 = llvm.mlir.constant(11 : index) : i64 %1 = builtin.unrealized_conversion_cast %0 : i64 to index %2 = llvm.mlir.constant(1 : index) : i64 %3 = builtin.unrealized_conversion_cast %2 : i64 to index %4 = llvm.mlir.constant(10 : index) : i64 %5 = builtin.unrealized_conversion_cast %4 : i64 to index %6 = llvm.mlir.constant(0 : index) : i64 %7 = builtin.unrealized_conversion_cast %6 : i64 to index %cst = arith.constant dense<0> : tensor<10xi32> %8 = llvm.mlir.constant(0 : i32) : i32 %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32> %splat = tensor.splat %arg0 : tensor<10xi32> %padded = tensor.pad %cst_0 low[0] high[7] { ^bb0(%arg1: index): tensor.yield %8 : i32 } : tensor<3xi32> to tensor<10xi32> %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) { ^bb0(%in: i32, %in_1: i32, %out: i32): %13 = llvm.add %in, %in_1 : i32 linalg.yield %13 : i32 } -> tensor<10xi32> %10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) { %13 = builtin.unrealized_conversion_cast %arg1 : index to i64 %14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) { %15 = builtin.unrealized_conversion_cast %arg3 : index to i64 %16 = llvm.add %13, %15 : i64 %17 = llvm.urem %16, %4 : i64 %18 = builtin.unrealized_conversion_cast %17 : i64 to index %extracted = tensor.extract %9[%arg3] : tensor<10xi32> %extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32> %19 = llvm.mul %extracted_1, %extracted : i32 %extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32> %20 = llvm.add %19, %extracted_2 : i32 %inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32> scf.yield %inserted : tensor<10xi32> } scf.yield %14 : tensor<10xi32> } %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) { ^bb0(%in: i32, %in_1: i32, %out: i32): %13 = llvm.sub %in, %in_1 : i32 linalg.yield %13 : i32 } -> tensor<10xi32> %12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) { %13 = builtin.unrealized_conversion_cast %arg1 : index to i64 %14 = llvm.sub %0, %13 : i64 %15 = builtin.unrealized_conversion_cast %14 : i64 to index %16 = llvm.mul %arg0, %arg2 : i32 %extracted = tensor.extract %11[%15] : tensor<10xi32> %17 = llvm.add %16, %extracted : i32 scf.yield %17 : i32 } llvm.return %12 : i32 } }  However, there are still two arith ops remaining: arith.constant using dense attributes to define tensors. There’s no obvious pass that handles this from looking at the list of available passes. As we’ll see, it’s the bufferization pass that handles this (we discussed it briefly in the previous article), and we should try to lower as much as possible before bufferizing. In general, bufferizing makes optimization passes harder. The next thing to try lowering is tensor.splat. There’s no obvious pass as well (tensor-to-linalg doesn’t lower it), and searching the LLVM codebase for SplatOp we find this pattern which suggests tensor.splat is also lowered during bufferization, and produces linalg.map ops. So we’ll have to lower those as well eventually. But what tensor-to-linalg does lower is the next op: tensor.pad. It replaces a pad with the following < %padded = tensor.pad %cst_0 low[0] high[7] { < ^bb0(%arg1: index): < tensor.yield %8 : i32 < } : tensor<3xi32> to tensor<10xi32> < %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) { --- > %9 = tensor.empty() : tensor<10xi32> > %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32> > %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32> > %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {  We add that pass in this commit. Next we have all of these linalg ops, as well as the scf.for loops. The linalg ops can lower naively as for loops, so we’ll do that first. There are three options: • -convert-linalg-to-affine-loops : Lower the operations from the linalg dialect into affine loops • -convert-linalg-to-loops : Lower the operations from the linalg dialect into loops • -convert-linalg-to-parallel-loops : Lower the operations from the linalg dialect into parallel loops I want to make this initial pipeline as simple and naive as possible, so convert-linalg-to-loops it is. However, the pass does nothing on this IR (added in this commit). If you run the pass with --debug you can find an explanation:  ** Failure : expected linalg op with buffer semantics  So linalg must be bufferized before it can be lowered to loops. However, we can tackle the scf to LLVM step with a combination of two passes: convert-scf-to-cf and convert-cf-to-llvm, in this commit. The lowering added arith.cmpi for the loop predicates, so moving arith-to-llvm to the end of the pipeline fixes that (commit). In the end we get this IR: #map = affine_map<(d0) -> (d0)> module attributes {llvm.data_layout = ""} { llvm.func @test_poly_fn(%arg0: i32) -> i32 { %0 = llvm.mlir.constant(11 : index) : i64 %1 = builtin.unrealized_conversion_cast %0 : i64 to index %2 = llvm.mlir.constant(1 : index) : i64 %3 = builtin.unrealized_conversion_cast %2 : i64 to index %4 = llvm.mlir.constant(10 : index) : i64 %5 = builtin.unrealized_conversion_cast %4 : i64 to index %6 = llvm.mlir.constant(0 : index) : i64 %7 = builtin.unrealized_conversion_cast %6 : i64 to index %cst = arith.constant dense<0> : tensor<10xi32> %8 = llvm.mlir.constant(0 : i32) : i32 %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32> %splat = tensor.splat %arg0 : tensor<10xi32> %9 = tensor.empty() : tensor<10xi32> %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32> %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32> %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) { ^bb0(%in: i32, %in_4: i32, %out: i32): %43 = llvm.add %in, %in_4 : i32 linalg.yield %43 : i32 } -> tensor<10xi32> cf.br ^bb1(%7, %cst : index, tensor<10xi32>) ^bb1(%12: index, %13: tensor<10xi32>): // 2 preds: ^bb0, ^bb5 %14 = builtin.unrealized_conversion_cast %12 : index to i64 %15 = llvm.icmp "slt" %14, %4 : i64 llvm.cond_br %15, ^bb2, ^bb6 ^bb2: // pred: ^bb1 %16 = builtin.unrealized_conversion_cast %12 : index to i64 cf.br ^bb3(%7, %13 : index, tensor<10xi32>) ^bb3(%17: index, %18: tensor<10xi32>): // 2 preds: ^bb2, ^bb4 %19 = builtin.unrealized_conversion_cast %17 : index to i64 %20 = llvm.icmp "slt" %19, %4 : i64 llvm.cond_br %20, ^bb4, ^bb5 ^bb4: // pred: ^bb3 %21 = builtin.unrealized_conversion_cast %17 : index to i64 %22 = llvm.add %16, %21 : i64 %23 = llvm.urem %22, %4 : i64 %24 = builtin.unrealized_conversion_cast %23 : i64 to index %extracted = tensor.extract %11[%17] : tensor<10xi32> %extracted_1 = tensor.extract %11[%12] : tensor<10xi32> %25 = llvm.mul %extracted_1, %extracted : i32 %extracted_2 = tensor.extract %18[%24] : tensor<10xi32> %26 = llvm.add %25, %extracted_2 : i32 %inserted = tensor.insert %26 into %18[%24] : tensor<10xi32> %27 = llvm.add %19, %2 : i64 %28 = builtin.unrealized_conversion_cast %27 : i64 to index cf.br ^bb3(%28, %inserted : index, tensor<10xi32>) ^bb5: // pred: ^bb3 %29 = llvm.add %14, %2 : i64 %30 = builtin.unrealized_conversion_cast %29 : i64 to index cf.br ^bb1(%30, %18 : index, tensor<10xi32>) ^bb6: // pred: ^bb1 %31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%13, %splat : tensor<10xi32>, tensor<10xi32>) outs(%13 : tensor<10xi32>) { ^bb0(%in: i32, %in_4: i32, %out: i32): %43 = llvm.sub %in, %in_4 : i32 linalg.yield %43 : i32 } -> tensor<10xi32> cf.br ^bb7(%3, %8 : index, i32) ^bb7(%32: index, %33: i32): // 2 preds: ^bb6, ^bb8 %34 = builtin.unrealized_conversion_cast %32 : index to i64 %35 = llvm.icmp "slt" %34, %0 : i64 llvm.cond_br %35, ^bb8, ^bb9 ^bb8: // pred: ^bb7 %36 = builtin.unrealized_conversion_cast %32 : index to i64 %37 = llvm.sub %0, %36 : i64 %38 = builtin.unrealized_conversion_cast %37 : i64 to index %39 = llvm.mul %arg0, %33 : i32 %extracted_3 = tensor.extract %31[%38] : tensor<10xi32> %40 = llvm.add %39, %extracted_3 : i32 %41 = llvm.add %34, %2 : i64 %42 = builtin.unrealized_conversion_cast %41 : i64 to index cf.br ^bb7(%42, %40 : index, i32) ^bb9: // pred: ^bb7 llvm.return %33 : i32 } }  ## Bufferization Last time I wrote at some length about the dialect conversion framework, and how it’s complicated mostly because of the bufferization passes, which were split across passes that type conflicts that required materialization and later resolution. Well, turns out these bufferization passes are now deprecated. These are the official docs, but the short story is that the network of bufferization passes was replaced by a one-shot-bufferize pass, with some boilerplate cleanup passes. This is the sequence of passes that is recommended by the docs, along with an option to ensure that function signatures are bufferized as well:  // One-shot bufferize, from // https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation bufferization::OneShotBufferizationOptions bufferizationOptions; bufferizationOptions.bufferizeFunctionBoundaries = true; manager.addPass( bufferization::createOneShotBufferizePass(bufferizationOptions)); manager.addPass(memref::createExpandReallocPass()); manager.addPass(bufferization::createOwnershipBasedBufferDeallocationPass()); manager.addPass(createCanonicalizerPass()); manager.addPass(bufferization::createBufferDeallocationSimplificationPass()); manager.addPass(bufferization::createLowerDeallocationsPass()); manager.addPass(createCSEPass()); manager.addPass(createCanonicalizerPass());  This pipeline exists upstream in a helper, but it was added after the commit we’ve pinned to. Moreover, some of the passes above were added after the commit we’ve pinned to! At this point it’s worth updating the upstream MLIR commit, which I did in this commit, and it required a few additional fixes to deal with API updates (starting with this commit and ending with this commit). Then this commit adds the one-shot bufferization pass and helpers to the pipeline. Even after all of this, I was dismayed to learn that the pass still did not lower the linalg operators. After a bit of digging, I realized it was because the func-to-llvm pass was running too early in the pipeline. So this commit moves the pass later. Just because we’re building this a bit backwards and naively, let’s comment out the tail end of the pipeline to see the result of bufferization. Before convert-linalg-to-loops, but omitting the rest: #map = affine_map<(d0) -> (d0)> module { memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64} memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64} func.func @test_poly_fn(%arg0: i32) -> i32 { %c0_i32 = arith.constant 0 : i32 %c11 = arith.constant 11 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_10xi32 : memref<10xi32> %1 = memref.get_global @__constant_3xi32 : memref<3xi32> %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32> linalg.map outs(%alloc : memref<10xi32>) () { linalg.yield %arg0 : i32 } %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32> linalg.fill ins(%c0_i32 : i32) outs(%alloc_0 : memref<10xi32>) %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>> memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_0, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_0 : memref<10xi32>) { ^bb0(%in: i32, %in_2: i32, %out: i32): %3 = arith.addi %in, %in_2 : i32 linalg.yield %3 : i32 } %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32> memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32> scf.for %arg1 = %c0 to %c10 step %c1 { scf.for %arg2 = %c0 to %c10 step %c1 { %3 = arith.addi %arg1, %arg2 : index %4 = arith.remui %3, %c10 : index %5 = memref.load %alloc_0[%arg2] : memref<10xi32> %6 = memref.load %alloc_0[%arg1] : memref<10xi32> %7 = arith.muli %6, %5 : i32 %8 = memref.load %alloc_1[%4] : memref<10xi32> %9 = arith.addi %7, %8 : i32 memref.store %9, %alloc_1[%4] : memref<10xi32> } } linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_1, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_1 : memref<10xi32>) { ^bb0(%in: i32, %in_2: i32, %out: i32): %3 = arith.subi %in, %in_2 : i32 linalg.yield %3 : i32 } %2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) { %3 = arith.subi %c11, %arg1 : index %4 = arith.muli %arg0, %arg2 : i32 %5 = memref.load %alloc_1[%3] : memref<10xi32> %6 = arith.addi %4, %5 : i32 scf.yield %6 : i32 } memref.dealloc %alloc : memref<10xi32> memref.dealloc %alloc_0 : memref<10xi32> memref.dealloc %alloc_1 : memref<10xi32> return %2 : i32 } }  After convert-linalg-to-loops, omitting the conversions to LLVM, wherein there are no longer any linalg ops, just loops: module { memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64} memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64} func.func @test_poly_fn(%arg0: i32) -> i32 { %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index %c0_i32 = arith.constant 0 : i32 %c11 = arith.constant 11 : index %0 = memref.get_global @__constant_10xi32 : memref<10xi32> %1 = memref.get_global @__constant_3xi32 : memref<3xi32> %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32> scf.for %arg1 = %c0 to %c10 step %c1 { memref.store %arg0, %alloc[%arg1] : memref<10xi32> } %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32> scf.for %arg1 = %c0 to %c10 step %c1 { memref.store %c0_i32, %alloc_0[%arg1] : memref<10xi32> } %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>> memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>> scf.for %arg1 = %c0 to %c10 step %c1 { %3 = memref.load %alloc_0[%arg1] : memref<10xi32> %4 = memref.load %alloc[%arg1] : memref<10xi32> %5 = arith.addi %3, %4 : i32 memref.store %5, %alloc_0[%arg1] : memref<10xi32> } %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32> memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32> scf.for %arg1 = %c0 to %c10 step %c1 { scf.for %arg2 = %c0 to %c10 step %c1 { %3 = arith.addi %arg1, %arg2 : index %4 = arith.remui %3, %c10 : index %5 = memref.load %alloc_0[%arg2] : memref<10xi32> %6 = memref.load %alloc_0[%arg1] : memref<10xi32> %7 = arith.muli %6, %5 : i32 %8 = memref.load %alloc_1[%4] : memref<10xi32> %9 = arith.addi %7, %8 : i32 memref.store %9, %alloc_1[%4] : memref<10xi32> } } scf.for %arg1 = %c0 to %c10 step %c1 { %3 = memref.load %alloc_1[%arg1] : memref<10xi32> %4 = memref.load %alloc[%arg1] : memref<10xi32> %5 = arith.subi %3, %4 : i32 memref.store %5, %alloc_1[%arg1] : memref<10xi32> } %2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) { %3 = arith.subi %c11, %arg1 : index %4 = arith.muli %arg0, %arg2 : i32 %5 = memref.load %alloc_1[%3] : memref<10xi32> %6 = arith.addi %4, %5 : i32 scf.yield %6 : i32 } memref.dealloc %alloc : memref<10xi32> memref.dealloc %alloc_0 : memref<10xi32> memref.dealloc %alloc_1 : memref<10xi32> return %2 : i32 } }  And then finally, the rest of the pipeline, as defined so far. module { memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64} memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64} llvm.func @test_poly_fn(%arg0: i32) -> i32 { %0 = llvm.mlir.constant(0 : index) : i64 %1 = builtin.unrealized_conversion_cast %0 : i64 to index %2 = llvm.mlir.constant(10 : index) : i64 %3 = builtin.unrealized_conversion_cast %2 : i64 to index %4 = llvm.mlir.constant(1 : index) : i64 %5 = builtin.unrealized_conversion_cast %4 : i64 to index %6 = llvm.mlir.constant(0 : i32) : i32 %7 = llvm.mlir.constant(11 : index) : i64 %8 = builtin.unrealized_conversion_cast %7 : i64 to index %9 = memref.get_global @__constant_10xi32 : memref<10xi32> %10 = memref.get_global @__constant_3xi32 : memref<3xi32> %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32> cf.br ^bb1(%1 : index) ^bb1(%11: index): // 2 preds: ^bb0, ^bb2 %12 = builtin.unrealized_conversion_cast %11 : index to i64 %13 = llvm.icmp "slt" %12, %2 : i64 llvm.cond_br %13, ^bb2, ^bb3 ^bb2: // pred: ^bb1 memref.store %arg0, %alloc[%11] : memref<10xi32> %14 = llvm.add %12, %4 : i64 %15 = builtin.unrealized_conversion_cast %14 : i64 to index cf.br ^bb1(%15 : index) ^bb3: // pred: ^bb1 %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32> cf.br ^bb4(%1 : index) ^bb4(%16: index): // 2 preds: ^bb3, ^bb5 %17 = builtin.unrealized_conversion_cast %16 : index to i64 %18 = llvm.icmp "slt" %17, %2 : i64 llvm.cond_br %18, ^bb5, ^bb6 ^bb5: // pred: ^bb4 memref.store %6, %alloc_0[%16] : memref<10xi32> %19 = llvm.add %17, %4 : i64 %20 = builtin.unrealized_conversion_cast %19 : i64 to index cf.br ^bb4(%20 : index) ^bb6: // pred: ^bb4 %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>> memref.copy %10, %subview : memref<3xi32> to memref<3xi32, strided<[1]>> cf.br ^bb7(%1 : index) ^bb7(%21: index): // 2 preds: ^bb6, ^bb8 %22 = builtin.unrealized_conversion_cast %21 : index to i64 %23 = llvm.icmp "slt" %22, %2 : i64 llvm.cond_br %23, ^bb8, ^bb9 ^bb8: // pred: ^bb7 %24 = memref.load %alloc_0[%21] : memref<10xi32> %25 = memref.load %alloc[%21] : memref<10xi32> %26 = llvm.add %24, %25 : i32 memref.store %26, %alloc_0[%21] : memref<10xi32> %27 = llvm.add %22, %4 : i64 %28 = builtin.unrealized_conversion_cast %27 : i64 to index cf.br ^bb7(%28 : index) ^bb9: // pred: ^bb7 %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32> memref.copy %9, %alloc_1 : memref<10xi32> to memref<10xi32> cf.br ^bb10(%1 : index) ^bb10(%29: index): // 2 preds: ^bb9, ^bb14 %30 = builtin.unrealized_conversion_cast %29 : index to i64 %31 = llvm.icmp "slt" %30, %2 : i64 llvm.cond_br %31, ^bb11, ^bb15 ^bb11: // pred: ^bb10 %32 = builtin.unrealized_conversion_cast %29 : index to i64 cf.br ^bb12(%1 : index) ^bb12(%33: index): // 2 preds: ^bb11, ^bb13 %34 = builtin.unrealized_conversion_cast %33 : index to i64 %35 = llvm.icmp "slt" %34, %2 : i64 llvm.cond_br %35, ^bb13, ^bb14 ^bb13: // pred: ^bb12 %36 = builtin.unrealized_conversion_cast %33 : index to i64 %37 = llvm.add %32, %36 : i64 %38 = llvm.urem %37, %2 : i64 %39 = builtin.unrealized_conversion_cast %38 : i64 to index %40 = memref.load %alloc_0[%33] : memref<10xi32> %41 = memref.load %alloc_0[%29] : memref<10xi32> %42 = llvm.mul %41, %40 : i32 %43 = memref.load %alloc_1[%39] : memref<10xi32> %44 = llvm.add %42, %43 : i32 memref.store %44, %alloc_1[%39] : memref<10xi32> %45 = llvm.add %34, %4 : i64 %46 = builtin.unrealized_conversion_cast %45 : i64 to index cf.br ^bb12(%46 : index) ^bb14: // pred: ^bb12 %47 = llvm.add %30, %4 : i64 %48 = builtin.unrealized_conversion_cast %47 : i64 to index cf.br ^bb10(%48 : index) ^bb15: // pred: ^bb10 cf.br ^bb16(%1 : index) ^bb16(%49: index): // 2 preds: ^bb15, ^bb17 %50 = builtin.unrealized_conversion_cast %49 : index to i64 %51 = llvm.icmp "slt" %50, %2 : i64 llvm.cond_br %51, ^bb17, ^bb18 ^bb17: // pred: ^bb16 %52 = memref.load %alloc_1[%49] : memref<10xi32> %53 = memref.load %alloc[%49] : memref<10xi32> %54 = llvm.sub %52, %53 : i32 memref.store %54, %alloc_1[%49] : memref<10xi32> %55 = llvm.add %50, %4 : i64 %56 = builtin.unrealized_conversion_cast %55 : i64 to index cf.br ^bb16(%56 : index) ^bb18: // pred: ^bb16 cf.br ^bb19(%5, %6 : index, i32) ^bb19(%57: index, %58: i32): // 2 preds: ^bb18, ^bb20 %59 = builtin.unrealized_conversion_cast %57 : index to i64 %60 = llvm.icmp "slt" %59, %7 : i64 llvm.cond_br %60, ^bb20, ^bb21 ^bb20: // pred: ^bb19 %61 = builtin.unrealized_conversion_cast %57 : index to i64 %62 = llvm.sub %7, %61 : i64 %63 = builtin.unrealized_conversion_cast %62 : i64 to index %64 = llvm.mul %arg0, %58 : i32 %65 = memref.load %alloc_1[%63] : memref<10xi32> %66 = llvm.add %64, %65 : i32 %67 = llvm.add %59, %4 : i64 %68 = builtin.unrealized_conversion_cast %67 : i64 to index cf.br ^bb19(%68, %66 : index, i32) ^bb21: // pred: ^bb19 memref.dealloc %alloc : memref<10xi32> memref.dealloc %alloc_0 : memref<10xi32> memref.dealloc %alloc_1 : memref<10xi32> llvm.return %58 : i32 } }  The remaining problems: • There are cf.br ops left in there, meaning lower-cf-to-llvm was unable to convert them, leaving in a bunch of un-removable index types (see third bullet). • We still have a memref.subview that is not supported in LLVM • We have a bunch of casts like builtin.unrealized_conversion_cast %7 : i64 to index, which are because index is not part of LLVM. The second needs a special pass, memref-expand-strided-metadata, which reduces more complicated memref ops to simpler ones that can be lowered. The third is fixed by using finalize-memref-to-llvm, which lowers index to llvm.ptr and memref to llvm.struct and llvm.array. A final reconcile-unrealized-casts removes the cast operations, provided they can safely be removed. Both are combined in this commit. However, the first one still eluded me for a while, until I figured out through embarrassing trial and error that again func-to-llvm was too early in the pipeline. Moving it to the end (this commit) resulted in cf.br being lowered to llvm.br and llvm.cond_br. Finally, this commit adds a set of standard cleanup passes, including constant propagation, common subexpression elimination, dead code elimination, and canonicalization. The final IR looks like this module { llvm.func @free(!llvm.ptr) llvm.func @malloc(i64) -> !llvm.ptr llvm.mlir.global private constant @__constant_3xi32(dense<[2, 3, 4]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<3 x i32> llvm.mlir.global private constant @__constant_10xi32(dense<0> : tensor<10xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<10 x i32> llvm.func @test_poly_fn(%arg0: i32) -> i32 { %0 = llvm.mlir.constant(0 : i32) : i32 %1 = llvm.mlir.constant(64 : index) : i64 %2 = llvm.mlir.constant(3 : index) : i64 %3 = llvm.mlir.constant(0 : index) : i64 %4 = llvm.mlir.constant(10 : index) : i64 %5 = llvm.mlir.constant(1 : index) : i64 %6 = llvm.mlir.constant(11 : index) : i64 %7 = llvm.mlir.addressof @__constant_10xi32 : !llvm.ptr %8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> %9 = llvm.mlir.addressof @__constant_3xi32 : !llvm.ptr %10 = llvm.getelementptr %9[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<3 x i32> %11 = llvm.mlir.zero : !llvm.ptr %12 = llvm.getelementptr %11[10] : (!llvm.ptr) -> !llvm.ptr, i32 %13 = llvm.ptrtoint %12 : !llvm.ptr to i64 %14 = llvm.add %13, %1 : i64 %15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr %16 = llvm.ptrtoint %15 : !llvm.ptr to i64 %17 = llvm.sub %1, %5 : i64 %18 = llvm.add %16, %17 : i64 %19 = llvm.urem %18, %1 : i64 %20 = llvm.sub %18, %19 : i64 %21 = llvm.inttoptr %20 : i64 to !llvm.ptr llvm.br ^bb1(%3 : i64) ^bb1(%22: i64): // 2 preds: ^bb0, ^bb2 %23 = llvm.icmp "slt" %22, %4 : i64 llvm.cond_br %23, ^bb2, ^bb3 ^bb2: // pred: ^bb1 %24 = llvm.getelementptr %21[%22] : (!llvm.ptr, i64) -> !llvm.ptr, i32 llvm.store %arg0, %24 : i32, !llvm.ptr %25 = llvm.add %22, %5 : i64 llvm.br ^bb1(%25 : i64) ^bb3: // pred: ^bb1 %26 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr %27 = llvm.ptrtoint %26 : !llvm.ptr to i64 %28 = llvm.add %27, %17 : i64 %29 = llvm.urem %28, %1 : i64 %30 = llvm.sub %28, %29 : i64 %31 = llvm.inttoptr %30 : i64 to !llvm.ptr llvm.br ^bb4(%3 : i64) ^bb4(%32: i64): // 2 preds: ^bb3, ^bb5 %33 = llvm.icmp "slt" %32, %4 : i64 llvm.cond_br %33, ^bb5, ^bb6 ^bb5: // pred: ^bb4 %34 = llvm.getelementptr %31[%32] : (!llvm.ptr, i64) -> !llvm.ptr, i32 llvm.store %0, %34 : i32, !llvm.ptr %35 = llvm.add %32, %5 : i64 llvm.br ^bb4(%35 : i64) ^bb6: // pred: ^bb4 %36 = llvm.mul %2, %5 : i64 %37 = llvm.getelementptr %11[1] : (!llvm.ptr) -> !llvm.ptr, i32 %38 = llvm.ptrtoint %37 : !llvm.ptr to i64 %39 = llvm.mul %36, %38 : i64 "llvm.intr.memcpy"(%31, %10, %39) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> () llvm.br ^bb7(%3 : i64) ^bb7(%40: i64): // 2 preds: ^bb6, ^bb8 %41 = llvm.icmp "slt" %40, %4 : i64 llvm.cond_br %41, ^bb8, ^bb9 ^bb8: // pred: ^bb7 %42 = llvm.getelementptr %31[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %43 = llvm.load %42 : !llvm.ptr -> i32 %44 = llvm.getelementptr %21[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %45 = llvm.load %44 : !llvm.ptr -> i32 %46 = llvm.add %43, %45 : i32 llvm.store %46, %42 : i32, !llvm.ptr %47 = llvm.add %40, %5 : i64 llvm.br ^bb7(%47 : i64) ^bb9: // pred: ^bb7 %48 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr %49 = llvm.ptrtoint %48 : !llvm.ptr to i64 %50 = llvm.add %49, %17 : i64 %51 = llvm.urem %50, %1 : i64 %52 = llvm.sub %50, %51 : i64 %53 = llvm.inttoptr %52 : i64 to !llvm.ptr %54 = llvm.mul %4, %5 : i64 %55 = llvm.mul %54, %38 : i64 "llvm.intr.memcpy"(%53, %8, %55) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> () llvm.br ^bb10(%3 : i64) ^bb10(%56: i64): // 2 preds: ^bb9, ^bb14 %57 = llvm.icmp "slt" %56, %4 : i64 llvm.cond_br %57, ^bb11, ^bb15 ^bb11: // pred: ^bb10 llvm.br ^bb12(%3 : i64) ^bb12(%58: i64): // 2 preds: ^bb11, ^bb13 %59 = llvm.icmp "slt" %58, %4 : i64 llvm.cond_br %59, ^bb13, ^bb14 ^bb13: // pred: ^bb12 %60 = llvm.add %56, %58 : i64 %61 = llvm.urem %60, %4 : i64 %62 = llvm.getelementptr %31[%58] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %63 = llvm.load %62 : !llvm.ptr -> i32 %64 = llvm.getelementptr %31[%56] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %65 = llvm.load %64 : !llvm.ptr -> i32 %66 = llvm.mul %65, %63 : i32 %67 = llvm.getelementptr %53[%61] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %68 = llvm.load %67 : !llvm.ptr -> i32 %69 = llvm.add %66, %68 : i32 llvm.store %69, %67 : i32, !llvm.ptr %70 = llvm.add %58, %5 : i64 llvm.br ^bb12(%70 : i64) ^bb14: // pred: ^bb12 %71 = llvm.add %56, %5 : i64 llvm.br ^bb10(%71 : i64) ^bb15: // pred: ^bb10 llvm.br ^bb16(%3 : i64) ^bb16(%72: i64): // 2 preds: ^bb15, ^bb17 %73 = llvm.icmp "slt" %72, %4 : i64 llvm.cond_br %73, ^bb17, ^bb18 ^bb17: // pred: ^bb16 %74 = llvm.getelementptr %53[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %75 = llvm.load %74 : !llvm.ptr -> i32 %76 = llvm.getelementptr %21[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %77 = llvm.load %76 : !llvm.ptr -> i32 %78 = llvm.sub %75, %77 : i32 llvm.store %78, %74 : i32, !llvm.ptr %79 = llvm.add %72, %5 : i64 llvm.br ^bb16(%79 : i64) ^bb18: // pred: ^bb16 llvm.br ^bb19(%5, %0 : i64, i32) ^bb19(%80: i64, %81: i32): // 2 preds: ^bb18, ^bb20 %82 = llvm.icmp "slt" %80, %6 : i64 llvm.cond_br %82, ^bb20, ^bb21 ^bb20: // pred: ^bb19 %83 = llvm.sub %6, %80 : i64 %84 = llvm.mul %arg0, %81 : i32 %85 = llvm.getelementptr %53[%83] : (!llvm.ptr, i64) -> !llvm.ptr, i32 %86 = llvm.load %85 : !llvm.ptr -> i32 %87 = llvm.add %84, %86 : i32 %88 = llvm.add %80, %5 : i64 llvm.br ^bb19(%88, %87 : i64, i32) ^bb21: // pred: ^bb19 llvm.call @free(%15) : (!llvm.ptr) -> () llvm.call @free(%26) : (!llvm.ptr) -> () llvm.call @free(%48) : (!llvm.ptr) -> () llvm.return %81 : i32 } }  ## Exiting MLIR In the MLIR parlance, the LLVM dialect is an “exit” dialect, meaning after lowering to LLVM you run a different tool that generates code used by an external system. In our case, this will be LLVM’s internal representation (“LLVM IR” which is different from the LLVM MLIR dialect). The “code gen” step is often called translation in the MLIR docs. Here are the official docs on generating LLVM IR. The codegen tool is called mlir-translate, and it has an --mlir-to-llvmir option. Running the command below on our output IR above gives the IR in this gist.   bazel build @llvm-project//mlir:mlir-translate
$bazel run //tools:tutorial-opt -- --poly-to-llvm$PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir


Next, to compile LLVM IR with LLVM directly, we use the llc tool that has the --filetype-obj option to emit an object file. Without that flag, you can see a textual representation of the machine code

$bazel build @llvm-project//llvm:llc$ bazel run //tools:tutorial-opt -- --poly-to-llvm PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc # textual representation .text .file "LLVMDialectModule" .globl test_poly_fn # -- Begin function test_poly_fn .p2align 4, 0x90 .type test_poly_fn,@function test_poly_fn: # @test_poly_fn .cfi_startproc # %bb.0: pushq %rbp .cfi_def_cfa_offset 16 pushq %r15 .cfi_def_cfa_offset 24 pushq %r14 .cfi_def_cfa_offset 32 pushq %r13 .cfi_def_cfa_offset 40 pushq %r12 .cfi_def_cfa_offset 48 pushq %rbx .cfi_def_cfa_offset 56 pushq %rax .cfi_def_cfa_offset 64 .cfi_offset %rbx, -56 .cfi_offset %r12, -48 <snip>  And finally, we can save the object file and compile and link it to a C main that calls the function and prints the result.  cat tests/poly_to_llvm_main.c
#include <stdio.h>

// This is the function we want to call from LLVM
int test_poly_fn(int x);

int main(int argc, char *argv[]) {
int i = 1;
int result = test_poly_fn(i);
printf("Result: %d\n", result);
return 0;
}
$bazel run //tools:tutorial-opt -- --poly-to-llvm$PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc --filetype=obj > poly_to_llvm.o
$clang -c tests/poly_to_llvm_main.c && clang poly_to_llvm_main.o poly_to_llvm.o -o a.out$ ./a.out
Result: 320


The polynomial computed by the test function is the rather bizarre:

((1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9) + (2 + 3*x + 4*x**2))**2 - (1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9)


But computing this mod $x^{10} – 1$ in Sympy, we get… 351. Uh oh. So somewhere along the way we messed up the lowering.

Before we fix it, let’s encode the whole process as a lit test. This requires making the binaries mlir-translate, llc, and clang available to the test runner, and setting up the RUN pipeline inside the test file. This is contained in this commit. Note %t tells lit to generate a test-unique temporary file.

// RUN: tutorial-opt --poly-to-llvm %s | mlir-translate --mlir-to-llvmir | llc -filetype=obj > %t
// RUN: clang -c poly_to_llvm_main.c
// RUN: clang poly_to_llvm_main.o %t -o a.out
// RUN: ./a.out | FileCheck

// CHECK: 351
func.func @test_poly_fn(%arg : i32) -> i32 {
%tens = tensor.splat %arg : tensor<10xi32>
%input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
%0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%1 = poly.add %0, %input : !poly.poly<10>
%2 = poly.mul %1, %1 : !poly.poly<10>
%3 = poly.sub %2, %input : !poly.poly<10>
%4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
return %4 : i32
}


Then running this as a normal lit test fails as

error: CHECK: expected string not found in input
# | // CHECK: 351
# |           ^
# | <stdin>:1:1: note: scanning from here
# | Result: 320
# | ^
# | <stdin>:1:9: note: possible intended match here
# | Result: 320


## Fixing the bug in the lowering

How do you find the bug in a lowering? Apparently there are some folks doing this via formal verification, formalizing the dialects and the lowerings in lean to prove correctness. I don’t have the time for that in this tutorial, so instead I simplify/expand the tests and squint.

Simplifying the test to computing the simpler function $t \mapsto 1 + t + t^2$, I see that an input of 1 gives 2, when it should be 3. This suggests that eval is lowered wrong, at least. An input of 5 gives 4, so they’re all off by one term. This likely means I have an off-by-one error in the loop that eval lowers to, and indeed that’s the problem. I was essentially doing this, where $N$ is the degree of the polynomial ($N-1$ is the largest legal index into the tensor):

accum = 0
for 1 <= i < N+1
index = N+1 - i
accum = accum * point + coeffs[index]


When it should have been

accum = 0
for 1 <= i < N+1
index = N - i
accum = accum * point + coeffs[index]


This commit fixes it, and now all the tests pass.

Aside: I don’t know of a means to do more resilient testing in MLIR. For example, I don’t know of a fuzz testing or property testing option. In my main project, HEIR, I hand-rolled a simple test generation routine that reads a config file and spits out MLIR test files. That allows me to jot down a bunch of tests quickly, but doesn’t give the full power of a system like hypothesis, which is a framework I’m quite fond of. I think the xDSL project would work in my favor here, but I have yet to dive into that, and as far as I know it requires re-defining all your custom dialects in Python to use. More on that on a future article.

## Taking a step back

This article showed a bit of a naive approach to building up a dialect conversion pipeline, where we just greedily looked for ops to lower and inserted the relevant passes somewhat haphazardly. That worked out OK, but some lower-level passes (converting to LLVM) confounded the overall pipeline when placed too early.

A better approach is to identify the highest level operations and lower those first. But that is only really possible if you already know which passes are available and what they do. For example, elementwise-to-linalg takes something that seems low-level a prior—aritil noticing that convert-arith-to-llvm silently ignored those ops. Similarly, the implications of converting func to LLVM (which appears to handle more than just the ops in func, were not clear until we tried it and ran into problems.

I don’t have a particularly good solution here besides trial and error. But I appreciate good tools, so I will amuse myself with some ideas.

Since most passes have per-op patterns, it seems like one could write a tool that analyzes an IR, simulates running all possible rewrite patterns from the standard list of passes, and checks which ones succeeded (i.e., the “match” was successful, though most patterns are a combined matchAndRewrite), and what op types were generated as a result. Then once you get to an IR that you don’t know how to lower further, you could run the tool and it would tell you all of your options.

An even more aggressive tool would could construct a complete graph of op-to-op conversions. You could identify an op to lower and a subset of legal dialects or ops (similar to a ConversionTarget), and it would report all possible paths from the op to your desired target.

I imagine performance is an obstacle here, and I also wonder to what extent this could be statically analyzed at least coarsely. For example, instead of running the rewrites, you could statically analyze the implementations of rewrite patterns lowering FooOp for occurrences of create<BarOp>, and just include an edge FooOp -> BarOp as a conservative estimate (it might not generate BarOp for every input).

## Plans

At this point we’ve covered the basic ground of MLIR: defining a new dialect, writing some optimization passes, doing some lowerings, and compiling and running a proper binary. The topics we could cover from here on out are rather broad, but here’s a few options:

• Writing an analysis pass.
• Writing a global program optimization pass with a cost model, in contrast to the local rewrites we’ve done so far.
• Defining a custom interface (trait with input/output) to write a generic pass that applies to multiple dialects.
• Explaining the linear algebra dialect and linalg.generic.
• Exploring some of the front ends for MLIR, like Polygeist and ClangIR.
• Exploring the Python bindings or C API for MLIR.
• Exploring some of the peripheral and in-progress projects around MLIR like PDLL, IRDL, xDSL, etc.
• Diving into the details of various upstream MLIR optimizations, like the polyhedral analysis or constant propagation passes.
• Working with sparse tensors.
• Covering more helper tools like the diagnostic infrastructure and bytecode generation.

# MLIR — Dialect Conversion

In previous articles we defined a dialect, and wrote various passes to optimize and canonicalize a program using that dialect. However, one of the main tenets of MLIR is “incremental lowering,” the idea that there are lots of levels of IR granularity, and you incrementally lower different parts of the IR, only discarding information when it’s no longer useful for optimizations. In this article we’ll see the first step of that: lowering the poly dialect to a combination of standard MLIR dialects, using the so-called dialect conversion infrastructure to accomplish it.

Side note: Victor Guerra kindly contributed a CMake build overlay to the mlir-tutorial project, in this PR, which I will be maintaining for the rest of the tutorial series. I still don’t know CMake well enough to give a tutorial on it, but I did work through parts of the official tutorial which seems excellent.

## The type obstacle

If not for types, dialect conversion (a.k.a. lowering) would be essentially the same as a normal pass: you write some rewrite patterns and apply them to the IR. You’d typically have one rewrite pattern per operation that needs to be lowered.

Types make this significantly more complex, and I’ll demonstrate the issue by example with poly.

In poly we have a poly.add operation that takes two values of poly.poly type, and returns a value of poly.poly type. We want to lower poly.add to, say, a vectorized loop of basic arithmetic operations. We can do this, and we will do it below, by rewriting poly.add as an arith.addi. However, arith.addi doesn’t know about poly.poly and likely never will.

Aside: if we had to make arith extensible to know about poly, we’d need to upstream a change to the arith.addi op’s operands to allow types that implement some sort of interface like “integer-like, or containers of integer-like things,” but when I suggested that to some MLIR folks they gently demurred. I believe this is because the MLIR community feels that the arith dialect is too overloaded.

So, in addition to lowering the op, we need to lower the poly.poly<N> type to something like tensor<Nxi32>. And this is where the type obstacle comes into play. Once you change a specific value’s type, say, while lowering an op that produces that value as output, then all downstream users of that value are still expecting the old type and are technically invalid until they are lowered. In between each pass MLIR runs verifiers to ensure the IR is valid, so without some special handling, this means all the types and ops need to be converted in one pass, or else these verifiers would fail. But managing all that with standard rewrite rules would be hard: for each op rewrite rule, you’d have to continually check if the arguments and results have been converted yet or not.

MLIR handles this situation with a wrapper around standard passes that they call the dialect conversion framework. Official docs here. They require the user of the framework to inherit from different classes normal rewrites, set up some additional metadata, and separate type conversion from op conversion in a specific way we’ll see shortly. But at a high level, this framework works by lowering ops in a certain sorted order, converting the types as they go, and giving the op converters access to both the original types of each op as well as what the in-progress converted types look like at the time the op is visited by the framework. Each op-based rewrite pattern is expected to make that op type-legal after it’s visited, but it need not worry about downstream ops.

Finally, the dialect conversion framework keeps track of any type conflicts that aren’t resolved, and if any remain at the end of the pass, one of two things happens. The conversion framework allows one to optionally implement what’s called a type materializer, which inserts new intermediate ops that resolve type conflicts. So the first possibility is that the dialect conversion framework uses your type materializer hooks to patch up the IR, and the pass ends successfully. If those hooks fail, or if you didn’t define any hooks, then the pass fails and complains that it couldn’t fix the type conflicts.

Part of the complexity of this infrastructure also has to do with one of the harder lowering pipelines in upstream MLIR: the bufferization pipeline. This pipeline essentially converts an IR that uses ops with “value semantics” into one with “pointer semantics.” For example, the tensor type and its associated operations have value semantics, meaning each op is semantically creating a brand new tensor in its output, and all operations are pure (with some caveats). On the other hand, memref has pointer semantics, meaning it’s modeling the physical hardware more closely, requires explicit memory allocation, and supports operations that mutate memory locations.

Because bufferization is complicated, it is split into multiple sub-passes that handle bufferization issues specific to each of the relevant upstream MLIR dialects (see in the docs, e.g., arith-bufferize, func-bufferize, etc.). Each of these bufferization passes creates some internally-unresolvable type conflicts, which require custom type materializations to resolve. And to juggle all these issues across all the relevant dialects, the MLIR folks built a dedicated dialect called bufferization to house the intermediate ops. You’ll notice ops like to_memref and to_tensor that serve this role. And then there is a finalizing-bufferize pass whose role is to clean up any lingering bufferization/materialization ops.

There was a talk at the 2020-11-19 MLIR Open Design meeting called “Type Conversions the Not-So-Hard Way” (slides) that helped me understand these details, but I couldn’t understand the talk much before I tried a few naive lowering attempts, and then even after I felt I understood the talk I ran into some issues around type materialization. So my aim in the rest of this article is to explain the clarity I’ve found, such as it may be, and make it easier for the reader to understand the content of that talk.

## Lowering Poly without Type Materializations

The first commit sets up the pass shell, similar to the previous passes in the series, though it is located in a new directory lib/Conversion. The only caveat to this commit is that it adds dependent dialects:

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::tutorial::poly::PolyDialect",
"mlir::tensor::TensorDialect",
];


In particular, a lowering must depend in this way on any dialects that contain operations or types that the lowering will create, to ensure that MLIR loads those dialects before trying to run the pass.

The second commit defines a ConversionTarget, which tells MLIR how to determine what ops are within scope of the lowering. Specifically, it allows you to declare an entire dialect as “illegal” after the lowering is complete. We’ll do this for the poly dialect.

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
using PolyToStandardBase::PolyToStandardBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();

ConversionTarget target(*context);       // <--- new thing

RewritePatternSet patterns(context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {   // <-- new thing
signalPassFailure();
}
}
};


ConversionTarget can also declare specific ops as illegal, or even conditionally legal, using a callback to specify exactly what counts as legal. And then instead of applyPatternsAndFoldGreedily, we use applyPartialConversion to kick off the lowering.

Aside: We use applyPartialConversion instead of the more natural sounding alternative applyFullConversion for a silly reason: the error messages are better. If applyFullConversion fails, even in debug mode, you don’t get a lot of information about what went wrong. In partial mode, you can see the steps and what type conflicts it wasn’t able to patch up at the end. As far as I can tell, partial conversion is a strict generalization of full conversion, and I haven’t found a reason to use applyFullConversion ever.

After adding the conversion target, running --poly-to-standard on any random poly program will result in an error:

$bazel run tools:tutorial-opt -- --poly-to-standard$PWD/tests/poly_syntax.mlir
tests/poly_syntax.mlir:14:10: error: failed to legalize operation 'poly.add' that was explicitly marked illegal
%0 = poly.add %arg0, %arg1 : !poly.poly<10>


In the next commit we define a subclass of TypeConverter to convert types from poly to other dialects.

class PolyToStandardTypeConverter : public TypeConverter {
public:
PolyToStandardTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
int degreeBound = type.getDegreeBound();
IntegerType elementTy =
IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless);
return RankedTensorType::get({degreeBound}, elementTy);
});
}
};


We call addConversion for each specific type we need to convert, and then the default addConversion([](Type type) { return type; }); signifies “no work needed, this type is legal.” By itself this type converter does nothing, because it is only used in conjunction with a specific kind of rewrite pattern, a subclass of OpConversionPattern.

struct ConvertAdd : public OpConversionPattern<AddOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
ConversionPatternRewriter &rewriter) const override {
// ...
return success();
}
};


We add this, along with the new function call to RewritePatternSet::add, in this commit. This call differs from a normal RewritePatternSet::add in that it takes the type converter as input and passes it along to the constructor of ConvertAdd, whose parent class takes the type converter as input and stores it as a member variable.

void runOnOperation() {
...
RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
}


An OpConversionPattern‘s matchAndRewrite method has two new arguments: the OpAdaptor and the ConversionPatternRewriter. The first, OpAdaptor, is an alias for AddOp::Adaptor which is part of the generated C++ code, holds the type-converted operands during dialect conversion. It’s called “adaptor” because it uses the table-gen defined names for an op’s arguments and results in the method names, rather than generic getOperand equivalents. Meanwhile, the AddOp argument (same as a normal rewrite pattern) contains the original, un-type-converted operands and results. The ConversionPatternRewriter is like a PatternRewriter, but it has additional methods relevant to dialect conversion, such as convertRegionTypes which is a helper for applying type conversions for ops with nested regions. All modifications to the IR must go through the ConversionPatternRewriter.

In the next commit, we implement the lowering for poly.add.

struct ConvertAdd : public OpConversionPattern<AddOp> {
<...>
ConversionPatternRewriter &rewriter) const override {
return success();
}
};


The main thing to notice is that we don’t have to convert the types ourselves. The conversion framework converted the operands in advance, put them on the OpAdaptor, and we can just focus on creating the lowered op.

Mathematically, however, this lowering is so simple because I’m taking advantage of the overflow semantics of 32-bit integer addition to avoid computing modular remainders. Or rather, I’m interpreting addi to have natural unsigned overflow semantics, though I think there is some ongoing discussion about how precise these semantics should be. Also, arith.addi is defined to operate on tensors of integers elementwise, in addition to operating on individual integers.

However, there is one hiccup when we try to test it on the following IR

func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) {
%2 = poly.add %0, %1: !poly.poly<10>
return
}


The pass fails here because of a type conflict between the function arguments and the converted add op. We get an error that looks like

error: failed to legalize unresolved materialization from '!poly.poly<10>' to 'tensor<10xi32>' that remained live after conversion
%2 = poly.add %0, %1: !poly.poly<10>
^
poly_to_standard.mlir:6:8: note: see current operation: %0 = "builtin.unrealized_conversion_cast"(%arg1) : (!poly.poly<10>) -> tensor<10xi32>
poly_to_standard.mlir:6:8: note: see existing live user here: %2 = arith.addi %1, %0 : tensor<10xi32>


If you look at the debug output (run the pass with --debug, i.e., bazel run tools:tutorial-opt -- --poly-to-standard --debug $PWD/tests/poly_to_standard.mlir) you’ll see a log of what the conversion framework is trying to do. Eventually it spits out this IR just before the end: func.func @test_lower_add(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) { %0 = builtin.unrealized_conversion_cast %arg1 : !poly.poly<10> to tensor<10xi32> %1 = builtin.unrealized_conversion_cast %arg0 : !poly.poly<10> to tensor<10xi32> %2 = arith.addi %1, %0 : tensor<10xi32> %3 = poly.add %arg0, %arg1 : !poly.poly<10> return }  This builtin.unrealized_conversion_cast is the internal stand-in for a type conflict before attempting to use user-defined materialization hooks to fix the conflicts (which we didn’t implement). It’s basically a forced type coercion. From the docs: This operation should not be attributed any special representational or execution semantics, and is generally only intended to be used to satisfy the temporary intermixing of type systems during the conversion of one type system to another. So at this point we have two options to make the pass succeed: • Convert the types in the function ops as well. • Add a custom type materializer to replace the unrealized conversion cast. The second option doesn’t really make sense, since we want to get rid of poly.poly entirely in this pass. But if we had to, we would use poly.from_tensor and poly.to_tensor, and we’ll show how to do that in the next section. For now, we’ll convert the function types. One should have to lower a structural op like func.func in every conversion pass that has a new type that can show up in a function signature. Ditto for ops like scf.if, scf.for, etc. Luckily, MLIR is general enough to handle it for you, but it requires opting in via some extra boilerplate in this commit. I won’t copy the boilerplate here, but basically you call helpers that use interfaces to define patterns to support type conversions for all function-looking ops (func/call/return), and all if/then/else-looking ops. We just call those helpers and give them our type converter, and it gives us patterns to add to our RewritePatternSet. After that, the pass runs and we get  func.func @test_lower_add(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) { %0 = arith.addi %arg0, %arg1 : tensor<10xi32> return }  I wanted to highlight one other aspect of dialect conversions, which is that they use the folding capabilities of the dialect being converted. So for an IR like func.func @test_lower_add_and_fold() { %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> %1 = poly.constant dense<[3, 4, 5]> : tensor<3xi32> : !poly.poly<10> %2 = poly.add %0, %1: !poly.poly<10> return }  Instead of lowering poly.add to arith.addi, and using a lowering for poly.constant added in this commit, we get func.func @test_lower_add_and_fold() { %cst = arith.constant dense<[2, 3, 4]> : tensor<3xi32> %cst_0 = arith.constant dense<[3, 4, 5]> : tensor<3xi32> %cst_1 = arith.constant dense<[5, 7, 9]> : tensor<3xi32> return }  So as it lowers, it eagerly tries to fold the results, and that can result in some constant propagation. Next, we lower the rest of the ops: • sub, from_tensor, to_tensor. The from_tensor lowering is slightly interesting, in that it allows the user to input just the lowest-degree coefficients of the polynomial, and then pads the higher degree terms with zero. This results in lowering to a tensor.pad op. • constant, which is interesting in that it lowers to another poly op (arith.constant + poly.from_tensor), which is then recursively converted via the rewrite rule for from_tensor. • mul, which lowers as a standard naive double for loop using the scf.for op • eval, which lowers as a single for loop and Horner’s method. (Note it does not support complex inputs, despite the support we added for that in the previous article) Then in this commit you can see how it lowers a combination of all the poly ops to an equivalent program that uses just tensor, scf, and arith. For the input test: func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 { %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> %1 = poly.add %0, %arg : !poly.poly<10> %2 = poly.mul %1, %1 : !poly.poly<10> %3 = poly.sub %2, %arg : !poly.poly<10> %4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32 return %4 : i32 }  It produces the following IR:  func.func @test_lower_many(%arg0: tensor<10xi32>, %arg1: i32) -> i32 { %cst = arith.constant dense<[2, 3, 4]> : tensor<3xi32> %c0_i32 = arith.constant 0 : i32 %padded = tensor.pad %cst low[0] high[7] { ^bb0(%arg2: index): tensor.yield %c0_i32 : i32 } : tensor<3xi32> to tensor<10xi32> %0 = arith.addi %padded, %arg0 : tensor<10xi32> %cst_0 = arith.constant dense<0> : tensor<10xi32> %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index %1 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %cst_0) -> (tensor<10xi32>) { %4 = scf.for %arg4 = %c0 to %c10 step %c1 iter_args(%arg5 = %arg3) -> (tensor<10xi32>) { %5 = arith.addi %arg2, %arg4 : index %6 = arith.remui %5, %c10 : index %extracted = tensor.extract %0[%arg2] : tensor<10xi32> %extracted_3 = tensor.extract %0[%arg4] : tensor<10xi32> %7 = arith.muli %extracted, %extracted_3 : i32 %extracted_4 = tensor.extract %arg5[%6] : tensor<10xi32> %8 = arith.addi %7, %extracted_4 : i32 %inserted = tensor.insert %8 into %arg5[%6] : tensor<10xi32> scf.yield %inserted : tensor<10xi32> } scf.yield %4 : tensor<10xi32> } %2 = arith.subi %1, %arg0 : tensor<10xi32> %c1_1 = arith.constant 1 : index %c11 = arith.constant 11 : index %c0_i32_2 = arith.constant 0 : i32 %3 = scf.for %arg2 = %c1_1 to %c11 step %c1_1 iter_args(%arg3 = %c0_i32_2) -> (i32) { %4 = arith.subi %c11, %arg2 : index %5 = arith.muli %arg1, %arg3 : i32 %extracted = tensor.extract %2[%4] : tensor<10xi32> %6 = arith.addi %5, %extracted : i32 scf.yield %6 : i32 } return %3 : i32 }  ## Materialization hooks The previous section didn’t use the materialization hook infrastructure mentioned earlier, simply because we don’t need it. It’s only necessary when type conflicts must persist across multiple passes, and we have no trouble lowering all of poly in one pass. But for demonstration purposes, this commit (reverted in the same PR) shows how one would implement the materialization hooks. I removed the lowering for poly.sub, reconfigured the ConversionTarget so that poly.sub and to_tensor, from_tensor are no longer illegal, and then added two hooks on the TypeConverter to insert new from_tensor and to_tensor ops when they are needed: // Convert from a tensor type to a poly type: use from_tensor addSourceMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { return builder.create<poly::FromTensorOp>(loc, type, inputs[0]); }); // Convert from a poly type to a tensor type: use to_tensor addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { return builder.create<poly::ToTensorOp>(loc, type, inputs[0]); });  Now the same lowering above succeeds, except anywhere there is a poly.sub it is surrounded by from_tensor and to_tensor. The output for test_lower_many now looks like this:  func.func @test_lower_many(%arg0: tensor<10xi32>, %arg1: i32) -> i32 { %2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %cst_0) -> (tensor<10xi32>) { %7 = scf.for %arg4 = %c0 to %c10 step %c1 iter_args(%arg5 = %arg3) -> (tensor<10xi32>) { ... } scf.yield %7 : tensor<10xi32> } %3 = poly.from_tensor %2 : tensor<10xi32> -> !poly.poly<10> %4 = poly.sub %3, %0 : !poly.poly<10> %5 = poly.to_tensor %4 : !poly.poly<10> -> tensor<10xi32> %c1_1 = arith.constant 1 : index ... }  ## Correctness? Lowering poly is not entirely trivial, because you’re implementing polynomial math. The tests I wrote are specific, in that they put tight constraints on the output IR, but what if I’m just dumb and I got the algorithm wrong? It stands to reason that I would want an extra layer of certainty to guard against my own worst enemy. One of the possible solutions here is to use mlir-cpu-runner, like we did when we lowered ctlz to LLVM. Another is to lower all the way to machine code and run it on the actual hardware instead of through an interpreter. And that’s what we’ll do next time. Until then, the poly lowerings might be buggy. Feel free to point out any bugs in GitHub issues or comments on this blog. # MLIR — Canonicalizers and Declarative Rewrite Patterns Table of Contents In a previous article we defined folding functions, and used them to enable some canonicalization and the sccp constant propagation pass for the poly dialect. This time we’ll see how to add more general canonicalization patterns. The code for this article is in this pull request, and as usual the commits are organized to be read in order. ## Why is Canonicalization Needed? MLIR provides folding as a mechanism to simplify an IR, which can result in simpler, more efficient ops (e.g., replacing multiplication by a power of 2 with a shift) or removing ops entirely (e.g., deleting$y = x+0$and using$x$for$y$downstream). It also has a special hook for constant materialization. General canonicalization differs from folding in MLIR in that it can transform many operations. In the MLIR docs they will call this “DAG-to-DAG” rewriting, where DAG stands for “directed acyclic graph” in the graph theory sense, but really just means a section of the IR. Canonicalizers can be written in the standard way: declare the op has a canonicalizer in tablegen and then implement a generated C++ function declaration. The official docs for that are here. Or you can do it all declaratively in tablegen, the docs for that are here. We’ll do both in this article. Aside: there is a third way, to use a new system called PDLL, but I haven’t figured out how to use that yet. It should be noted that PDLL is under active development, and in the meantime the tablegen-based approach in this article (called “DRR” for Declarative Rewrite Rules in the MLIR docs) is considered to be in maintenance mode, but not yet deprecated. I’ll try to cover PDLL in a future article. ## Canonicalizers in C++ Reusing our poly dialect, we’ll start with the binary polynomial operations, adding let hasCanonicalizer = 1; to the op base class in this commit, which generates the following method headers on each of the binary op classes // PolyOps.h.inc static void getCanonicalizationPatterns( ::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);  The body of this method asks to add custom rewrite patterns to the input results set, and we can define those patterns however we feel in the C++. The first canonicalization pattern we’ll write in this commit is for the simple identity$x^y – y^2 = (x+y)(x-y)$, which is useful because it replaces a multiplication with an addition. The only caveat is that this canonicalization is only more efficient if the squares have no other downstream uses. // Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses. struct DifferenceOfSquares : public OpRewritePattern<SubOp> { DifferenceOfSquares(mlir::MLIRContext *context) : OpRewritePattern<SubOp>(context, /*benefit=*/1) {} LogicalResult matchAndRewrite(SubOp op, PatternRewriter &rewriter) const override { Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); if (!lhs.hasOneUse() || !rhs.hasOneUse()) { return failure(); } auto rhsMul = rhs.getDefiningOp<MulOp>(); auto lhsMul = lhs.getDefiningOp<MulOp>(); if (!rhsMul || !lhsMul) { return failure(); } bool rhsMulOpsAgree = rhsMul.getLhs() == rhsMul.getRhs(); bool lhsMulOpsAgree = lhsMul.getLhs() == lhsMul.getRhs(); if (!rhsMulOpsAgree || !lhsMulOpsAgree) { return failure(); } auto x = lhsMul.getLhs(); auto y = rhsMul.getLhs(); AddOp newAdd = rewriter.create<AddOp>(op.getLoc(), x, y); SubOp newSub = rewriter.create<SubOp>(op.getLoc(), x, y); MulOp newMul = rewriter.create<MulOp>(op.getLoc(), newAdd, newSub); rewriter.replaceOp(op, {newMul}); // We don't need to remove the original ops because MLIR already has // canonicalization patterns that remove unused ops. return success(); } };  The test in the same commit shows the impact: // Input: func.func @test_difference_of_squares( %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> { %2 = poly.mul %0, %0 : !poly.poly<3> %3 = poly.mul %1, %1 : !poly.poly<3> %4 = poly.sub %2, %3 : !poly.poly<3> %5 = poly.add %4, %4 : !poly.poly<3> return %5 : !poly.poly<3> } // Output: // bazel run tools:tutorial-opt -- --canonicalize$FILE
func.func @test_difference_of_squares(%arg0: !poly.poly<3>, %arg1: !poly.poly<3>) -> !poly.poly<3> {
%0 = poly.add %arg0, %arg1 : !poly.poly<3>
%1 = poly.sub %arg0, %arg1 : !poly.poly<3>
%2 = poly.mul %0, %1 : !poly.poly<3>
%3 = poly.add %2, %2 : !poly.poly<3>
return %3 : !poly.poly<3>
}


Other than this pattern being used in the getCanonicalizationPatterns function, there is nothing new here compared to the previous article on rewrite patterns.

## Canonicalizers in Tablegen

The above canonicalization is really a simple kind of optimization attached to the canonicalization pass. It seems that this is how many minor optimizations end up being implemented, making the --canonicalize pass a heavyweight and powerful pass. However, the name “canonicalize” also suggests that it should be used to put the IR into a canonical form so that later passes don’t have to check as many cases.

So let’s implement a canonicalization like that as well. We’ll add one that has poly interact with the complex dialect, and implement a canonicalization that ensures complex conjugation always comes after polynomial evaluation. This works because of the identity $f(\overline{z}) = \overline{f(z)}$ for all polynomials $f$ and all complex numbers $z$.

This commit adds support for complex inputs in the poly.eval op. Note that in MLIR, complex types are forced to be floating point because all the op verifiers that construct complex numbers require it. The complex type itself, however, suggests a complex<i32> is perfectly legal, so it seems nobody in the MLIR community needs Gaussian integers yet.

The rule itself is implemented in tablegen in this commit. The main tablegen code is:

include "PolyOps.td"
include "mlir/Dialect/Complex/IR/ComplexOps.td"
include "mlir/IR/PatternBase.td"

def LiftConjThroughEval : Pat<
(Poly_EvalOp $f, (ConjOp$z)),
(ConjOp (Poly_EvalOp $f,$z))
>;


The two new pieces here are the Pat class (source) that defines a rewrite pattern, and the parenthetical notation that defines sub-trees of the IR being matched and rewritten. The source documentation on the Pattern parent class is quite well written, so read that for extra detail, and the normal docs provide a higher level view with additional semantics.

But the short story here is that the inputs to Pat are two “IR tree” objects (MLIR calls them “DAG nodes”), and each node in the tree is specified by parentheses ( ) with the first thing in the parentheses being the name of an operation (the tablegen name, e.g., Poly_EvalOp which comes from PolyOps.td), and the remaining arguments being the op’s arguments or attributes. Naturally, the nodes can nest, and that corresponds to a match applied to the argument. I.e., (Poly_EvalOp $f, (ConjOp$z)) means “an eval op whose first argument is anything (bind that to $f) and whose second argument is the output of a ConjOp whose input is anything (bind that to $z).

When running tablegen with the -gen-rewriters option, this generates this code, which is not much more than a thorough version of the pattern we’d write manually. Then in this commit we show how to include it in the codebase. We still have to tell MLIR which pass to add the generated patterns to. You can add each pattern by name, or use the populateWithGenerated function to add them all.

As another example, this commit reimplements the difference of squares pattern in tablegen. This one uses three additional features: a pattern that generates multiple new ops (which uses Pattern instead of Pat), binding the ops to names, and constraints that control when the pattern may be run.

// PolyPatterns.td
def HasOneUse: Constraint<CPred<"$_self.hasOneUse()">, "has one use">; // Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses. def DifferenceOfSquares : Pattern< (Poly_SubOp (Poly_MulOp:$lhs $x,$x), (Poly_MulOp:$rhs$y, $y)), [ (Poly_AddOp:$sum $x,$y),
(Poly_SubOp:$diff$x, $y), (Poly_MulOp:$res $sum,$diff),
],
[(HasOneUse:$lhs), (HasOneUse:$rhs)]
>;


The HasOneUse constraint merely injects the quoted C++ code into a generated if guard, with $_self as a magic string to substitute in the argument when it’s used. But then notice the syntax of (Poly_MulOp:$lhs $x,$x), the colon binds \$lhs to refer to the op as a whole (or, via method overloads, its result value), so that it can be passed to the constraint. Similarly, the generated ops are all given names so they can be fed as the arguments of other generated ops Finally, the second argument of Pattern is a list of generated ops to replace the matched input IR, rather than a single node for Pat.

The benefit of doing this is significantly less boilerplate related to type casting, checking for nulls, and emitting error messages. But because you still occasionally need to inject random C++ code, and inspect the generated C++ to debug, it helps to be fluent in both styles. I don’t know how to check a constraint like “has a single use” in pure DRR tablegen.