MLIR — A Global Optimization and Dataflow Analysis

Table of Contents

In this article we’ll implement a global optimization pass, and show how to use the dataflow analysis framework to verify the results of our optimization.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

The noisy arithmetic problem

This demonstration is based on a simplified model of computation relevant to the HEIR project. You don’t need to be familiar with that project to follow this article, but if you’re wondering why someone would ever want the kind of optimization I’m going to write, that project is why.

The basic model is “noisy integer arithmetic.” That is, a program can have integer types of bounded width, and each integer is associated with some unknown (but bounded with high probability) symmetric random noise. You can imagine the “real” integer being the top 5 bits of a 32-bit integer, and the bottom 27 bits storing the noise. When a new integer is created, it magically has a random signed 12-bit integer added to it. When you apply operations to combine integers, the noise grows. Adding two integers adds their noise, and at worst you get one more bit of noise. Scaling an integer by a statically-known constant scales the noise by a constant. Multiplying two integers multiplies their noise values, and you get twice the bits of noise. As long as your program stays below 27 bits of noise, you can still “recover” the original 5-bit integer at the end of the program. Such a program is called legal, and otherwise, the output is random junk and the program is called illegal.

Finally, there is an expensive operation called reduce_noise that can explicitly reduce the noise of a noisy integer back to the base level of 12 bits. This operation has a statically known cost relative to the standard operations.

Note that starting from two noisy integers, each with 12 bits of noise, a single multiplication op brings you jarringly close to ruin. You would have at most 24 bits of noise, which is close to the maximum of 26. But the input IR we start with may have arbitrary computations that do not respect the noise limits. The goal of our optimization is to rewrite an MLIR region so that the noisy integer math never exceeds the noise limit at any given step.

A trivial way to do that would be to insert reduce_noise ops greedily, whenever an op would bring the noise of a value too high. Inserting such ops may be necessary, but a more suitable goal would be to minimize the overall cost of the program, subject to the constraint that the noisy arithmetic is legal. One could do this by inserting reduce_noise ops more strategically, or by rewriting the program to reduce the need for reduce_noise ops, or both. We’ll focus on the former: finding the best place to insert reduce_noise ops without rewriting the rest of the program.

The noisy dialect

We previously wrote about defining a new dialect, and the noisy dialect we created for this article has little new to show. This commit defines the types and ops, hard coding the 32-bit width and 5-bit semantic input type for simplicity, as well as the values of 12 bits of initial noise and 26 bits of max noise.

Note that the noise bound is not expressed on the type as an attribute. If it were, we’d run into a few problems: first, whenever you insert a reduce_noise op, you’d have to update the types on all the downstream ops. Second, it would prevent you from expressing control flow, since the noise bound cannot be statically inferred from the source code when there are two possible paths that could result in different noise values.

So instead, we need a way to compute the noise values, and associate them with each SSA value, and account for control flow. This is what an analysis pass is designed to do.

An analysis pass is just a class

The typical use of an analysis pass is to construct a data structure that encodes global information about a program, which can then be re-used during different parts of a pass. I imagined there would be more infrastructure around analysis passes in MLIR, but it’s quite simple. You define a C++ class with a constructor that takes an Operation *, and construct it basically whenever you need it. The only infrastructure for it involves storing and caching the constructed analysis within a pass, and understanding when an analysis needs to be recomputed (always between passes, by default).

By way of example, over at the HEIR project I made a simple analysis that chooses a unique variable name for every SSA value in a program, which I then used to generate code in an output language that needed variable names.

For this article we’ll see two analysis passes. One will formulate and solve the optimization problem that decides where to insert reduce_noise operations. This will be one of the “class that does anything” kind of analysis pass. The other analysis pass will rely on MLIR’s data flow analysis framework to propagate the noise model through the IR. This one will actually not require us to write an analysis from scratch, but instead will be implemented by means of the existing IntegerRangeAnalysis, which only requires us to implement an interface on each op that describes how the op affects the noise. This will be used in our pass to verify that the inserted reduce_noise operations ensure, if nothing else, that the noise never exceeds the maximum allowable noise.

We’ll start with the data flow analysis.

Reusing IntegerRangeAnalysis

Data flow analysis is a classical static analysis technique for propagating information through a program’s IR. It is one part of Frances Allen’s Turing Award. This article gives a good introduction and additional details, but I will paraphrase it briefly here in the context of IntegerRangeAnalysis.

The basic idea is that you want to get information about what possible values an integer-typed value can have at any point in a given program. If you see x = 7, then you know exactly what x is. If you see something like

func (%x : i8) { 
  %1 = arith.extsi %x : i8 to i32 
  %2 = arith.addi %x, %x : i32
}

then you know that %2 can be at most a signed 9-bit integer, because it started as an 8-bit integer, and adding two such integers together can’t fill up more than one extra bit.

In such cases, one can find optimizations, like the int-range-optimizations pass in MLIR, which looks at comparison ops arith.cmpi and determines if it can replace them with constants. It does this by looking at the integer range analysis for the two operands. E.g., given the op x > y, if you know y‘s maximum value is less than x‘s minimum value, then you can replace it with a constant true.

Computing the data flow analysis requires two ingredients called a transfer function and a join operation. The transfer function describes what the output integer range should be for a given op and a given set of input integer ranges. This can be an arbitrary function. The join operation describes how to combine two or more integer ranges when you get to a point at the program in which different branches of control flow merge. For example,

def fn(branch):
  x = 7
  if branch:
     y = x * x
  else:
     y = 2*x
  return y

The value of y just before returning cannot be known exactly, but in one branch you know it’s 14, and in another it’s 49. So the final value of y could be estimated as being in the range [14, 49]. Here the join function computes the smallest integer range containing both estimates. [Aside: it could instead use the set {14, 49} to be more precise, but that is not what IntegerRangeAnalysis happens to do]

In order for a data flow analysis to work properly, the values being propagated and the join function must together form a semilattice, which is a partially-ordered set in which every two elements have a least upper bound, that upper bound is computed by join, and join itself must be associative, commutative, and idempotent. For dataflow analysis, the semilattice must also be finite. This is often expressed by having distinct “top” and “bottom” elements as defaults. “Top” represents “could be anything,” and sometimes expresses that a more precise bound would be too computationally expensive to continue to track. “Bottom” usually represents an uninitialized value.

Once you have this, then MLIR provides a general algorithm to propagate values through the IR via a technique called Kildall’s method, which iteratively updates the SSA values, applying the transfer function and joining at the merging of control flow paths, until the process reaches a fixed point.

Here are MLIR’s official docs on dataflow analysis, and here is the RFC where the current data flow solver framework was introduced. In our situation, we want to use the solver framework with the existing IntegerRangeAnalysis, which only asks that we implement the transfer function by implementing InferIntRangeInterface on our ops.

This commit does just that. This requires adding the DeclareOpInterfaceMethods<InferIntRangeInterface> to all relevant ops. That in turn generates function declarations for

void MyOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> inputRanges, SetIntRangeFn setResultRange);

The ConstantIntRange is a dataclass holding a min and max integer value. inputRanges represents the known bounds on the inputs to the operation in question, and SetIntRangeFn is the callback used to produce the result.

For example, for AddOp we can implement it as

ConstantIntRanges unionPlusOne(ArrayRef<ConstantIntRanges> inputRanges) {
  auto lhsRange = inputRanges[0];
  auto rhsRange = inputRanges[1];
  auto joined = lhsRange.rangeUnion(rhsRange);
  return ConstantIntRanges::fromUnsigned(joined.umin(), joined.umax() + 1);
}

void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> inputRanges,
                              SetIntRangeFn setResultRange) {
  setResultRange(getResult(), unionPlusOne(inputRanges));
}

A MulOp is similarly implemented by summing the maxes. Meanwhile, EncodeOp and ReduceNoiseOp each set the initial range to [0, 12]. So the min will always be zero, and we really only care about the max.

The next commit defines an empty pass that will contain our analyses and optimizations, and this commit shows how the integer range analysis is used to validate an IR’s noise growth. In short, you load the IntegerRangeAnalysis and its dependent DeadCodeAnalysis, run the solver, and then walk the IR, asking the solver via lookupState to give the resulting value range for each op’s result, and comparing it against the maximum.

void runOnOperation() {
    Operation *module = getOperation();

    DataFlowSolver solver;
    solver.load<dataflow::DeadCodeAnalysis>();
    solver.load<dataflow::IntegerRangeAnalysis>();
    if (failed(solver.initializeAndRun(module)))
      signalPassFailure();

    auto result = module->walk([&](Operation *op) {
      if (!llvm::isa<noisy::AddOp, noisy::SubOp, noisy::MulOp,
                     noisy::ReduceNoiseOp>(*op)) {
        return WalkResult::advance();
      }
      const dataflow::IntegerValueRangeLattice *opRange =
          solver.lookupState<dataflow::IntegerValueRangeLattice>(
              op->getResult(0));
      if (!opRange || opRange->getValue().isUninitialized()) {
        op->emitOpError()
            << "Found op without a set integer range; did the analysis fail?";
        return WalkResult::interrupt();
      }

      ConstantIntRanges range = opRange->getValue().getValue();
      if (range.umax().getZExtValue() > MAX_NOISE) {
        op->emitOpError() << "Found op after which the noise exceeds the "
                             "allowable maximum of "
                          << MAX_NOISE
                          << "; it was: " << range.umax().getZExtValue()
                          << "\n";
        return WalkResult::interrupt();
      }

      return WalkResult::advance();
    });

    if (result.wasInterrupted())
      signalPassFailure();

Finally, in this commit we add a test that exercises it:

func.func @test_op_syntax() -> i5 {
  %0 = arith.constant 3 : i5
  %1 = arith.constant 4 : i5
  %2 = noisy.encode %0 : i5 -> !noisy.i32
  %3 = noisy.encode %1 : i5 -> !noisy.i32
  %4 = noisy.mul %2, %3 : !noisy.i32
  %5 = noisy.mul %4, %4 : !noisy.i32
  %6 = noisy.mul %5, %5 : !noisy.i32
  %7 = noisy.mul %6, %6 : !noisy.i32
  %8 = noisy.decode %7 : !noisy.i32 -> i5
  return %8 : i5
}

Running tutorial-opt --noisy-reduce-noise on this file produces the following error:

 error: 'noisy.mul' op Found op after which the noise exceeds the allowable maximum of 26; it was: 48

  %5 = noisy.mul %4, %4 : !noisy.i32
       ^
mlir-tutorial/tests/noisy_reduce_noise.mlir:11:8: note: see current operation: %5 = "noisy.mul"(%4, %4) : (!noisy.i32, !noisy.i32) -> !noisy.i32

And if you run in debug mode with --debug --debug-only=int-range-analysis, you will see the per-op propagations printed to the terminal

$  bazel run tools:tutorial-opt -- --noisy-reduce-noise-optimizer $PWD/tests/noisy_reduce_noise.mlir --debug --debug-only=int-range-analysis
Inferring ranges for %c3_i5 = arith.constant 3 : i5
Inferred range unsigned : [3, 3] signed : [3, 3]
Inferring ranges for %c4_i5 = arith.constant 4 : i5
Inferred range unsigned : [4, 4] signed : [4, 4]
Inferring ranges for %0 = noisy.encode %c3_i5 : i5 -> !noisy.i32
Inferred range unsigned : [0, 12] signed : [0, 12]
Inferring ranges for %1 = noisy.encode %c4_i5 : i5 -> !noisy.i32
Inferred range unsigned : [0, 12] signed : [0, 12]
Inferring ranges for %2 = noisy.mul %0, %1 : !noisy.i32
Inferred range unsigned : [0, 24] signed : [0, 24]
Inferring ranges for %3 = noisy.mul %2, %2 : !noisy.i32
Inferred range unsigned : [0, 48] signed : [0, 48]
Inferring ranges for %4 = noisy.mul %3, %3 : !noisy.i32
Inferred range unsigned : [0, 96] signed : [0, 96]
Inferring ranges for %5 = noisy.mul %4, %4 : !noisy.i32
Inferred range unsigned : [0, 192] signed : [0, 192]

As a quick aside, there was one minor upstream problem preventing me from reusing IntegerRangeAnalysis, which I patched in https://github.com/llvm/llvm-project/pull/72007. This means I also had to update the LLVM commit hash used by this project in this commit.

An ILP optimization pass

Next, we build an analysis that solves a global optimization to insert reduce_noise ops efficiently. As mentioned earlier, this is a “do anything” kind of analysis, so we put all of the logic into the analysis’s construtor.

[Aside: I wouldn’t normally do this, because constructors don’t have return values so it’s hard to signal failure; but the API for the analysis specifies the constructor takes as input the Operation * to analyze, and I would expect any properly constructed object to be “ready to use.” Maybe someone who knows C++ better will comment and shed some wisdom for me.]

This commit sets up the analysis shell and interface.

class ReduceNoiseAnalysis {
 public:
  ReduceNoiseAnalysis(Operation *op);
  ~ReduceNoiseAnalysis() = default;

  /// Return true if a reduce_noise op should be inserted after the given
  /// operation, according to the solution to the optimization problem.
  bool shouldInsertReduceNoise(Operation *op) const {
    return solution.lookup(op);
  }

 private:
  llvm::DenseMap<Operation *, bool> solution;
};

This commit adds a workspace dependency on Google’s or-tools package (“OR” stands for Operations Research here, a.k.a. discrete optimization), which comes bundled with a number of nice solvers, and an API for formulating optimization problems. And this commit implements the actual solver model.

Now this model is quite a bit of code, and this article is not the best place to give a full-fledged introduction to linear programming, modeling techniques, or the OR-tools API. What I’ll do instead is explain the model in detail here, give a few small notes on how that translates to the OR-tools C++ API. If you want a gentler background on linear programming, see my article series about diet optimization (part 1, part 2).

All linear programs specify a linear function as an objective to minimize, along with a set of linear equalities and inequalities that constrain the solution. In a standard linear program, the variables must be continuously valued. In a mixed-integer linear program, some of those variables are allowed to be discrete integers, which, it turns out, makes it possible to solve many more problems, but requires completely different optimization techniques and may result in exponentially slow runtime. So many techniques in operations research relate to modeling a problem in such a way that the number of integer variables is relatively small.

Our linear model starts by defining some basic variables. Some variables in the model represent “decisions” that we can make, and others represent “state” that reacts to the decisions via constraints.

  • For each operation $x$, a $\{0, 1\}$-valued variable $\textup{InsertReduceNoise}_x$. Such a variable is 1 if and only if we insert a reduce_noise op after the operation $x$.
  • For each SSA-value $v$ that is input or output to a noisy op, a continuous-valued variable $\textup{NoiseAt}_v$. This represents the upper bound of the noise at value $v$.

In particular, the solver’s performance will get worse as the number of binary variables increases, which in this case corresponds to the number of noisy ops.

The objective function, with a small caveat explained later, is simply the sum of the decision variables, and we’d like to minimize it. Each reduce_noise op is considered equally expensive, and there is no special nuance here about scheduling them in parallel or in serial.

Next, we add constraints. First, $0 \leq \textup{NoiseAt}_v \leq 26$, which asserts that no SSA value can exceed the max noise. Second, we need to enforce that an encode op fixes the noise of its output to 12, i.e., for each encode op $x$, we add the constraint $\textup{NoiseAt}_{\textup{result}(x)} = 12$.

Finally, we need constraints that say that if you choose to insert a reduce_noise op, then the noise is reset to 12, otherwise it is set to the appropriate function of the inputs. This is where the modeling gets a little tricky, but multiplication is easier so let’s start there.

Fix a multiplication op $x$, its two input SSA values $\textup{LHS}, \textup{RHS}$, and its output $\textup{RES}$. As a piecewise function, we want a constraint like:

\[ \textup{NoiseAt}_\textup{RES} = \begin{cases} \textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS} & \text{ if } \textup{InsertReduceNoise}_x = 0 \\ 12 & \text{ if } \textup{InsertReduceNoise}_x = 1 \end{cases} \]

This isn’t linear, but we can combine the two branches to

\[ \begin{aligned} \textup{NoiseAt}_\textup{RES} &= (1 – \textup{ InsertReduceNoise}_x) (\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) \\ & + 12 \textup{ InsertReduceNoise}_x \end{aligned} \]

This does the classic trick of using a bit as a controlled multiplexer, but it’s still not linear. We can make it linear, however, by replacing this one constraint with four constraints, and an auxiliary constant $C=100$ that we know is larger than the possible range of values that the $\textup{NoiseAt}_v$ variables can attain. Those four linear constraints are:

\[ \begin{aligned}
\textup{NoiseAt}_\textup{RES} &\geq 12 \textup{ InsertReduceNoise}_x \\
\textup{NoiseAt}_\textup{RES} &\leq 12 + C(1 – \textup{InsertReduceNoise}_x) \\
\textup{NoiseAt}_\textup{RES} &\geq (\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) – C \textup{ InsertReduceNoise}_x \\
\textup{NoiseAt}_\textup{RES} &\leq (\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) + C \textup{ InsertReduceNoise}_x \\
\end{aligned} \]

Setting the decision variable to zero results in the first two equations being trivially satisfied. Setting it to 1 causes the first two equations to be equivalent to $\textup{NoiseAt}_\textup{RES} = 12$. Likewise, the second two constraints are trivial when the decision variable is 1, and force the output noise to be equal to the sum of the two input noises when set to zero.

The addition op is handled similarly, except that the term $(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS})$ is replaced by something non-linear, namely $1 + \max(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS})$. We can still handle that, but it requires an extra modeling trick. We introduce a new variable $Z_x$ for each add op $x$, and two constraints:

\[ \begin{aligned} Z_v &\geq 1 + \textup{NoiseAt}_{LHS} \\ Z_v &\geq 1 + \textup{NoiseAt}_{RHS} \end{aligned} \]

Together these ensure that $Z_v$ is at least 1 plus the max of the two input noises, but it doesn’t force equality. To achieve that, we add $Z_v$ to the minimization objective (alongside the sum of the decision variables) with a small penalty to ensure the solver tries to minimize them. Since they have trivially minimal values equal to “1 plus the max,” the solver will have no trouble optimizing them, and this will be effectively an equality constraint.

[Aside: Whenever you do this trick, you have to convince yourself that the solver won’t somehow be able to increase $Z_v$ as a trade-off against lower values of other objective terms, and produce a lower overall objective value. Solvers are mischievous and cannot be trusted. In our case, there is no risk: if you were to increase $Z_v$ below its minimum value, that would only increase the noise propagation through add ops, meaning the solver would have to compensate by potentially adding even more reduce_noise ops!]

Then, the constraint for an add op uses $Z_v$ in place of $(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS})$ the mul op.

The only other minor aspect of this solver model is that these constraints enforce the consistency of the noise propagation after a reduce_noise op may be added, but if a reduce_noise op is added, it doesn’t necessarily enforce the noise growth of the output of the op before it’s input to reduce_noise. We can achieve this by adding new constraints expressing $(\textup{NoiseAt}_{LHS} + \textup{NoiseAt}_{RHS}) leq 26$ and $Z_v \leq 26$ for multiplication and addition ops, respectively.

When converting this to the OR-tools C++ API, as we did in this commit, a few minor things to note:

  • You can specify upper and lower bounds on a variable at variable creation time, rather than as separate constraints. You’ll see this in solver->MakeNumVar(min, max, name).
  • Constraints must be specified in the form min <= expr <= max, where min and max are constants and expr is a linear combination of variables, meaning that one has to manually re-arrange and simplify all the equations above so the variables are all on one side and the constants on the other. (The OR-tools Python API is more expressive, but we don’t have it here.)
  • The constraints and the objective are specified by SetCoefficient, which sets the coefficient of a variable in a linear combination one at a time.

Finally, this commit implements the part of the pass that uses the solver’s output to insert new reduce_noise ops. And this commit adds some more tests.

An example of its use:

// This test checks that the solver can find a single insertion point
// for a reduce_noise op that handles two branches, each of which would
// also need a reduce_noise op if handled separately.
func.func @test_single_insertion_branching() -> i5 {
  %0 = arith.constant 3 : i5
  %1 = arith.constant 4 : i5
  %2 = noisy.encode %0 : i5 -> !noisy.i32
  %3 = noisy.encode %1 : i5 -> !noisy.i32
  // Noise: 12
  %4 = noisy.mul %2, %3 : !noisy.i32
  // Noise: 24

  // branch 1
  %b1 = noisy.add %4, %3 : !noisy.i32
  // Noise: 25
  %b2 = noisy.add %b1, %3 : !noisy.i32
  // Noise: 25
  %b3 = noisy.add %b2, %3 : !noisy.i32
  // Noise: 26
  %b4 = noisy.add %b3, %3 : !noisy.i32
  // Noise: 27

  // branch 2
  %c1 = noisy.sub %4, %2 : !noisy.i32
  // Noise: 25
  %c2 = noisy.sub %c1, %3 : !noisy.i32
  // Noise: 25
  %c3 = noisy.sub %c2, %3 : !noisy.i32
  // Noise: 26
  %c4 = noisy.sub %c3, %3 : !noisy.i32
  // Noise: 27

  %x1 = noisy.decode %b4 : !noisy.i32 -> i5
  %x2 = noisy.decode %c4 : !noisy.i32 -> i5
  %x3 = arith.addi %x1, %x2 : i5
  return %x3 : i5
}

And the output:

  func.func @test_single_insertion_branching() -> i5 {
    %c3_i5 = arith.constant 3 : i5
    %c4_i5 = arith.constant 4 : i5
    %0 = noisy.encode %c3_i5 : i5 -> !noisy.i32
    %1 = noisy.encode %c4_i5 : i5 -> !noisy.i32
    %2 = noisy.mul %0, %1 : !noisy.i32
    %3 = noisy.reduce_noise %2 : !noisy.i32
    %4 = noisy.add %3, %1 : !noisy.i32
    %5 = noisy.add %4, %1 : !noisy.i32
    %6 = noisy.add %5, %1 : !noisy.i32
    %7 = noisy.add %6, %1 : !noisy.i32
    %8 = noisy.sub %3, %0 : !noisy.i32
    %9 = noisy.sub %8, %1 : !noisy.i32
    %10 = noisy.sub %9, %1 : !noisy.i32
    %11 = noisy.sub %10, %1 : !noisy.i32
    %12 = noisy.decode %7 : !noisy.i32 -> i5
    %13 = noisy.decode %11 : !noisy.i32 -> i5
    %14 = arith.addi %12, %13 : i5
    return %14 : i5
  }

MLIR — Lowering through LLVM

Table of Contents

In the last article we lowered our custom poly dialect to standard MLIR dialects. In this article we’ll continue lowering it to LLVM IR, exporting it out of MLIR to LLVM, and then compiling to x86 machine code.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Defining a Pipeline

The first step in lowering to machine code is to lower to an “exit dialect.” That is, a dialect from which there is a code-gen tool that converts MLIR code to some non-MLIR format. In our case, we’re targeting x86, so the exit dialect is the LLVM dialect, and the code-gen tool is the binary mlir-translate (more on that later). Lowering to an exit dialect, as it turns out, is not all that simple.

One of the things I’ve struggled with when learning MLIR is how to compose all the different lowerings into a pipeline that ends in the result I want, especially when other people wrote those pipelines. When starting from a high level dialect like linalg (linear algebra), there can be dozens of lowerings involved, and some of them can reintroduce ops from dialects you thought you completely lowered already, or have complex pre-conditions or relationships to other lowerings. There are also simply too many lowerings to easily scan.

There are two ways to specify a lowering from the MLIR binary. One is completely on the command line. You can use the --pass-pipeline flag with a tiny DSL, like this

bazel run //tools:tutorial-opt -- foo.mlir \
  --pass-pipeline='builtin.module(func.func(cse,canonicalize),convert-func-to-llvm)'

Above the thing that looks like a function call is “anchoring” a sequences of passes to operate on a particular op (allowing them to run in parallel across ops).

Thankfully, you can also declare the pipeline in code, and wrap it up in a single flag. The above might be equivalently defined as follows:

void customPipelineBuilder(mlir::OpPassManager &pm) {
  pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanoncalizerPass());
  pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
  pm.addPass(createConvertFuncToLLVMPass());  // runs on builtin.module by default
}

int main(int argc, char **argv) {
  mlir::DialectRegistry registry;
  <... register dialects ...>

  mlir::PassPipelineRegistration<>(
      "my-pipeline", "A custom pipeline", customPipelineBuilder);

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "Tutorial opt main", registry));
}

We’ll do the actually-runnable analogue of this in the next section when we lower the poly dialect to LLVM.

Lowering Poly to LLVM

In this section we’ll define a pipeline lowering poly to LLVM and show the MLIR along each step. Strap in, there’s going to be a lot of MLIR code in this article.

The process I’ve used to build up a big pipeline is rather toilsome and incremental. Basically, start from an empty pipeline and the starting MLIR, then look for the “highest level” op you can think of, add a pass that lowers it to the pipeline, and if that pass fails, figure out what pass is required before it. Then repeat until you have achieved your target.

In this commit we define a pass pipeline --poly-to-llvm that includes only the --poly-to-standard pass defined last time, along with a canonicalization pass. Then we start from this IR:

$ cat $PWD/tests/poly_to_llvm.mlir

func.func @test_poly_fn(%arg : i32) -> i32 {
  %tens = tensor.splat %arg : tensor<10xi32>
  %input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.add %0, %input : !poly.poly<10>
  %2 = poly.mul %1, %1 : !poly.poly<10>
  %3 = poly.sub %2, %input : !poly.poly<10>
  %4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
  return %4 : i32
}

$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir

module {
  func.func @test_poly_fn(%arg0: i32) -> i32 {
    %c11 = arith.constant 11 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %c0_i32 = arith.constant 0 : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %padded = tensor.pad %cst_0 low[0] high[7] {
    ^bb0(%arg1: index):
      tensor.yield %c0_i32 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %0 = arith.addi %padded, %splat : tensor<10xi32>
    %1 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
      %4 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
        %5 = arith.addi %arg1, %arg3 : index
        %6 = arith.remui %5, %c10 : index
        %extracted = tensor.extract %0[%arg3] : tensor<10xi32>
        %extracted_1 = tensor.extract %0[%arg1] : tensor<10xi32>
        %7 = arith.muli %extracted_1, %extracted : i32
        %extracted_2 = tensor.extract %arg4[%6] : tensor<10xi32>
        %8 = arith.addi %7, %extracted_2 : i32
        %inserted = tensor.insert %8 into %arg4[%6] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %4 : tensor<10xi32>
    }
    %2 = arith.subi %1, %splat : tensor<10xi32>
    %3 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
      %4 = arith.subi %c11, %arg1 : index
      %5 = arith.muli %arg0, %arg2 : i32
      %extracted = tensor.extract %2[%4] : tensor<10xi32>
      %6 = arith.addi %5, %extracted : i32
      scf.yield %6 : i32
    }
    return %3 : i32
  }
}

Let’s naively start from the top and see what happens. One way you can try this more interactively is to run a tentative pass to add like bazel run //tools:tutorial-opt -- --poly-to-llvm --pass-to-try $PWD/tests/poly_to_llvm.mlir, and it will run the hard-coded pipeline followed by the new pass.

The first op that can be lowered is func.func, and there is a convert-func-to-llvm pass, which we add in this commit. It turns out to process a lot more than just the func op:

module attributes {llvm.data_layout = ""} {
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(11 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %8 = llvm.mlir.constant(0 : i32) : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %padded = tensor.pad %cst_0 low[0] high[7] {
    ^bb0(%arg1: index):
      tensor.yield %8 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %9 = arith.addi %padded, %splat : tensor<10xi32>
    %10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
        %15 = builtin.unrealized_conversion_cast %arg3 : index to i64
        %16 = llvm.add %13, %15  : i64
        %17 = llvm.urem %16, %4  : i64
        %18 = builtin.unrealized_conversion_cast %17 : i64 to index
        %extracted = tensor.extract %9[%arg3] : tensor<10xi32>
        %extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32>
        %19 = llvm.mul %extracted_1, %extracted  : i32
        %extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32>
        %20 = llvm.add %19, %extracted_2  : i32
        %inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %14 : tensor<10xi32>
    }
    %11 = arith.subi %10, %splat : tensor<10xi32>
    %12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = llvm.sub %0, %13  : i64
      %15 = builtin.unrealized_conversion_cast %14 : i64 to index
      %16 = llvm.mul %arg0, %arg2  : i32
      %extracted = tensor.extract %11[%15] : tensor<10xi32>
      %17 = llvm.add %16, %extracted  : i32
      scf.yield %17 : i32
    }
    llvm.return %12 : i32
  }
}

Notably, this pass converted most of the arithmetic operations—though not the tensor-generating ones that use the dense attribute—and inserted a number of unrealized_conversion_cast ops for the resulting type conflicts, which we’ll have to get rid of eventually, but we can’t now because the values on both sides of the type conversion are still used.

Next we’ll lower arith to LLVM using the suggestively-named convert-arith-to-llvm in this commit. However, it has no effect on the resulting IR, and the arith ops remain. What gives? It turns out that arith ops that operate on tensors are not supported by convert-arith-to-llvm. To deal with this, we need a special pass called convert-elementwise-to-linalg, which lowers these ops to linalg.generic ops. (I plan to cover linalg.generic in a future tutorial).

We add it in this commit, and this is the diff between the above IR and the new output (right > is the new IR):

> #map = affine_map<(d0) -> (d0)>
19c20,24
<     %9 = arith.addi %padded, %splat : tensor<10xi32>
---
>     %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
>     ^bb0(%in: i32, %in_1: i32, %out: i32):
>       %13 = arith.addi %in, %in_1 : i32
>       linalg.yield %13 : i32
>     } -> tensor<10xi32>
37c42,46
<     %11 = arith.subi %10, %splat : tensor<10xi32>
---
>     %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) {
>     ^bb0(%in: i32, %in_1: i32, %out: i32):
>       %13 = arith.subi %in, %in_1 : i32
>       linalg.yield %13 : i32
>     } -> tensor<10xi32>
49a59
>

This is nice, but now you can see we have a new problem: lowering linalg, and re-lowering the arith ops that were inserted by the convert-elementwise-to-linalg pass. So adding back the arith pass at the end in this commit replaces the two inserted arith ops, giving this IR

#map = affine_map<(d0) -> (d0)>
module attributes {llvm.data_layout = ""} {
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(11 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %8 = llvm.mlir.constant(0 : i32) : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %padded = tensor.pad %cst_0 low[0] high[7] {
    ^bb0(%arg1: index):
      tensor.yield %8 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
    ^bb0(%in: i32, %in_1: i32, %out: i32):
      %13 = llvm.add %in, %in_1  : i32
      linalg.yield %13 : i32
    } -> tensor<10xi32>
    %10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
        %15 = builtin.unrealized_conversion_cast %arg3 : index to i64
        %16 = llvm.add %13, %15  : i64
        %17 = llvm.urem %16, %4  : i64
        %18 = builtin.unrealized_conversion_cast %17 : i64 to index
        %extracted = tensor.extract %9[%arg3] : tensor<10xi32>
        %extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32>
        %19 = llvm.mul %extracted_1, %extracted  : i32
        %extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32>
        %20 = llvm.add %19, %extracted_2  : i32
        %inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %14 : tensor<10xi32>
    }
    %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) {
    ^bb0(%in: i32, %in_1: i32, %out: i32):
      %13 = llvm.sub %in, %in_1  : i32
      linalg.yield %13 : i32
    } -> tensor<10xi32>
    %12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = llvm.sub %0, %13  : i64
      %15 = builtin.unrealized_conversion_cast %14 : i64 to index
      %16 = llvm.mul %arg0, %arg2  : i32
      %extracted = tensor.extract %11[%15] : tensor<10xi32>
      %17 = llvm.add %16, %extracted  : i32
      scf.yield %17 : i32
    }
    llvm.return %12 : i32
  }
}

However, there are still two arith ops remaining: arith.constant using dense attributes to define tensors. There’s no obvious pass that handles this from looking at the list of available passes. As we’ll see, it’s the bufferization pass that handles this (we discussed it briefly in the previous article), and we should try to lower as much as possible before bufferizing. In general, bufferizing makes optimization passes harder.

The next thing to try lowering is tensor.splat. There’s no obvious pass as well (tensor-to-linalg doesn’t lower it), and searching the LLVM codebase for SplatOp we find this pattern which suggests tensor.splat is also lowered during bufferization, and produces linalg.map ops. So we’ll have to lower those as well eventually.

But what tensor-to-linalg does lower is the next op: tensor.pad. It replaces a pad with the following

<     %padded = tensor.pad %cst_0 low[0] high[7] {
<     ^bb0(%arg1: index):
<       tensor.yield %8 : i32
<     } : tensor<3xi32> to tensor<10xi32>
<     %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
---
>     %9 = tensor.empty() : tensor<10xi32>
>     %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32>
>     %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32>
>     %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {

We add that pass in this commit.

Next we have all of these linalg ops, as well as the scf.for loops. The linalg ops can lower naively as for loops, so we’ll do that first. There are three options:

  • -convert-linalg-to-affine-loops : Lower the operations from the linalg dialect into affine loops
  • -convert-linalg-to-loops : Lower the operations from the linalg dialect into loops
  • -convert-linalg-to-parallel-loops : Lower the operations from the linalg dialect into parallel loops

I want to make this initial pipeline as simple and naive as possible, so convert-linalg-to-loops it is. However, the pass does nothing on this IR (added in this commit). If you run the pass with --debug you can find an explanation:

 ** Failure : expected linalg op with buffer semantics

So linalg must be bufferized before it can be lowered to loops.

However, we can tackle the scf to LLVM step with a combination of two passes: convert-scf-to-cf and convert-cf-to-llvm, in this commit. The lowering added arith.cmpi for the loop predicates, so moving arith-to-llvm to the end of the pipeline fixes that (commit). In the end we get this IR:

#map = affine_map<(d0) -> (d0)>
module attributes {llvm.data_layout = ""} {
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(11 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %8 = llvm.mlir.constant(0 : i32) : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %9 = tensor.empty() : tensor<10xi32>
    %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32>
    %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32>
    %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {
    ^bb0(%in: i32, %in_4: i32, %out: i32):
      %43 = llvm.add %in, %in_4  : i32
      linalg.yield %43 : i32
    } -> tensor<10xi32>
    cf.br ^bb1(%7, %cst : index, tensor<10xi32>)
  ^bb1(%12: index, %13: tensor<10xi32>):  // 2 preds: ^bb0, ^bb5
    %14 = builtin.unrealized_conversion_cast %12 : index to i64
    %15 = llvm.icmp "slt" %14, %4 : i64
    llvm.cond_br %15, ^bb2, ^bb6
  ^bb2:  // pred: ^bb1
    %16 = builtin.unrealized_conversion_cast %12 : index to i64
    cf.br ^bb3(%7, %13 : index, tensor<10xi32>)
  ^bb3(%17: index, %18: tensor<10xi32>):  // 2 preds: ^bb2, ^bb4
    %19 = builtin.unrealized_conversion_cast %17 : index to i64
    %20 = llvm.icmp "slt" %19, %4 : i64
    llvm.cond_br %20, ^bb4, ^bb5
  ^bb4:  // pred: ^bb3
    %21 = builtin.unrealized_conversion_cast %17 : index to i64
    %22 = llvm.add %16, %21  : i64
    %23 = llvm.urem %22, %4  : i64
    %24 = builtin.unrealized_conversion_cast %23 : i64 to index
    %extracted = tensor.extract %11[%17] : tensor<10xi32>
    %extracted_1 = tensor.extract %11[%12] : tensor<10xi32>
    %25 = llvm.mul %extracted_1, %extracted  : i32
    %extracted_2 = tensor.extract %18[%24] : tensor<10xi32>
    %26 = llvm.add %25, %extracted_2  : i32
    %inserted = tensor.insert %26 into %18[%24] : tensor<10xi32>
    %27 = llvm.add %19, %2  : i64
    %28 = builtin.unrealized_conversion_cast %27 : i64 to index
    cf.br ^bb3(%28, %inserted : index, tensor<10xi32>)
  ^bb5:  // pred: ^bb3
    %29 = llvm.add %14, %2  : i64
    %30 = builtin.unrealized_conversion_cast %29 : i64 to index
    cf.br ^bb1(%30, %18 : index, tensor<10xi32>)
  ^bb6:  // pred: ^bb1
    %31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%13, %splat : tensor<10xi32>, tensor<10xi32>) outs(%13 : tensor<10xi32>) {
    ^bb0(%in: i32, %in_4: i32, %out: i32):
      %43 = llvm.sub %in, %in_4  : i32
      linalg.yield %43 : i32
    } -> tensor<10xi32>
    cf.br ^bb7(%3, %8 : index, i32)
  ^bb7(%32: index, %33: i32):  // 2 preds: ^bb6, ^bb8
    %34 = builtin.unrealized_conversion_cast %32 : index to i64
    %35 = llvm.icmp "slt" %34, %0 : i64
    llvm.cond_br %35, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %36 = builtin.unrealized_conversion_cast %32 : index to i64
    %37 = llvm.sub %0, %36  : i64
    %38 = builtin.unrealized_conversion_cast %37 : i64 to index
    %39 = llvm.mul %arg0, %33  : i32
    %extracted_3 = tensor.extract %31[%38] : tensor<10xi32>
    %40 = llvm.add %39, %extracted_3  : i32
    %41 = llvm.add %34, %2  : i64
    %42 = builtin.unrealized_conversion_cast %41 : i64 to index
    cf.br ^bb7(%42, %40 : index, i32)
  ^bb9:  // pred: ^bb7
    llvm.return %33 : i32
  }
}

Bufferization

Last time I wrote at some length about the dialect conversion framework, and how it’s complicated mostly because of the bufferization passes, which were split across passes that type conflicts that required materialization and later resolution.

Well, turns out these bufferization passes are now deprecated. These are the official docs, but the short story is that the network of bufferization passes was replaced by a one-shot-bufferize pass, with some boilerplate cleanup passes. This is the sequence of passes that is recommended by the docs, along with an option to ensure that function signatures are bufferized as well:

  // One-shot bufferize, from
  // https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation
  bufferization::OneShotBufferizationOptions bufferizationOptions;
  bufferizationOptions.bufferizeFunctionBoundaries = true;
  manager.addPass(
      bufferization::createOneShotBufferizePass(bufferizationOptions));
  manager.addPass(memref::createExpandReallocPass());
  manager.addPass(bufferization::createOwnershipBasedBufferDeallocationPass());
  manager.addPass(createCanonicalizerPass());
  manager.addPass(bufferization::createBufferDeallocationSimplificationPass());
  manager.addPass(bufferization::createLowerDeallocationsPass());
  manager.addPass(createCSEPass());
  manager.addPass(createCanonicalizerPass());

This pipeline exists upstream in a helper, but it was added after the commit we’ve pinned to. Moreover, some of the passes above were added after the commit we’ve pinned to! At this point it’s worth updating the upstream MLIR commit, which I did in this commit, and it required a few additional fixes to deal with API updates (starting with this commit and ending with this commit). Then this commit adds the one-shot bufferization pass and helpers to the pipeline.

Even after all of this, I was dismayed to learn that the pass still did not lower the linalg operators. After a bit of digging, I realized it was because the func-to-llvm pass was running too early in the pipeline. So this commit moves the pass later. Just because we’re building this a bit backwards and naively, let’s comment out the tail end of the pipeline to see the result of bufferization.

Before convert-linalg-to-loops, but omitting the rest:

#map = affine_map<(d0) -> (d0)>
module {
  memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
  memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
  func.func @test_poly_fn(%arg0: i32) -> i32 {
    %c0_i32 = arith.constant 0 : i32
    %c11 = arith.constant 11 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %0 = memref.get_global @__constant_10xi32 : memref<10xi32>
    %1 = memref.get_global @__constant_3xi32 : memref<3xi32>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    linalg.map outs(%alloc : memref<10xi32>)
      () {
        linalg.yield %arg0 : i32
      }
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc_0 : memref<10xi32>)
    %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
    memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
    linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_0, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_0 : memref<10xi32>) {
    ^bb0(%in: i32, %in_2: i32, %out: i32):
      %3 = arith.addi %in, %in_2 : i32
      linalg.yield %3 : i32
    }
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      scf.for %arg2 = %c0 to %c10 step %c1 {
        %3 = arith.addi %arg1, %arg2 : index
        %4 = arith.remui %3, %c10 : index
        %5 = memref.load %alloc_0[%arg2] : memref<10xi32>
        %6 = memref.load %alloc_0[%arg1] : memref<10xi32>
        %7 = arith.muli %6, %5 : i32
        %8 = memref.load %alloc_1[%4] : memref<10xi32>
        %9 = arith.addi %7, %8 : i32
        memref.store %9, %alloc_1[%4] : memref<10xi32>
      }
    }
    linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_1, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_1 : memref<10xi32>) {
    ^bb0(%in: i32, %in_2: i32, %out: i32):
      %3 = arith.subi %in, %in_2 : i32
      linalg.yield %3 : i32
    }
    %2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
      %3 = arith.subi %c11, %arg1 : index
      %4 = arith.muli %arg0, %arg2 : i32
      %5 = memref.load %alloc_1[%3] : memref<10xi32>
      %6 = arith.addi %4, %5 : i32
      scf.yield %6 : i32
    }
    memref.dealloc %alloc : memref<10xi32>
    memref.dealloc %alloc_0 : memref<10xi32>
    memref.dealloc %alloc_1 : memref<10xi32>
    return %2 : i32
  }
}

After convert-linalg-to-loops, omitting the conversions to LLVM, wherein there are no longer any linalg ops, just loops:

module {
  memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
  memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
  func.func @test_poly_fn(%arg0: i32) -> i32 {
    %c0 = arith.constant 0 : index
    %c10 = arith.constant 10 : index
    %c1 = arith.constant 1 : index
    %c0_i32 = arith.constant 0 : i32
    %c11 = arith.constant 11 : index
    %0 = memref.get_global @__constant_10xi32 : memref<10xi32>
    %1 = memref.get_global @__constant_3xi32 : memref<3xi32>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      memref.store %arg0, %alloc[%arg1] : memref<10xi32>
    }
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      memref.store %c0_i32, %alloc_0[%arg1] : memref<10xi32>
    }
    %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
    memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      %3 = memref.load %alloc_0[%arg1] : memref<10xi32>
      %4 = memref.load %alloc[%arg1] : memref<10xi32>
      %5 = arith.addi %3, %4 : i32
      memref.store %5, %alloc_0[%arg1] : memref<10xi32>
    }
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      scf.for %arg2 = %c0 to %c10 step %c1 {
        %3 = arith.addi %arg1, %arg2 : index
        %4 = arith.remui %3, %c10 : index
        %5 = memref.load %alloc_0[%arg2] : memref<10xi32>
        %6 = memref.load %alloc_0[%arg1] : memref<10xi32>
        %7 = arith.muli %6, %5 : i32
        %8 = memref.load %alloc_1[%4] : memref<10xi32>
        %9 = arith.addi %7, %8 : i32
        memref.store %9, %alloc_1[%4] : memref<10xi32>
      }
    }
    scf.for %arg1 = %c0 to %c10 step %c1 {
      %3 = memref.load %alloc_1[%arg1] : memref<10xi32>
      %4 = memref.load %alloc[%arg1] : memref<10xi32>
      %5 = arith.subi %3, %4 : i32
      memref.store %5, %alloc_1[%arg1] : memref<10xi32>
    }
    %2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
      %3 = arith.subi %c11, %arg1 : index
      %4 = arith.muli %arg0, %arg2 : i32
      %5 = memref.load %alloc_1[%3] : memref<10xi32>
      %6 = arith.addi %4, %5 : i32
      scf.yield %6 : i32
    }
    memref.dealloc %alloc : memref<10xi32>
    memref.dealloc %alloc_0 : memref<10xi32>
    memref.dealloc %alloc_1 : memref<10xi32>
    return %2 : i32
  }
}

And then finally, the rest of the pipeline, as defined so far.

module {
  memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
  memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(0 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(10 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(1 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : i32) : i32
    %7 = llvm.mlir.constant(11 : index) : i64
    %8 = builtin.unrealized_conversion_cast %7 : i64 to index
    %9 = memref.get_global @__constant_10xi32 : memref<10xi32>
    %10 = memref.get_global @__constant_3xi32 : memref<3xi32>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    cf.br ^bb1(%1 : index)
  ^bb1(%11: index):  // 2 preds: ^bb0, ^bb2
    %12 = builtin.unrealized_conversion_cast %11 : index to i64
    %13 = llvm.icmp "slt" %12, %2 : i64
    llvm.cond_br %13, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    memref.store %arg0, %alloc[%11] : memref<10xi32>
    %14 = llvm.add %12, %4  : i64
    %15 = builtin.unrealized_conversion_cast %14 : i64 to index
    cf.br ^bb1(%15 : index)
  ^bb3:  // pred: ^bb1
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    cf.br ^bb4(%1 : index)
  ^bb4(%16: index):  // 2 preds: ^bb3, ^bb5
    %17 = builtin.unrealized_conversion_cast %16 : index to i64
    %18 = llvm.icmp "slt" %17, %2 : i64
    llvm.cond_br %18, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    memref.store %6, %alloc_0[%16] : memref<10xi32>
    %19 = llvm.add %17, %4  : i64
    %20 = builtin.unrealized_conversion_cast %19 : i64 to index
    cf.br ^bb4(%20 : index)
  ^bb6:  // pred: ^bb4
    %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
    memref.copy %10, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
    cf.br ^bb7(%1 : index)
  ^bb7(%21: index):  // 2 preds: ^bb6, ^bb8
    %22 = builtin.unrealized_conversion_cast %21 : index to i64
    %23 = llvm.icmp "slt" %22, %2 : i64
    llvm.cond_br %23, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %24 = memref.load %alloc_0[%21] : memref<10xi32>
    %25 = memref.load %alloc[%21] : memref<10xi32>
    %26 = llvm.add %24, %25  : i32
    memref.store %26, %alloc_0[%21] : memref<10xi32>
    %27 = llvm.add %22, %4  : i64
    %28 = builtin.unrealized_conversion_cast %27 : i64 to index
    cf.br ^bb7(%28 : index)
  ^bb9:  // pred: ^bb7
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    memref.copy %9, %alloc_1 : memref<10xi32> to memref<10xi32>
    cf.br ^bb10(%1 : index)
  ^bb10(%29: index):  // 2 preds: ^bb9, ^bb14
    %30 = builtin.unrealized_conversion_cast %29 : index to i64
    %31 = llvm.icmp "slt" %30, %2 : i64
    llvm.cond_br %31, ^bb11, ^bb15
  ^bb11:  // pred: ^bb10
    %32 = builtin.unrealized_conversion_cast %29 : index to i64
    cf.br ^bb12(%1 : index)
  ^bb12(%33: index):  // 2 preds: ^bb11, ^bb13
    %34 = builtin.unrealized_conversion_cast %33 : index to i64
    %35 = llvm.icmp "slt" %34, %2 : i64
    llvm.cond_br %35, ^bb13, ^bb14
  ^bb13:  // pred: ^bb12
    %36 = builtin.unrealized_conversion_cast %33 : index to i64
    %37 = llvm.add %32, %36  : i64
    %38 = llvm.urem %37, %2  : i64
    %39 = builtin.unrealized_conversion_cast %38 : i64 to index
    %40 = memref.load %alloc_0[%33] : memref<10xi32>
    %41 = memref.load %alloc_0[%29] : memref<10xi32>
    %42 = llvm.mul %41, %40  : i32
    %43 = memref.load %alloc_1[%39] : memref<10xi32>
    %44 = llvm.add %42, %43  : i32
    memref.store %44, %alloc_1[%39] : memref<10xi32>
    %45 = llvm.add %34, %4  : i64
    %46 = builtin.unrealized_conversion_cast %45 : i64 to index
    cf.br ^bb12(%46 : index)
  ^bb14:  // pred: ^bb12
    %47 = llvm.add %30, %4  : i64
    %48 = builtin.unrealized_conversion_cast %47 : i64 to index
    cf.br ^bb10(%48 : index)
  ^bb15:  // pred: ^bb10
    cf.br ^bb16(%1 : index)
  ^bb16(%49: index):  // 2 preds: ^bb15, ^bb17
    %50 = builtin.unrealized_conversion_cast %49 : index to i64
    %51 = llvm.icmp "slt" %50, %2 : i64
    llvm.cond_br %51, ^bb17, ^bb18
  ^bb17:  // pred: ^bb16
    %52 = memref.load %alloc_1[%49] : memref<10xi32>
    %53 = memref.load %alloc[%49] : memref<10xi32>
    %54 = llvm.sub %52, %53  : i32
    memref.store %54, %alloc_1[%49] : memref<10xi32>
    %55 = llvm.add %50, %4  : i64
    %56 = builtin.unrealized_conversion_cast %55 : i64 to index
    cf.br ^bb16(%56 : index)
  ^bb18:  // pred: ^bb16
    cf.br ^bb19(%5, %6 : index, i32)
  ^bb19(%57: index, %58: i32):  // 2 preds: ^bb18, ^bb20
    %59 = builtin.unrealized_conversion_cast %57 : index to i64
    %60 = llvm.icmp "slt" %59, %7 : i64
    llvm.cond_br %60, ^bb20, ^bb21
  ^bb20:  // pred: ^bb19
    %61 = builtin.unrealized_conversion_cast %57 : index to i64
    %62 = llvm.sub %7, %61  : i64
    %63 = builtin.unrealized_conversion_cast %62 : i64 to index
    %64 = llvm.mul %arg0, %58  : i32
    %65 = memref.load %alloc_1[%63] : memref<10xi32>
    %66 = llvm.add %64, %65  : i32
    %67 = llvm.add %59, %4  : i64
    %68 = builtin.unrealized_conversion_cast %67 : i64 to index
    cf.br ^bb19(%68, %66 : index, i32)
  ^bb21:  // pred: ^bb19
    memref.dealloc %alloc : memref<10xi32>
    memref.dealloc %alloc_0 : memref<10xi32>
    memref.dealloc %alloc_1 : memref<10xi32>
    llvm.return %58 : i32
  }
}

The remaining problems:

  • There are cf.br ops left in there, meaning lower-cf-to-llvm was unable to convert them, leaving in a bunch of un-removable index types (see third bullet).
  • We still have a memref.subview that is not supported in LLVM
  • We have a bunch of casts like builtin.unrealized_conversion_cast %7 : i64 to index, which are because index is not part of LLVM.

The second needs a special pass, memref-expand-strided-metadata, which reduces more complicated memref ops to simpler ones that can be lowered. The third is fixed by using finalize-memref-to-llvm, which lowers index to llvm.ptr and memref to llvm.struct and llvm.array. A final reconcile-unrealized-casts removes the cast operations, provided they can safely be removed. Both are combined in this commit.

However, the first one still eluded me for a while, until I figured out through embarrassing trial and error that again func-to-llvm was too early in the pipeline. Moving it to the end (this commit) resulted in cf.br being lowered to llvm.br and llvm.cond_br.

Finally, this commit adds a set of standard cleanup passes, including constant propagation, common subexpression elimination, dead code elimination, and canonicalization. The final IR looks like this

module {
  llvm.func @free(!llvm.ptr)
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.mlir.global private constant @__constant_3xi32(dense<[2, 3, 4]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<3 x i32>
  llvm.mlir.global private constant @__constant_10xi32(dense<0> : tensor<10xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<10 x i32>
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.constant(64 : index) : i64
    %2 = llvm.mlir.constant(3 : index) : i64
    %3 = llvm.mlir.constant(0 : index) : i64
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = llvm.mlir.constant(1 : index) : i64
    %6 = llvm.mlir.constant(11 : index) : i64
    %7 = llvm.mlir.addressof @__constant_10xi32 : !llvm.ptr
    %8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
    %9 = llvm.mlir.addressof @__constant_3xi32 : !llvm.ptr
    %10 = llvm.getelementptr %9[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<3 x i32>
    %11 = llvm.mlir.zero : !llvm.ptr
    %12 = llvm.getelementptr %11[10] : (!llvm.ptr) -> !llvm.ptr, i32
    %13 = llvm.ptrtoint %12 : !llvm.ptr to i64
    %14 = llvm.add %13, %1  : i64
    %15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %16 = llvm.ptrtoint %15 : !llvm.ptr to i64
    %17 = llvm.sub %1, %5  : i64
    %18 = llvm.add %16, %17  : i64
    %19 = llvm.urem %18, %1  : i64
    %20 = llvm.sub %18, %19  : i64
    %21 = llvm.inttoptr %20 : i64 to !llvm.ptr
    llvm.br ^bb1(%3 : i64)
  ^bb1(%22: i64):  // 2 preds: ^bb0, ^bb2
    %23 = llvm.icmp "slt" %22, %4 : i64
    llvm.cond_br %23, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    %24 = llvm.getelementptr %21[%22] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    llvm.store %arg0, %24 : i32, !llvm.ptr
    %25 = llvm.add %22, %5  : i64
    llvm.br ^bb1(%25 : i64)
  ^bb3:  // pred: ^bb1
    %26 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %27 = llvm.ptrtoint %26 : !llvm.ptr to i64
    %28 = llvm.add %27, %17  : i64
    %29 = llvm.urem %28, %1  : i64
    %30 = llvm.sub %28, %29  : i64
    %31 = llvm.inttoptr %30 : i64 to !llvm.ptr
    llvm.br ^bb4(%3 : i64)
  ^bb4(%32: i64):  // 2 preds: ^bb3, ^bb5
    %33 = llvm.icmp "slt" %32, %4 : i64
    llvm.cond_br %33, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    %34 = llvm.getelementptr %31[%32] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    llvm.store %0, %34 : i32, !llvm.ptr
    %35 = llvm.add %32, %5  : i64
    llvm.br ^bb4(%35 : i64)
  ^bb6:  // pred: ^bb4
    %36 = llvm.mul %2, %5  : i64
    %37 = llvm.getelementptr %11[1] : (!llvm.ptr) -> !llvm.ptr, i32
    %38 = llvm.ptrtoint %37 : !llvm.ptr to i64
    %39 = llvm.mul %36, %38  : i64
    "llvm.intr.memcpy"(%31, %10, %39) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
    llvm.br ^bb7(%3 : i64)
  ^bb7(%40: i64):  // 2 preds: ^bb6, ^bb8
    %41 = llvm.icmp "slt" %40, %4 : i64
    llvm.cond_br %41, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %42 = llvm.getelementptr %31[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %43 = llvm.load %42 : !llvm.ptr -> i32
    %44 = llvm.getelementptr %21[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %45 = llvm.load %44 : !llvm.ptr -> i32
    %46 = llvm.add %43, %45  : i32
    llvm.store %46, %42 : i32, !llvm.ptr
    %47 = llvm.add %40, %5  : i64
    llvm.br ^bb7(%47 : i64)
  ^bb9:  // pred: ^bb7
    %48 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %49 = llvm.ptrtoint %48 : !llvm.ptr to i64
    %50 = llvm.add %49, %17  : i64
    %51 = llvm.urem %50, %1  : i64
    %52 = llvm.sub %50, %51  : i64
    %53 = llvm.inttoptr %52 : i64 to !llvm.ptr
    %54 = llvm.mul %4, %5  : i64
    %55 = llvm.mul %54, %38  : i64
    "llvm.intr.memcpy"(%53, %8, %55) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
    llvm.br ^bb10(%3 : i64)
  ^bb10(%56: i64):  // 2 preds: ^bb9, ^bb14
    %57 = llvm.icmp "slt" %56, %4 : i64
    llvm.cond_br %57, ^bb11, ^bb15
  ^bb11:  // pred: ^bb10
    llvm.br ^bb12(%3 : i64)
  ^bb12(%58: i64):  // 2 preds: ^bb11, ^bb13
    %59 = llvm.icmp "slt" %58, %4 : i64
    llvm.cond_br %59, ^bb13, ^bb14
  ^bb13:  // pred: ^bb12
    %60 = llvm.add %56, %58  : i64
    %61 = llvm.urem %60, %4  : i64
    %62 = llvm.getelementptr %31[%58] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %63 = llvm.load %62 : !llvm.ptr -> i32
    %64 = llvm.getelementptr %31[%56] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %65 = llvm.load %64 : !llvm.ptr -> i32
    %66 = llvm.mul %65, %63  : i32
    %67 = llvm.getelementptr %53[%61] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %68 = llvm.load %67 : !llvm.ptr -> i32
    %69 = llvm.add %66, %68  : i32
    llvm.store %69, %67 : i32, !llvm.ptr
    %70 = llvm.add %58, %5  : i64
    llvm.br ^bb12(%70 : i64)
  ^bb14:  // pred: ^bb12
    %71 = llvm.add %56, %5  : i64
    llvm.br ^bb10(%71 : i64)
  ^bb15:  // pred: ^bb10
    llvm.br ^bb16(%3 : i64)
  ^bb16(%72: i64):  // 2 preds: ^bb15, ^bb17
    %73 = llvm.icmp "slt" %72, %4 : i64
    llvm.cond_br %73, ^bb17, ^bb18
  ^bb17:  // pred: ^bb16
    %74 = llvm.getelementptr %53[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %75 = llvm.load %74 : !llvm.ptr -> i32
    %76 = llvm.getelementptr %21[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %77 = llvm.load %76 : !llvm.ptr -> i32
    %78 = llvm.sub %75, %77  : i32
    llvm.store %78, %74 : i32, !llvm.ptr
    %79 = llvm.add %72, %5  : i64
    llvm.br ^bb16(%79 : i64)
  ^bb18:  // pred: ^bb16
    llvm.br ^bb19(%5, %0 : i64, i32)
  ^bb19(%80: i64, %81: i32):  // 2 preds: ^bb18, ^bb20
    %82 = llvm.icmp "slt" %80, %6 : i64
    llvm.cond_br %82, ^bb20, ^bb21
  ^bb20:  // pred: ^bb19
    %83 = llvm.sub %6, %80  : i64
    %84 = llvm.mul %arg0, %81  : i32
    %85 = llvm.getelementptr %53[%83] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %86 = llvm.load %85 : !llvm.ptr -> i32
    %87 = llvm.add %84, %86  : i32
    %88 = llvm.add %80, %5  : i64
    llvm.br ^bb19(%88, %87 : i64, i32)
  ^bb21:  // pred: ^bb19
    llvm.call @free(%15) : (!llvm.ptr) -> ()
    llvm.call @free(%26) : (!llvm.ptr) -> ()
    llvm.call @free(%48) : (!llvm.ptr) -> ()
    llvm.return %81 : i32
  }
}

Exiting MLIR

In the MLIR parlance, the LLVM dialect is an “exit” dialect, meaning after lowering to LLVM you run a different tool that generates code used by an external system. In our case, this will be LLVM’s internal representation (“LLVM IR” which is different from the LLVM MLIR dialect). The “code gen” step is often called translation in the MLIR docs. Here are the official docs on generating LLVM IR.

The codegen tool is called mlir-translate, and it has an --mlir-to-llvmir option. Running the command below on our output IR above gives the IR in this gist.

$  bazel build @llvm-project//mlir:mlir-translate
$  bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir

Next, to compile LLVM IR with LLVM directly, we use the llc tool that has the --filetype-obj option to emit an object file. Without that flag, you can see a textual representation of the machine code

$ bazel build @llvm-project//llvm:llc
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc

# textual representation
        .text
        .file   "LLVMDialectModule"
        .globl  test_poly_fn                    # -- Begin function test_poly_fn
        .p2align        4, 0x90
        .type   test_poly_fn,@function
test_poly_fn:                           # @test_poly_fn
        .cfi_startproc
# %bb.0:
        pushq   %rbp
        .cfi_def_cfa_offset 16
        pushq   %r15
        .cfi_def_cfa_offset 24
        pushq   %r14
        .cfi_def_cfa_offset 32
        pushq   %r13
        .cfi_def_cfa_offset 40
        pushq   %r12
        .cfi_def_cfa_offset 48
        pushq   %rbx
        .cfi_def_cfa_offset 56
        pushq   %rax
        .cfi_def_cfa_offset 64
        .cfi_offset %rbx, -56
        .cfi_offset %r12, -48
<snip>

And finally, we can save the object file and compile and link it to a C main that calls the function and prints the result.

$ cat tests/poly_to_llvm_main.c
#include <stdio.h>

// This is the function we want to call from LLVM
int test_poly_fn(int x);

int main(int argc, char *argv[]) {
  int i = 1;
  int result = test_poly_fn(i);
  printf("Result: %d\n", result);
  return 0;
}
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc --filetype=obj > poly_to_llvm.o
$ clang -c tests/poly_to_llvm_main.c && clang poly_to_llvm_main.o poly_to_llvm.o -o a.out
$ ./a.out
Result: 320

The polynomial computed by the test function is the rather bizarre:

((1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9) + (2 + 3*x + 4*x**2))**2 - (1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9)

But computing this mod $x^{10} – 1$ in Sympy, we get… 351. Uh oh. So somewhere along the way we messed up the lowering.

Before we fix it, let’s encode the whole process as a lit test. This requires making the binaries mlir-translate, llc, and clang available to the test runner, and setting up the RUN pipeline inside the test file. This is contained in this commit. Note %t tells lit to generate a test-unique temporary file.

// RUN: tutorial-opt --poly-to-llvm %s | mlir-translate --mlir-to-llvmir | llc -filetype=obj > %t
// RUN: clang -c poly_to_llvm_main.c
// RUN: clang poly_to_llvm_main.o %t -o a.out
// RUN: ./a.out | FileCheck

// CHECK: 351
func.func @test_poly_fn(%arg : i32) -> i32 {
  %tens = tensor.splat %arg : tensor<10xi32>
  %input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.add %0, %input : !poly.poly<10>
  %2 = poly.mul %1, %1 : !poly.poly<10>
  %3 = poly.sub %2, %input : !poly.poly<10>
  %4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
  return %4 : i32
}

Then running this as a normal lit test fails as

error: CHECK: expected string not found in input
# | // CHECK: 351
# |           ^
# | <stdin>:1:1: note: scanning from here
# | Result: 320
# | ^
# | <stdin>:1:9: note: possible intended match here
# | Result: 320

Fixing the bug in the lowering

How do you find the bug in a lowering? Apparently there are some folks doing this via formal verification, formalizing the dialects and the lowerings in lean to prove correctness. I don’t have the time for that in this tutorial, so instead I simplify/expand the tests and squint.

Simplifying the test to computing the simpler function $t \mapsto 1 + t + t^2$, I see that an input of 1 gives 2, when it should be 3. This suggests that eval is lowered wrong, at least. An input of 5 gives 4, so they’re all off by one term. This likely means I have an off-by-one error in the loop that eval lowers to, and indeed that’s the problem. I was essentially doing this, where $N$ is the degree of the polynomial ($N-1$ is the largest legal index into the tensor):

accum = 0
for 1 <= i < N+1
  index = N+1 - i
  accum = accum * point + coeffs[index]

When it should have been

accum = 0
for 1 <= i < N+1
  index = N - i
  accum = accum * point + coeffs[index]

This commit fixes it, and now all the tests pass.

Aside: I don’t know of a means to do more resilient testing in MLIR. For example, I don’t know of a fuzz testing or property testing option. In my main project, HEIR, I hand-rolled a simple test generation routine that reads a config file and spits out MLIR test files. That allows me to jot down a bunch of tests quickly, but doesn’t give the full power of a system like hypothesis, which is a framework I’m quite fond of. I think the xDSL project would work in my favor here, but I have yet to dive into that, and as far as I know it requires re-defining all your custom dialects in Python to use. More on that on a future article.

Taking a step back

This article showed a bit of a naive approach to building up a dialect conversion pipeline, where we just greedily looked for ops to lower and inserted the relevant passes somewhat haphazardly. That worked out OK, but some lower-level passes (converting to LLVM) confounded the overall pipeline when placed too early.

A better approach is to identify the highest level operations and lower those first. But that is only really possible if you already know which passes are available and what they do. For example, elementwise-to-linalg takes something that seems low-level a prior—aritil noticing that convert-arith-to-llvm silently ignored those ops. Similarly, the implications of converting func to LLVM (which appears to handle more than just the ops in func, were not clear until we tried it and ran into problems.

I don’t have a particularly good solution here besides trial and error. But I appreciate good tools, so I will amuse myself with some ideas.

Since most passes have per-op patterns, it seems like one could write a tool that analyzes an IR, simulates running all possible rewrite patterns from the standard list of passes, and checks which ones succeeded (i.e., the “match” was successful, though most patterns are a combined matchAndRewrite), and what op types were generated as a result. Then once you get to an IR that you don’t know how to lower further, you could run the tool and it would tell you all of your options.

An even more aggressive tool would could construct a complete graph of op-to-op conversions. You could identify an op to lower and a subset of legal dialects or ops (similar to a ConversionTarget), and it would report all possible paths from the op to your desired target.

I imagine performance is an obstacle here, and I also wonder to what extent this could be statically analyzed at least coarsely. For example, instead of running the rewrites, you could statically analyze the implementations of rewrite patterns lowering FooOp for occurrences of create<BarOp>, and just include an edge FooOp -> BarOp as a conservative estimate (it might not generate BarOp for every input).

Plans

At this point we’ve covered the basic ground of MLIR: defining a new dialect, writing some optimization passes, doing some lowerings, and compiling and running a proper binary. The topics we could cover from here on out are rather broad, but here’s a few options:

  • Writing an analysis pass.
  • Writing a global program optimization pass with a cost model, in contrast to the local rewrites we’ve done so far.
  • Defining a custom interface (trait with input/output) to write a generic pass that applies to multiple dialects.
  • Explaining the linear algebra dialect and linalg.generic.
  • Exploring some of the front ends for MLIR, like Polygeist and ClangIR.
  • Exploring the Python bindings or C API for MLIR.
  • Exploring some of the peripheral and in-progress projects around MLIR like PDLL, IRDL, xDSL, etc.
  • Diving into the details of various upstream MLIR optimizations, like the polyhedral analysis or constant propagation passes.
  • Working with sparse tensors.
  • Covering more helper tools like the diagnostic infrastructure and bytecode generation.

MLIR — Dialect Conversion

Table of Contents

In previous articles we defined a dialect, and wrote various passes to optimize and canonicalize a program using that dialect. However, one of the main tenets of MLIR is “incremental lowering,” the idea that there are lots of levels of IR granularity, and you incrementally lower different parts of the IR, only discarding information when it’s no longer useful for optimizations. In this article we’ll see the first step of that: lowering the poly dialect to a combination of standard MLIR dialects, using the so-called dialect conversion infrastructure to accomplish it.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Side note: Victor Guerra kindly contributed a CMake build overlay to the mlir-tutorial project, in this PR, which I will be maintaining for the rest of the tutorial series. I still don’t know CMake well enough to give a tutorial on it, but I did work through parts of the official tutorial which seems excellent.

The type obstacle

If not for types, dialect conversion (a.k.a. lowering) would be essentially the same as a normal pass: you write some rewrite patterns and apply them to the IR. You’d typically have one rewrite pattern per operation that needs to be lowered.

Types make this significantly more complex, and I’ll demonstrate the issue by example with poly.

In poly we have a poly.add operation that takes two values of poly.poly type, and returns a value of poly.poly type. We want to lower poly.add to, say, a vectorized loop of basic arithmetic operations. We can do this, and we will do it below, by rewriting poly.add as an arith.addi. However, arith.addi doesn’t know about poly.poly and likely never will.

Aside: if we had to make arith extensible to know about poly, we’d need to upstream a change to the arith.addi op’s operands to allow types that implement some sort of interface like “integer-like, or containers of integer-like things,” but when I suggested that to some MLIR folks they gently demurred. I believe this is because the MLIR community feels that the arith dialect is too overloaded.

So, in addition to lowering the op, we need to lower the poly.poly<N> type to something like tensor<Nxi32>. And this is where the type obstacle comes into play. Once you change a specific value’s type, say, while lowering an op that produces that value as output, then all downstream users of that value are still expecting the old type and are technically invalid until they are lowered. In between each pass MLIR runs verifiers to ensure the IR is valid, so without some special handling, this means all the types and ops need to be converted in one pass, or else these verifiers would fail. But managing all that with standard rewrite rules would be hard: for each op rewrite rule, you’d have to continually check if the arguments and results have been converted yet or not.

MLIR handles this situation with a wrapper around standard passes that they call the dialect conversion framework. Official docs here. They require the user of the framework to inherit from different classes normal rewrites, set up some additional metadata, and separate type conversion from op conversion in a specific way we’ll see shortly. But at a high level, this framework works by lowering ops in a certain sorted order, converting the types as they go, and giving the op converters access to both the original types of each op as well as what the in-progress converted types look like at the time the op is visited by the framework. Each op-based rewrite pattern is expected to make that op type-legal after it’s visited, but it need not worry about downstream ops.

Finally, the dialect conversion framework keeps track of any type conflicts that aren’t resolved, and if any remain at the end of the pass, one of two things happens. The conversion framework allows one to optionally implement what’s called a type materializer, which inserts new intermediate ops that resolve type conflicts. So the first possibility is that the dialect conversion framework uses your type materializer hooks to patch up the IR, and the pass ends successfully. If those hooks fail, or if you didn’t define any hooks, then the pass fails and complains that it couldn’t fix the type conflicts.

Part of the complexity of this infrastructure also has to do with one of the harder lowering pipelines in upstream MLIR: the bufferization pipeline. This pipeline essentially converts an IR that uses ops with “value semantics” into one with “pointer semantics.” For example, the tensor type and its associated operations have value semantics, meaning each op is semantically creating a brand new tensor in its output, and all operations are pure (with some caveats). On the other hand, memref has pointer semantics, meaning it’s modeling the physical hardware more closely, requires explicit memory allocation, and supports operations that mutate memory locations.

Because bufferization is complicated, it is split into multiple sub-passes that handle bufferization issues specific to each of the relevant upstream MLIR dialects (see in the docs, e.g., arith-bufferize, func-bufferize, etc.). Each of these bufferization passes creates some internally-unresolvable type conflicts, which require custom type materializations to resolve. And to juggle all these issues across all the relevant dialects, the MLIR folks built a dedicated dialect called bufferization to house the intermediate ops. You’ll notice ops like to_memref and to_tensor that serve this role. And then there is a finalizing-bufferize pass whose role is to clean up any lingering bufferization/materialization ops.

There was a talk at the 2020-11-19 MLIR Open Design meeting called “Type Conversions the Not-So-Hard Way” (slides) that helped me understand these details, but I couldn’t understand the talk much before I tried a few naive lowering attempts, and then even after I felt I understood the talk I ran into some issues around type materialization. So my aim in the rest of this article is to explain the clarity I’ve found, such as it may be, and make it easier for the reader to understand the content of that talk.

Lowering Poly without Type Materializations

The first commit sets up the pass shell, similar to the previous passes in the series, though it is located in a new directory lib/Conversion. The only caveat to this commit is that it adds dependent dialects:

let dependentDialects = [
  "mlir::arith::ArithDialect",
  "mlir::tutorial::poly::PolyDialect",
  "mlir::tensor::TensorDialect",
];

In particular, a lowering must depend in this way on any dialects that contain operations or types that the lowering will create, to ensure that MLIR loads those dialects before trying to run the pass.

The second commit defines a ConversionTarget, which tells MLIR how to determine what ops are within scope of the lowering. Specifically, it allows you to declare an entire dialect as “illegal” after the lowering is complete. We’ll do this for the poly dialect.

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
  using PolyToStandardBase::PolyToStandardBase;

  void runOnOperation() override {
    MLIRContext *context = &getContext();
    auto *module = getOperation();

    ConversionTarget target(*context);       // <--- new thing
    target.addIllegalDialect<PolyDialect>();

    RewritePatternSet patterns(context);

    if (failed(applyPartialConversion(module, target, std::move(patterns)))) {   // <-- new thing
      signalPassFailure();
    }
  }
};

ConversionTarget can also declare specific ops as illegal, or even conditionally legal, using a callback to specify exactly what counts as legal. And then instead of applyPatternsAndFoldGreedily, we use applyPartialConversion to kick off the lowering.

Aside: We use applyPartialConversion instead of the more natural sounding alternative applyFullConversion for a silly reason: the error messages are better. If applyFullConversion fails, even in debug mode, you don’t get a lot of information about what went wrong. In partial mode, you can see the steps and what type conflicts it wasn’t able to patch up at the end. As far as I can tell, partial conversion is a strict generalization of full conversion, and I haven’t found a reason to use applyFullConversion ever.

After adding the conversion target, running --poly-to-standard on any random poly program will result in an error:

$ bazel run tools:tutorial-opt -- --poly-to-standard $PWD/tests/poly_syntax.mlir
tests/poly_syntax.mlir:14:10: error: failed to legalize operation 'poly.add' that was explicitly marked illegal
    %0 = poly.add %arg0, %arg1 : !poly.poly<10>

In the next commit we define a subclass of TypeConverter to convert types from poly to other dialects.

class PolyToStandardTypeConverter : public TypeConverter {
 public:
  PolyToStandardTypeConverter(MLIRContext *ctx) {
    addConversion([](Type type) { return type; });
    addConversion([ctx](PolynomialType type) -> Type {
      int degreeBound = type.getDegreeBound();
      IntegerType elementTy =
          IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless);
      return RankedTensorType::get({degreeBound}, elementTy);
    });
  }
};

We call addConversion for each specific type we need to convert, and then the default addConversion([](Type type) { return type; }); signifies “no work needed, this type is legal.” By itself this type converter does nothing, because it is only used in conjunction with a specific kind of rewrite pattern, a subclass of OpConversionPattern.

struct ConvertAdd : public OpConversionPattern<AddOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult matchAndRewrite(
      AddOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {
    // ...
    return success();
  }
};

We add this, along with the new function call to RewritePatternSet::add, in this commit. This call differs from a normal RewritePatternSet::add in that it takes the type converter as input and passes it along to the constructor of ConvertAdd, whose parent class takes the type converter as input and stores it as a member variable.

void runOnOperation() {
  ...
  RewritePatternSet patterns(context);
  PolyToStandardTypeConverter typeConverter(context);
  patterns.add<ConvertAdd>(typeConverter, context);
}

An OpConversionPattern‘s matchAndRewrite method has two new arguments: the OpAdaptor and the ConversionPatternRewriter. The first, OpAdaptor, is an alias for AddOp::Adaptor which is part of the generated C++ code, holds the type-converted operands during dialect conversion. It’s called “adaptor” because it uses the table-gen defined names for an op’s arguments and results in the method names, rather than generic getOperand equivalents. Meanwhile, the AddOp argument (same as a normal rewrite pattern) contains the original, un-type-converted operands and results. The ConversionPatternRewriter is like a PatternRewriter, but it has additional methods relevant to dialect conversion, such as convertRegionTypes which is a helper for applying type conversions for ops with nested regions. All modifications to the IR must go through the ConversionPatternRewriter.

In the next commit, we implement the lowering for poly.add.

struct ConvertAdd : public OpConversionPattern<AddOp> {
  <...>
  LogicalResult matchAndRewrite(AddOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    arith::AddIOp addOp = rewriter.create<arith::AddIOp>(
        op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
    rewriter.replaceOp(op.getOperation(), {addOp});
    return success();
  }
};

The main thing to notice is that we don’t have to convert the types ourselves. The conversion framework converted the operands in advance, put them on the OpAdaptor, and we can just focus on creating the lowered op.

Mathematically, however, this lowering is so simple because I’m taking advantage of the overflow semantics of 32-bit integer addition to avoid computing modular remainders. Or rather, I’m interpreting addi to have natural unsigned overflow semantics, though I think there is some ongoing discussion about how precise these semantics should be. Also, arith.addi is defined to operate on tensors of integers elementwise, in addition to operating on individual integers.

However, there is one hiccup when we try to test it on the following IR

func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) {
  %2 = poly.add %0, %1: !poly.poly<10>
  return
}

The pass fails here because of a type conflict between the function arguments and the converted add op. We get an error that looks like

error: failed to legalize unresolved materialization from '!poly.poly<10>' to 'tensor<10xi32>' that remained live after conversion
  %2 = poly.add %0, %1: !poly.poly<10>
       ^
poly_to_standard.mlir:6:8: note: see current operation: %0 = "builtin.unrealized_conversion_cast"(%arg1) : (!poly.poly<10>) -> tensor<10xi32>
poly_to_standard.mlir:6:8: note: see existing live user here: %2 = arith.addi %1, %0 : tensor<10xi32>

If you look at the debug output (run the pass with --debug, i.e., bazel run tools:tutorial-opt -- --poly-to-standard --debug $PWD/tests/poly_to_standard.mlir) you’ll see a log of what the conversion framework is trying to do. Eventually it spits out this IR just before the end:

func.func @test_lower_add(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) {
  %0 = builtin.unrealized_conversion_cast %arg1 : !poly.poly<10> to tensor<10xi32>
  %1 = builtin.unrealized_conversion_cast %arg0 : !poly.poly<10> to tensor<10xi32>
  %2 = arith.addi %1, %0 : tensor<10xi32>
  %3 = poly.add %arg0, %arg1 : !poly.poly<10>
  return
}

This builtin.unrealized_conversion_cast is the internal stand-in for a type conflict before attempting to use user-defined materialization hooks to fix the conflicts (which we didn’t implement). It’s basically a forced type coercion. From the docs:

This operation should not be attributed any special representational or execution semantics, and is generally only intended to be used to satisfy the temporary intermixing of type systems during the conversion of one type system to another.

So at this point we have two options to make the pass succeed:

  • Convert the types in the function ops as well.
  • Add a custom type materializer to replace the unrealized conversion cast.

The second option doesn’t really make sense, since we want to get rid of poly.poly entirely in this pass. But if we had to, we would use poly.from_tensor and poly.to_tensor, and we’ll show how to do that in the next section. For now, we’ll convert the function types.

One should have to lower a structural op like func.func in every conversion pass that has a new type that can show up in a function signature. Ditto for ops like scf.if, scf.for, etc. Luckily, MLIR is general enough to handle it for you, but it requires opting in via some extra boilerplate in this commit. I won’t copy the boilerplate here, but basically you call helpers that use interfaces to define patterns to support type conversions for all function-looking ops (func/call/return), and all if/then/else-looking ops. We just call those helpers and give them our type converter, and it gives us patterns to add to our RewritePatternSet.

After that, the pass runs and we get

  func.func @test_lower_add(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) {
    %0 = arith.addi %arg0, %arg1 : tensor<10xi32>
    return
  }

I wanted to highlight one other aspect of dialect conversions, which is that they use the folding capabilities of the dialect being converted. So for an IR like

func.func @test_lower_add_and_fold() {
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.constant dense<[3, 4, 5]> : tensor<3xi32> : !poly.poly<10>
  %2 = poly.add %0, %1: !poly.poly<10>
  return
}

Instead of lowering poly.add to arith.addi, and using a lowering for poly.constant added in this commit, we get

func.func @test_lower_add_and_fold() {
  %cst = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
  %cst_0 = arith.constant dense<[3, 4, 5]> : tensor<3xi32>
  %cst_1 = arith.constant dense<[5, 7, 9]> : tensor<3xi32>
  return
}

So as it lowers, it eagerly tries to fold the results, and that can result in some constant propagation.

Next, we lower the rest of the ops:

  • sub, from_tensor, to_tensor. The from_tensor lowering is slightly interesting, in that it allows the user to input just the lowest-degree coefficients of the polynomial, and then pads the higher degree terms with zero. This results in lowering to a tensor.pad op.
  • constant, which is interesting in that it lowers to another poly op (arith.constant + poly.from_tensor), which is then recursively converted via the rewrite rule for from_tensor.
  • mul, which lowers as a standard naive double for loop using the scf.for op
  • eval, which lowers as a single for loop and Horner’s method. (Note it does not support complex inputs, despite the support we added for that in the previous article)

Then in this commit you can see how it lowers a combination of all the poly ops to an equivalent program that uses just tensor, scf, and arith. For the input test:

func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 {
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.add %0, %arg : !poly.poly<10>
  %2 = poly.mul %1, %1 : !poly.poly<10>
  %3 = poly.sub %2, %arg : !poly.poly<10>
  %4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32
  return %4 : i32
}

It produces the following IR:

  func.func @test_lower_many(%arg0: tensor<10xi32>, %arg1: i32) -> i32 {
    %cst = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %c0_i32 = arith.constant 0 : i32
    %padded = tensor.pad %cst low[0] high[7] {
    ^bb0(%arg2: index):
      tensor.yield %c0_i32 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %0 = arith.addi %padded, %arg0 : tensor<10xi32>
    %cst_0 = arith.constant dense<0> : tensor<10xi32>
    %c0 = arith.constant 0 : index
    %c10 = arith.constant 10 : index
    %c1 = arith.constant 1 : index
    %1 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %cst_0) -> (tensor<10xi32>) {
      %4 = scf.for %arg4 = %c0 to %c10 step %c1 iter_args(%arg5 = %arg3) -> (tensor<10xi32>) {
        %5 = arith.addi %arg2, %arg4 : index
        %6 = arith.remui %5, %c10 : index
        %extracted = tensor.extract %0[%arg2] : tensor<10xi32>
        %extracted_3 = tensor.extract %0[%arg4] : tensor<10xi32>
        %7 = arith.muli %extracted, %extracted_3 : i32
        %extracted_4 = tensor.extract %arg5[%6] : tensor<10xi32>
        %8 = arith.addi %7, %extracted_4 : i32
        %inserted = tensor.insert %8 into %arg5[%6] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %4 : tensor<10xi32>
    }
    %2 = arith.subi %1, %arg0 : tensor<10xi32>
    %c1_1 = arith.constant 1 : index
    %c11 = arith.constant 11 : index
    %c0_i32_2 = arith.constant 0 : i32
    %3 = scf.for %arg2 = %c1_1 to %c11 step %c1_1 iter_args(%arg3 = %c0_i32_2) -> (i32) {
      %4 = arith.subi %c11, %arg2 : index
      %5 = arith.muli %arg1, %arg3 : i32
      %extracted = tensor.extract %2[%4] : tensor<10xi32>
      %6 = arith.addi %5, %extracted : i32
      scf.yield %6 : i32
    }
    return %3 : i32
  }

Materialization hooks

The previous section didn’t use the materialization hook infrastructure mentioned earlier, simply because we don’t need it. It’s only necessary when type conflicts must persist across multiple passes, and we have no trouble lowering all of poly in one pass. But for demonstration purposes, this commit (reverted in the same PR) shows how one would implement the materialization hooks. I removed the lowering for poly.sub, reconfigured the ConversionTarget so that poly.sub and to_tensor, from_tensor are no longer illegal, and then added two hooks on the TypeConverter to insert new from_tensor and to_tensor ops when they are needed:

// Convert from a tensor type to a poly type: use from_tensor
addSourceMaterialization([](OpBuilder &builder, Type type,
                            ValueRange inputs, Location loc) -> Value {
  return builder.create<poly::FromTensorOp>(loc, type, inputs[0]);
});

// Convert from a poly type to a tensor type: use to_tensor
addTargetMaterialization([](OpBuilder &builder, Type type,
                            ValueRange inputs, Location loc) -> Value {
  return builder.create<poly::ToTensorOp>(loc, type, inputs[0]);
});

Now the same lowering above succeeds, except anywhere there is a poly.sub it is surrounded by from_tensor and to_tensor. The output for test_lower_many now looks like this:

  func.func @test_lower_many(%arg0: tensor<10xi32>, %arg1: i32) -> i32 {
    %2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %cst_0) -> (tensor<10xi32>) {
      %7 = scf.for %arg4 = %c0 to %c10 step %c1 iter_args(%arg5 = %arg3) -> (tensor<10xi32>) {
         ...
      }
      scf.yield %7 : tensor<10xi32>
    }
    %3 = poly.from_tensor %2 : tensor<10xi32> -> !poly.poly<10>
    %4 = poly.sub %3, %0 : !poly.poly<10>
    %5 = poly.to_tensor %4 : !poly.poly<10> -> tensor<10xi32>
    %c1_1 = arith.constant 1 : index
    ...
  }

Correctness?

Lowering poly is not entirely trivial, because you’re implementing polynomial math. The tests I wrote are specific, in that they put tight constraints on the output IR, but what if I’m just dumb and I got the algorithm wrong? It stands to reason that I would want an extra layer of certainty to guard against my own worst enemy.

One of the possible solutions here is to use mlir-cpu-runner, like we did when we lowered ctlz to LLVM. Another is to lower all the way to machine code and run it on the actual hardware instead of through an interpreter. And that’s what we’ll do next time. Until then, the poly lowerings might be buggy. Feel free to point out any bugs in GitHub issues or comments on this blog.

Socks, a matching game based on an additive combinatorics problem

Can you find a set of cards among these six, such that the socks on the chosen cards can be grouped into matching pairs? (Duplicate pairs of the same sock are OK)

Spoilers

If the cards are indexed as

1   2   3
4   5   6

Then the following three subsets work: $\{ 1, 2, 4, 5, 6 \}$, $\{ 2, 3, 6 \}$, and $\{ 1, 3, 4, 5 \}$. There might be more but I don’t see them.

This is the objective of the game Socks, a card game originally designed by Anna Varvak, a math professor at Soka University of America. I tweaked the rules slightly, and designed the deck above, and you can buy a copy. (I’m selling it with Anna’s permission)

The game consists of 63 cards, one for each possible nonempty subset of 6 distinct socks. Duplicate pairs of the same kind of sock are OK, so long as they can be grouped into pairs (e.g., 3 yellow socks is not allowed, but 4 is). Players deal some cards and race to find sets, and the player at the end of the game with the most sets wins. The natural math question is: how many cards are needed to guarantee there’s a set?

Mathematically, the cards can hence be viewed as elements of the additive group $(\mathbb{Z}/2\mathbb{Z})^6$ of length-6 binary vectors. An index corresponds to a sock type, and the value is one of the sock is present on the card, and zero otherwise.

I use the “additive group” structure because a set in Socks is a subset of $(\mathbb{Z}/2\mathbb{Z})^6$ that sums to zero in the group. In programmer terms, a card is a length-6 bitstring and a “sum” is a bitwise XOR. Either way, “summing” two cards can be thought of as producing the card consisting of only the socks that show up an odd number of times on the two cards. One implication here is that any set requires at least three cards: each group element is its own inverse, but no card shows up twice in the deck.

Then the natural math question becomes: how many cards do you need to guarantee there is a valid set? In math terms, what is the smallest integer $M$ such that every size-$M$ subset $S \subset (\mathbb{Z}/2\mathbb{Z})^6$ contains a nonempty zero-summing subset?

Theorem: $M = 7$.

Proof. The following set of size 6 has no zero-summing subset

\[ \{ (1, 0, \dots, 0), (0, 1, 0, \dots, 0), \dots, (0, 0, \dots, 1) \} \]

So $M > 6$. On the other hand, consider any set $S$ of size 7. There are $2^{7} – 1$ distinct sums of nonempty subsets of $S$, but only $2^6$ group elements, so there must be two distinct subsets of $S$ that have the same sum. Suppose they are $X = {x_1, \dots, x_r}$ and $Y = {y_1, \dots, y_s}$, and $g = x_1 + \dots + x_r = y_1 + \dots + y_s$. Then adding the two equations, and noting that every element $x$ satisfies $x = -x$ in this group, we get

\[ x_1 + \dots + x_r + y_1 + \dots + y_s = g + g = 0 \]

In other words, we would like to use the “set” $\{ x_1, \dots, x_r, y_1, \dots, y_s \}$ and declare victory, but we can’t because some of the $x_i$ may coincide with some of the $y_j$. That is, $X$ and $Y$ can overlap, and we can’t “use” an element twice. But because the two sets are distinct (not equal), they cannot overlap completely. For any values that overlap, say $x_1 = y_2$, their sum is $x_1 + y_2 = 2x_1 = 0$, and those two elements can be removed without changing the sum. Hence, the final zero-summing subset is the symmetric difference $X \triangle Y$.

$\square$

Unfortunately, as I’ve played Socks, I’ve found that it’s no easier to find a zero-summing subset than it is to find two subsets that have the same sum. So while knowing this proof helps me win the hearts and minds of my opponents, it doesn’t help me win the game.

The next natural goal for a mathematician is to ask the same question of more general groups. The above proof argument naturally extends to $(\mathbb{Z}/2\mathbb{Z})^k$ having $M= k+1$. But for $(\mathbb{Z}/n\mathbb{Z})^k$, the problem is open.

Conjecture: Let $G = (\mathbb{Z}/n\mathbb{Z})^k$, then every set $S$ of $1 + k(n-1)$ elements of $G$ has a nonempty zero-summing subset.

The same argument from the previous theorem doesn’t quite apply: even though you can prove some distinct subsets of $S$ have the same sum, the difference gives different group elements that may not be in $S$.

This generalization opens the door to a decent subset of the number theory and additive combinatorics literature, in which this is called the Olson’s constant problem. Section F.3.3 of this survey of Bela Bajnok covers the literature quite well. Erdős originally worked on the problem in the 60’s. There have been a swath of results for groups with particular structure, for example in John Olson’s original 1969 paper, A Combinatorial Problem on Finite Abelian Groups, I, in which he proves the above conjecture for the special case of finite abelian $p$-groups. I will summarize a few results and conjectures below, most copied from Bajnok’s survey, and if you’re clever enough, each one could be the basis for a new card game. The hard part, it seems, is finding a theme that is cute and picking a small enough group so as to make it fun.

Theorem: For every even integer $n$, Olson’s constant for $\mathbb{Z}/n\mathbb{Z}$ is at least $1 + \lfloor \sqrt{2n – 3} \rfloor$.

For all even $n \leq 64$, this bound is known to be an equality. E.g., for $n=64$ it is 12, and for $n=50$ it is 10. There are no known values of $n$ for which this bound is not tight.

Theorem: For every prime $p$, Olson’s constant for $\mathbb{Z}/p\mathbb{Z}$ is exactly $1 + \lfloor \sqrt{2p} – 1 / 2 \rfloor$.

For example, for $p=53$, Olson’s constant is 10. (Deal ten cards containing…something both cute and interpretable mod 53! Good luck)

Odd composite values of $n$ is still an open problem, though a lower bound is known of $1 + \lfloor (\sqrt{8n + 9} – 1 ) /2 \rfloor$.

The case for products of cyclic groups involves various bounds (the $\tau$ below represents Olson’s constant minus 1).

Equality is known to hold for $k = 2, 3, 4, 5$.

Theorem: For any finite Abelian group of order $n$, Olson’s constant is less than $3 \sqrt{n} + 1$.

Conjecture [Erdős]: For any finite Abelian group of order $n$, Olson’s constant is less than $1 + \sqrt{2n}$.

Theorem: There is a constant $C$, such that for any finite Abelian group of order $n$, Olson’s constant is at most $1 + \sqrt{2n} + C \sqrt[3]{n} \log_e n$.

And finally, it is conjectured that cyclic groups have maximal Olson constant among all groups of a given order. So if you want to make your game hard, use cyclic groups. If you want to make it easy, use groups that are products of many small cyclic groups.

The origin story

The story of how I came to make my own version of Socks is a funny little mix-up.

Though the original game was designed by Anna Varvak in 2012, I made the cards shown in the picture at the beginning of this article. A friend of mine originally told me about Socks about six months ago. She had heard about it from a friend of hers, but she couldn’t find a physical copy for sale. Socks is listed at BoardGameGeek, but it had a broken website link, www.socksgame.com. I assumed it was out of print. I thought making my own version would be fun! So I found The Game Crafter, a website for on-demand printing of game components, and threw together a design.

When I finished, I noticed the generated URL was www.thegamecrafter.com/games/socks2 and I thought, “That’s weird, someone else made a game called Socks!” And voila, https://www.thegamecrafter.com/games/socks points to Varvak’s original game, still for sale.

Luckily, I was able to get in contact with Anna, and she graciously gave me permission to sell my version under the same name, “Socks.” The official rules in each game are slightly different: in hers you deal 12 cards and a set must consist of exactly 3 cards. In mine you deal 7 cards, and a set can be formed from any subset of the cards. When you restrict the subsets to have size exactly 3, the problem is slightly different (see section F.3.1 of Bajnok’s survey), and the main result that applies is

Theorem: For all positive integers $k$, every set $S \subset (\mathbb{Z}/2\mathbb{Z})^k$ with size at least $2^{k-1} + 2$ has a zero-summing subset of size 3, and there is a set of size $2^{k-1} + 1$ with no zero-summing subset of size 3.

In Anna’s version, it appears, 12 cards is not enough. You need a whopping 34 to guarantee a set exists.

Some playing notes

Some notes about what happens when actually playing the game:

  • A set is worth 1 point, no matter how many cards are in it. I have no idea if it is a better strategy to look for smaller or larger sets.
  • A set must have size at least three, and sets of size exactly three are easy-ish to find by XOR-summing pairs of cards.
  • Each sock type is in the same location on the cards, which makes it easier to visualize.
  • Often you can locate small groups of cards (say, 1-2) that are unusable because their inclusion forces an odd number of some sock. This seems to drastically diminish the size of the search space.
  • A set using all the cards is not that uncommon, so it seems to help to start by summing all the socks and identifying the socks that break that possibility, or else claim a quick set.
  • It does not appear to be easier to find sets when the cards are “sparse” (1-3 socks per card) vs “dense” (4-6 socks per card).
  • When the last card is dealt you can immediately claim “Socks”, as long as every claimed set in the game was valid. This makes the last round boring, and it is slightly more interesting to leave the last card undealt, and then “infer” what is on that card by inverting the sum of the remaining cards. This is similar to the parlor trick for SET I wrote about a while back. Doing this trick for Socks is not nearly as hard as it is for SET, but might be impressive to players who have not thought about the math.