Table of Contents

In the last article we lowered our custom poly dialect to standard MLIR dialects. In this article we’ll continue lowering it to LLVM IR, exporting it out of MLIR to LLVM, and then compiling to x86 machine code.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Defining a Pipeline

The first step in lowering to machine code is to lower to an “exit dialect.” That is, a dialect from which there is a code-gen tool that converts MLIR code to some non-MLIR format. In our case, we’re targeting x86, so the exit dialect is the LLVM dialect, and the code-gen tool is the binary mlir-translate (more on that later). Lowering to an exit dialect, as it turns out, is not all that simple.

One of the things I’ve struggled with when learning MLIR is how to compose all the different lowerings into a pipeline that ends in the result I want, especially when other people wrote those pipelines. When starting from a high level dialect like linalg (linear algebra), there can be dozens of lowerings involved, and some of them can reintroduce ops from dialects you thought you completely lowered already, or have complex pre-conditions or relationships to other lowerings. There are also simply too many lowerings to easily scan.

There are two ways to specify a lowering from the MLIR binary. One is completely on the command line. You can use the --pass-pipeline flag with a tiny DSL, like this

bazel run //tools:tutorial-opt -- foo.mlir \
  --pass-pipeline='builtin.module(func.func(cse,canonicalize),convert-func-to-llvm)'

Above the thing that looks like a function call is “anchoring” a sequences of passes to operate on a particular op (allowing them to run in parallel across ops).

Thankfully, you can also declare the pipeline in code, and wrap it up in a single flag. The above might be equivalently defined as follows:

void customPipelineBuilder(mlir::OpPassManager &pm) {
  pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanoncalizerPass());
  pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
  pm.addPass(createConvertFuncToLLVMPass());  // runs on builtin.module by default
}

int main(int argc, char **argv) {
  mlir::DialectRegistry registry;
  <... register dialects ...>

  mlir::PassPipelineRegistration<>(
      "my-pipeline", "A custom pipeline", customPipelineBuilder);

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "Tutorial opt main", registry));
}

We’ll do the actually-runnable analogue of this in the next section when we lower the poly dialect to LLVM.

Lowering Poly to LLVM

In this section we’ll define a pipeline lowering poly to LLVM and show the MLIR along each step. Strap in, there’s going to be a lot of MLIR code in this article.

The process I’ve used to build up a big pipeline is rather toilsome and incremental. Basically, start from an empty pipeline and the starting MLIR, then look for the “highest level” op you can think of, add a pass that lowers it to the pipeline, and if that pass fails, figure out what pass is required before it. Then repeat until you have achieved your target.

In this commit we define a pass pipeline --poly-to-llvm that includes only the --poly-to-standard pass defined last time, along with a canonicalization pass. Then we start from this IR:

$ cat $PWD/tests/poly_to_llvm.mlir

func.func @test_poly_fn(%arg : i32) -> i32 {
  %tens = tensor.splat %arg : tensor<10xi32>
  %input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.add %0, %input : !poly.poly<10>
  %2 = poly.mul %1, %1 : !poly.poly<10>
  %3 = poly.sub %2, %input : !poly.poly<10>
  %4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
  return %4 : i32
}

$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir

module {
  func.func @test_poly_fn(%arg0: i32) -> i32 {
    %c11 = arith.constant 11 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %c0_i32 = arith.constant 0 : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %padded = tensor.pad %cst_0 low[0] high[7] {
    ^bb0(%arg1: index):
      tensor.yield %c0_i32 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %0 = arith.addi %padded, %splat : tensor<10xi32>
    %1 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
      %4 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
        %5 = arith.addi %arg1, %arg3 : index
        %6 = arith.remui %5, %c10 : index
        %extracted = tensor.extract %0[%arg3] : tensor<10xi32>
        %extracted_1 = tensor.extract %0[%arg1] : tensor<10xi32>
        %7 = arith.muli %extracted_1, %extracted : i32
        %extracted_2 = tensor.extract %arg4[%6] : tensor<10xi32>
        %8 = arith.addi %7, %extracted_2 : i32
        %inserted = tensor.insert %8 into %arg4[%6] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %4 : tensor<10xi32>
    }
    %2 = arith.subi %1, %splat : tensor<10xi32>
    %3 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
      %4 = arith.subi %c11, %arg1 : index
      %5 = arith.muli %arg0, %arg2 : i32
      %extracted = tensor.extract %2[%4] : tensor<10xi32>
      %6 = arith.addi %5, %extracted : i32
      scf.yield %6 : i32
    }
    return %3 : i32
  }
}

Let’s naively start from the top and see what happens. One way you can try this more interactively is to run a tentative pass to add like bazel run //tools:tutorial-opt -- --poly-to-llvm --pass-to-try $PWD/tests/poly_to_llvm.mlir, and it will run the hard-coded pipeline followed by the new pass.

The first op that can be lowered is func.func, and there is a convert-func-to-llvm pass, which we add in this commit. It turns out to process a lot more than just the func op:

module attributes {llvm.data_layout = ""} {
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(11 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %8 = llvm.mlir.constant(0 : i32) : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %padded = tensor.pad %cst_0 low[0] high[7] {
    ^bb0(%arg1: index):
      tensor.yield %8 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %9 = arith.addi %padded, %splat : tensor<10xi32>
    %10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
        %15 = builtin.unrealized_conversion_cast %arg3 : index to i64
        %16 = llvm.add %13, %15  : i64
        %17 = llvm.urem %16, %4  : i64
        %18 = builtin.unrealized_conversion_cast %17 : i64 to index
        %extracted = tensor.extract %9[%arg3] : tensor<10xi32>
        %extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32>
        %19 = llvm.mul %extracted_1, %extracted  : i32
        %extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32>
        %20 = llvm.add %19, %extracted_2  : i32
        %inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %14 : tensor<10xi32>
    }
    %11 = arith.subi %10, %splat : tensor<10xi32>
    %12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = llvm.sub %0, %13  : i64
      %15 = builtin.unrealized_conversion_cast %14 : i64 to index
      %16 = llvm.mul %arg0, %arg2  : i32
      %extracted = tensor.extract %11[%15] : tensor<10xi32>
      %17 = llvm.add %16, %extracted  : i32
      scf.yield %17 : i32
    }
    llvm.return %12 : i32
  }
}

Notably, this pass converted most of the arithmetic operations—though not the tensor-generating ones that use the dense attribute—and inserted a number of unrealized_conversion_cast ops for the resulting type conflicts, which we’ll have to get rid of eventually, but we can’t now because the values on both sides of the type conversion are still used.

Next we’ll lower arith to LLVM using the suggestively-named convert-arith-to-llvm in this commit. However, it has no effect on the resulting IR, and the arith ops remain. What gives? It turns out that arith ops that operate on tensors are not supported by convert-arith-to-llvm. To deal with this, we need a special pass called convert-elementwise-to-linalg, which lowers these ops to linalg.generic ops. (I plan to cover linalg.generic in a future tutorial).

We add it in this commit, and this is the diff between the above IR and the new output (right > is the new IR):

> #map = affine_map<(d0) -> (d0)>
19c20,24
<     %9 = arith.addi %padded, %splat : tensor<10xi32>
---
>     %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
>     ^bb0(%in: i32, %in_1: i32, %out: i32):
>       %13 = arith.addi %in, %in_1 : i32
>       linalg.yield %13 : i32
>     } -> tensor<10xi32>
37c42,46
<     %11 = arith.subi %10, %splat : tensor<10xi32>
---
>     %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) {
>     ^bb0(%in: i32, %in_1: i32, %out: i32):
>       %13 = arith.subi %in, %in_1 : i32
>       linalg.yield %13 : i32
>     } -> tensor<10xi32>
49a59
>

This is nice, but now you can see we have a new problem: lowering linalg, and re-lowering the arith ops that were inserted by the convert-elementwise-to-linalg pass. So adding back the arith pass at the end in this commit replaces the two inserted arith ops, giving this IR

#map = affine_map<(d0) -> (d0)>
module attributes {llvm.data_layout = ""} {
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(11 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %8 = llvm.mlir.constant(0 : i32) : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %padded = tensor.pad %cst_0 low[0] high[7] {
    ^bb0(%arg1: index):
      tensor.yield %8 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
    ^bb0(%in: i32, %in_1: i32, %out: i32):
      %13 = llvm.add %in, %in_1  : i32
      linalg.yield %13 : i32
    } -> tensor<10xi32>
    %10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
        %15 = builtin.unrealized_conversion_cast %arg3 : index to i64
        %16 = llvm.add %13, %15  : i64
        %17 = llvm.urem %16, %4  : i64
        %18 = builtin.unrealized_conversion_cast %17 : i64 to index
        %extracted = tensor.extract %9[%arg3] : tensor<10xi32>
        %extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32>
        %19 = llvm.mul %extracted_1, %extracted  : i32
        %extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32>
        %20 = llvm.add %19, %extracted_2  : i32
        %inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %14 : tensor<10xi32>
    }
    %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) {
    ^bb0(%in: i32, %in_1: i32, %out: i32):
      %13 = llvm.sub %in, %in_1  : i32
      linalg.yield %13 : i32
    } -> tensor<10xi32>
    %12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) {
      %13 = builtin.unrealized_conversion_cast %arg1 : index to i64
      %14 = llvm.sub %0, %13  : i64
      %15 = builtin.unrealized_conversion_cast %14 : i64 to index
      %16 = llvm.mul %arg0, %arg2  : i32
      %extracted = tensor.extract %11[%15] : tensor<10xi32>
      %17 = llvm.add %16, %extracted  : i32
      scf.yield %17 : i32
    }
    llvm.return %12 : i32
  }
}

However, there are still two arith ops remaining: arith.constant using dense attributes to define tensors. There’s no obvious pass that handles this from looking at the list of available passes. As we’ll see, it’s the bufferization pass that handles this (we discussed it briefly in the previous article), and we should try to lower as much as possible before bufferizing. In general, bufferizing makes optimization passes harder.

The next thing to try lowering is tensor.splat. There’s no obvious pass as well (tensor-to-linalg doesn’t lower it), and searching the LLVM codebase for SplatOp we find this pattern which suggests tensor.splat is also lowered during bufferization, and produces linalg.map ops. So we’ll have to lower those as well eventually.

But what tensor-to-linalg does lower is the next op: tensor.pad. It replaces a pad with the following

<     %padded = tensor.pad %cst_0 low[0] high[7] {
<     ^bb0(%arg1: index):
<       tensor.yield %8 : i32
<     } : tensor<3xi32> to tensor<10xi32>
<     %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
---
>     %9 = tensor.empty() : tensor<10xi32>
>     %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32>
>     %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32>
>     %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {

We add that pass in this commit.

Next we have all of these linalg ops, as well as the scf.for loops. The linalg ops can lower naively as for loops, so we’ll do that first. There are three options:

I want to make this initial pipeline as simple and naive as possible, so convert-linalg-to-loops it is. However, the pass does nothing on this IR (added in this commit). If you run the pass with --debug you can find an explanation:

 ** Failure : expected linalg op with buffer semantics

So linalg must be bufferized before it can be lowered to loops.

However, we can tackle the scf to LLVM step with a combination of two passes: convert-scf-to-cf and convert-cf-to-llvm, in this commit. The lowering added arith.cmpi for the loop predicates, so moving arith-to-llvm to the end of the pipeline fixes that (commit). In the end we get this IR:

#map = affine_map<(d0) -> (d0)>
module attributes {llvm.data_layout = ""} {
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(11 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(1 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = builtin.unrealized_conversion_cast %6 : i64 to index
    %cst = arith.constant dense<0> : tensor<10xi32>
    %8 = llvm.mlir.constant(0 : i32) : i32
    %cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %splat = tensor.splat %arg0 : tensor<10xi32>
    %9 = tensor.empty() : tensor<10xi32>
    %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32>
    %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32>
    %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {
    ^bb0(%in: i32, %in_4: i32, %out: i32):
      %43 = llvm.add %in, %in_4  : i32
      linalg.yield %43 : i32
    } -> tensor<10xi32>
    cf.br ^bb1(%7, %cst : index, tensor<10xi32>)
  ^bb1(%12: index, %13: tensor<10xi32>):  // 2 preds: ^bb0, ^bb5
    %14 = builtin.unrealized_conversion_cast %12 : index to i64
    %15 = llvm.icmp "slt" %14, %4 : i64
    llvm.cond_br %15, ^bb2, ^bb6
  ^bb2:  // pred: ^bb1
    %16 = builtin.unrealized_conversion_cast %12 : index to i64
    cf.br ^bb3(%7, %13 : index, tensor<10xi32>)
  ^bb3(%17: index, %18: tensor<10xi32>):  // 2 preds: ^bb2, ^bb4
    %19 = builtin.unrealized_conversion_cast %17 : index to i64
    %20 = llvm.icmp "slt" %19, %4 : i64
    llvm.cond_br %20, ^bb4, ^bb5
  ^bb4:  // pred: ^bb3
    %21 = builtin.unrealized_conversion_cast %17 : index to i64
    %22 = llvm.add %16, %21  : i64
    %23 = llvm.urem %22, %4  : i64
    %24 = builtin.unrealized_conversion_cast %23 : i64 to index
    %extracted = tensor.extract %11[%17] : tensor<10xi32>
    %extracted_1 = tensor.extract %11[%12] : tensor<10xi32>
    %25 = llvm.mul %extracted_1, %extracted  : i32
    %extracted_2 = tensor.extract %18[%24] : tensor<10xi32>
    %26 = llvm.add %25, %extracted_2  : i32
    %inserted = tensor.insert %26 into %18[%24] : tensor<10xi32>
    %27 = llvm.add %19, %2  : i64
    %28 = builtin.unrealized_conversion_cast %27 : i64 to index
    cf.br ^bb3(%28, %inserted : index, tensor<10xi32>)
  ^bb5:  // pred: ^bb3
    %29 = llvm.add %14, %2  : i64
    %30 = builtin.unrealized_conversion_cast %29 : i64 to index
    cf.br ^bb1(%30, %18 : index, tensor<10xi32>)
  ^bb6:  // pred: ^bb1
    %31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%13, %splat : tensor<10xi32>, tensor<10xi32>) outs(%13 : tensor<10xi32>) {
    ^bb0(%in: i32, %in_4: i32, %out: i32):
      %43 = llvm.sub %in, %in_4  : i32
      linalg.yield %43 : i32
    } -> tensor<10xi32>
    cf.br ^bb7(%3, %8 : index, i32)
  ^bb7(%32: index, %33: i32):  // 2 preds: ^bb6, ^bb8
    %34 = builtin.unrealized_conversion_cast %32 : index to i64
    %35 = llvm.icmp "slt" %34, %0 : i64
    llvm.cond_br %35, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %36 = builtin.unrealized_conversion_cast %32 : index to i64
    %37 = llvm.sub %0, %36  : i64
    %38 = builtin.unrealized_conversion_cast %37 : i64 to index
    %39 = llvm.mul %arg0, %33  : i32
    %extracted_3 = tensor.extract %31[%38] : tensor<10xi32>
    %40 = llvm.add %39, %extracted_3  : i32
    %41 = llvm.add %34, %2  : i64
    %42 = builtin.unrealized_conversion_cast %41 : i64 to index
    cf.br ^bb7(%42, %40 : index, i32)
  ^bb9:  // pred: ^bb7
    llvm.return %33 : i32
  }
}

Bufferization

Last time I wrote at some length about the dialect conversion framework, and how it’s complicated mostly because of the bufferization passes, which were split across multiple passes and introduced type conflicts that required special type materialization and later resolution.

Well, turns out these bufferization passes are now deprecated. These are the official docs, but the short story is that the network of bufferization passes was replaced by a one-shot-bufferize pass, with some boilerplate cleanup passes. This is the sequence of passes that is recommended by the docs, along with an option to ensure that function signatures are bufferized as well:

  // One-shot bufferize, from
  // https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation
  bufferization::OneShotBufferizationOptions bufferizationOptions;
  bufferizationOptions.bufferizeFunctionBoundaries = true;
  manager.addPass(
      bufferization::createOneShotBufferizePass(bufferizationOptions));
  manager.addPass(memref::createExpandReallocPass());
  manager.addPass(bufferization::createOwnershipBasedBufferDeallocationPass());
  manager.addPass(createCanonicalizerPass());
  manager.addPass(bufferization::createBufferDeallocationSimplificationPass());
  manager.addPass(bufferization::createLowerDeallocationsPass());
  manager.addPass(createCSEPass());
  manager.addPass(createCanonicalizerPass());

This pipeline exists upstream in a helper, but it was added after the commit we’ve pinned to. Moreover, some of the passes above were added after the commit we’ve pinned to! At this point it’s worth updating the upstream MLIR commit, which I did in this commit, and it required a few additional fixes to deal with API updates (starting with this commit and ending with this commit). Then this commit adds the one-shot bufferization pass and helpers to the pipeline.

Even after all of this, I was dismayed to learn that the pass still did not lower the linalg operators. After a bit of digging, I realized it was because the func-to-llvm pass was running too early in the pipeline. So this commit moves the pass later. Just because we’re building this a bit backwards and naively, let’s comment out the tail end of the pipeline to see the result of bufferization.

Before convert-linalg-to-loops, but omitting the rest:

#map = affine_map<(d0) -> (d0)>
module {
  memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
  memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
  func.func @test_poly_fn(%arg0: i32) -> i32 {
    %c0_i32 = arith.constant 0 : i32
    %c11 = arith.constant 11 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %0 = memref.get_global @__constant_10xi32 : memref<10xi32>
    %1 = memref.get_global @__constant_3xi32 : memref<3xi32>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    linalg.map outs(%alloc : memref<10xi32>)
      () {
        linalg.yield %arg0 : i32
      }
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc_0 : memref<10xi32>)
    %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
    memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
    linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_0, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_0 : memref<10xi32>) {
    ^bb0(%in: i32, %in_2: i32, %out: i32):
      %3 = arith.addi %in, %in_2 : i32
      linalg.yield %3 : i32
    }
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      scf.for %arg2 = %c0 to %c10 step %c1 {
        %3 = arith.addi %arg1, %arg2 : index
        %4 = arith.remui %3, %c10 : index
        %5 = memref.load %alloc_0[%arg2] : memref<10xi32>
        %6 = memref.load %alloc_0[%arg1] : memref<10xi32>
        %7 = arith.muli %6, %5 : i32
        %8 = memref.load %alloc_1[%4] : memref<10xi32>
        %9 = arith.addi %7, %8 : i32
        memref.store %9, %alloc_1[%4] : memref<10xi32>
      }
    }
    linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_1, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_1 : memref<10xi32>) {
    ^bb0(%in: i32, %in_2: i32, %out: i32):
      %3 = arith.subi %in, %in_2 : i32
      linalg.yield %3 : i32
    }
    %2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
      %3 = arith.subi %c11, %arg1 : index
      %4 = arith.muli %arg0, %arg2 : i32
      %5 = memref.load %alloc_1[%3] : memref<10xi32>
      %6 = arith.addi %4, %5 : i32
      scf.yield %6 : i32
    }
    memref.dealloc %alloc : memref<10xi32>
    memref.dealloc %alloc_0 : memref<10xi32>
    memref.dealloc %alloc_1 : memref<10xi32>
    return %2 : i32
  }
}

After convert-linalg-to-loops, omitting the conversions to LLVM, wherein there are no longer any linalg ops, just loops:

module {
  memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
  memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
  func.func @test_poly_fn(%arg0: i32) -> i32 {
    %c0 = arith.constant 0 : index
    %c10 = arith.constant 10 : index
    %c1 = arith.constant 1 : index
    %c0_i32 = arith.constant 0 : i32
    %c11 = arith.constant 11 : index
    %0 = memref.get_global @__constant_10xi32 : memref<10xi32>
    %1 = memref.get_global @__constant_3xi32 : memref<3xi32>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      memref.store %arg0, %alloc[%arg1] : memref<10xi32>
    }
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      memref.store %c0_i32, %alloc_0[%arg1] : memref<10xi32>
    }
    %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
    memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      %3 = memref.load %alloc_0[%arg1] : memref<10xi32>
      %4 = memref.load %alloc[%arg1] : memref<10xi32>
      %5 = arith.addi %3, %4 : i32
      memref.store %5, %alloc_0[%arg1] : memref<10xi32>
    }
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32>
    scf.for %arg1 = %c0 to %c10 step %c1 {
      scf.for %arg2 = %c0 to %c10 step %c1 {
        %3 = arith.addi %arg1, %arg2 : index
        %4 = arith.remui %3, %c10 : index
        %5 = memref.load %alloc_0[%arg2] : memref<10xi32>
        %6 = memref.load %alloc_0[%arg1] : memref<10xi32>
        %7 = arith.muli %6, %5 : i32
        %8 = memref.load %alloc_1[%4] : memref<10xi32>
        %9 = arith.addi %7, %8 : i32
        memref.store %9, %alloc_1[%4] : memref<10xi32>
      }
    }
    scf.for %arg1 = %c0 to %c10 step %c1 {
      %3 = memref.load %alloc_1[%arg1] : memref<10xi32>
      %4 = memref.load %alloc[%arg1] : memref<10xi32>
      %5 = arith.subi %3, %4 : i32
      memref.store %5, %alloc_1[%arg1] : memref<10xi32>
    }
    %2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
      %3 = arith.subi %c11, %arg1 : index
      %4 = arith.muli %arg0, %arg2 : i32
      %5 = memref.load %alloc_1[%3] : memref<10xi32>
      %6 = arith.addi %4, %5 : i32
      scf.yield %6 : i32
    }
    memref.dealloc %alloc : memref<10xi32>
    memref.dealloc %alloc_0 : memref<10xi32>
    memref.dealloc %alloc_1 : memref<10xi32>
    return %2 : i32
  }
}

And then finally, the rest of the pipeline, as defined so far.

module {
  memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
  memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(0 : index) : i64
    %1 = builtin.unrealized_conversion_cast %0 : i64 to index
    %2 = llvm.mlir.constant(10 : index) : i64
    %3 = builtin.unrealized_conversion_cast %2 : i64 to index
    %4 = llvm.mlir.constant(1 : index) : i64
    %5 = builtin.unrealized_conversion_cast %4 : i64 to index
    %6 = llvm.mlir.constant(0 : i32) : i32
    %7 = llvm.mlir.constant(11 : index) : i64
    %8 = builtin.unrealized_conversion_cast %7 : i64 to index
    %9 = memref.get_global @__constant_10xi32 : memref<10xi32>
    %10 = memref.get_global @__constant_3xi32 : memref<3xi32>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    cf.br ^bb1(%1 : index)
  ^bb1(%11: index):  // 2 preds: ^bb0, ^bb2
    %12 = builtin.unrealized_conversion_cast %11 : index to i64
    %13 = llvm.icmp "slt" %12, %2 : i64
    llvm.cond_br %13, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    memref.store %arg0, %alloc[%11] : memref<10xi32>
    %14 = llvm.add %12, %4  : i64
    %15 = builtin.unrealized_conversion_cast %14 : i64 to index
    cf.br ^bb1(%15 : index)
  ^bb3:  // pred: ^bb1
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    cf.br ^bb4(%1 : index)
  ^bb4(%16: index):  // 2 preds: ^bb3, ^bb5
    %17 = builtin.unrealized_conversion_cast %16 : index to i64
    %18 = llvm.icmp "slt" %17, %2 : i64
    llvm.cond_br %18, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    memref.store %6, %alloc_0[%16] : memref<10xi32>
    %19 = llvm.add %17, %4  : i64
    %20 = builtin.unrealized_conversion_cast %19 : i64 to index
    cf.br ^bb4(%20 : index)
  ^bb6:  // pred: ^bb4
    %subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
    memref.copy %10, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
    cf.br ^bb7(%1 : index)
  ^bb7(%21: index):  // 2 preds: ^bb6, ^bb8
    %22 = builtin.unrealized_conversion_cast %21 : index to i64
    %23 = llvm.icmp "slt" %22, %2 : i64
    llvm.cond_br %23, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %24 = memref.load %alloc_0[%21] : memref<10xi32>
    %25 = memref.load %alloc[%21] : memref<10xi32>
    %26 = llvm.add %24, %25  : i32
    memref.store %26, %alloc_0[%21] : memref<10xi32>
    %27 = llvm.add %22, %4  : i64
    %28 = builtin.unrealized_conversion_cast %27 : i64 to index
    cf.br ^bb7(%28 : index)
  ^bb9:  // pred: ^bb7
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
    memref.copy %9, %alloc_1 : memref<10xi32> to memref<10xi32>
    cf.br ^bb10(%1 : index)
  ^bb10(%29: index):  // 2 preds: ^bb9, ^bb14
    %30 = builtin.unrealized_conversion_cast %29 : index to i64
    %31 = llvm.icmp "slt" %30, %2 : i64
    llvm.cond_br %31, ^bb11, ^bb15
  ^bb11:  // pred: ^bb10
    %32 = builtin.unrealized_conversion_cast %29 : index to i64
    cf.br ^bb12(%1 : index)
  ^bb12(%33: index):  // 2 preds: ^bb11, ^bb13
    %34 = builtin.unrealized_conversion_cast %33 : index to i64
    %35 = llvm.icmp "slt" %34, %2 : i64
    llvm.cond_br %35, ^bb13, ^bb14
  ^bb13:  // pred: ^bb12
    %36 = builtin.unrealized_conversion_cast %33 : index to i64
    %37 = llvm.add %32, %36  : i64
    %38 = llvm.urem %37, %2  : i64
    %39 = builtin.unrealized_conversion_cast %38 : i64 to index
    %40 = memref.load %alloc_0[%33] : memref<10xi32>
    %41 = memref.load %alloc_0[%29] : memref<10xi32>
    %42 = llvm.mul %41, %40  : i32
    %43 = memref.load %alloc_1[%39] : memref<10xi32>
    %44 = llvm.add %42, %43  : i32
    memref.store %44, %alloc_1[%39] : memref<10xi32>
    %45 = llvm.add %34, %4  : i64
    %46 = builtin.unrealized_conversion_cast %45 : i64 to index
    cf.br ^bb12(%46 : index)
  ^bb14:  // pred: ^bb12
    %47 = llvm.add %30, %4  : i64
    %48 = builtin.unrealized_conversion_cast %47 : i64 to index
    cf.br ^bb10(%48 : index)
  ^bb15:  // pred: ^bb10
    cf.br ^bb16(%1 : index)
  ^bb16(%49: index):  // 2 preds: ^bb15, ^bb17
    %50 = builtin.unrealized_conversion_cast %49 : index to i64
    %51 = llvm.icmp "slt" %50, %2 : i64
    llvm.cond_br %51, ^bb17, ^bb18
  ^bb17:  // pred: ^bb16
    %52 = memref.load %alloc_1[%49] : memref<10xi32>
    %53 = memref.load %alloc[%49] : memref<10xi32>
    %54 = llvm.sub %52, %53  : i32
    memref.store %54, %alloc_1[%49] : memref<10xi32>
    %55 = llvm.add %50, %4  : i64
    %56 = builtin.unrealized_conversion_cast %55 : i64 to index
    cf.br ^bb16(%56 : index)
  ^bb18:  // pred: ^bb16
    cf.br ^bb19(%5, %6 : index, i32)
  ^bb19(%57: index, %58: i32):  // 2 preds: ^bb18, ^bb20
    %59 = builtin.unrealized_conversion_cast %57 : index to i64
    %60 = llvm.icmp "slt" %59, %7 : i64
    llvm.cond_br %60, ^bb20, ^bb21
  ^bb20:  // pred: ^bb19
    %61 = builtin.unrealized_conversion_cast %57 : index to i64
    %62 = llvm.sub %7, %61  : i64
    %63 = builtin.unrealized_conversion_cast %62 : i64 to index
    %64 = llvm.mul %arg0, %58  : i32
    %65 = memref.load %alloc_1[%63] : memref<10xi32>
    %66 = llvm.add %64, %65  : i32
    %67 = llvm.add %59, %4  : i64
    %68 = builtin.unrealized_conversion_cast %67 : i64 to index
    cf.br ^bb19(%68, %66 : index, i32)
  ^bb21:  // pred: ^bb19
    memref.dealloc %alloc : memref<10xi32>
    memref.dealloc %alloc_0 : memref<10xi32>
    memref.dealloc %alloc_1 : memref<10xi32>
    llvm.return %58 : i32
  }
}

The remaining problems:

The second needs a special pass, memref-expand-strided-metadata, which reduces more complicated memref ops to simpler ones that can be lowered. The third is fixed by using finalize-memref-to-llvm, which lowers index to llvm.ptr and memref to llvm.struct and llvm.array. A final reconcile-unrealized-casts removes the cast operations, provided they can safely be removed. Both are combined in this commit.

However, the first one still eluded me for a while, until I figured out through embarrassing trial and error that again func-to-llvm was too early in the pipeline. Moving it to the end (this commit) resulted in cf.br being lowered to llvm.br and llvm.cond_br.

Finally, this commit adds a set of standard cleanup passes, including constant propagation, common subexpression elimination, dead code elimination, and canonicalization. The final IR looks like this

module {
  llvm.func @free(!llvm.ptr)
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.mlir.global private constant @__constant_3xi32(dense<[2, 3, 4]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<3 x i32>
  llvm.mlir.global private constant @__constant_10xi32(dense<0> : tensor<10xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<10 x i32>
  llvm.func @test_poly_fn(%arg0: i32) -> i32 {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.constant(64 : index) : i64
    %2 = llvm.mlir.constant(3 : index) : i64
    %3 = llvm.mlir.constant(0 : index) : i64
    %4 = llvm.mlir.constant(10 : index) : i64
    %5 = llvm.mlir.constant(1 : index) : i64
    %6 = llvm.mlir.constant(11 : index) : i64
    %7 = llvm.mlir.addressof @__constant_10xi32 : !llvm.ptr
    %8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
    %9 = llvm.mlir.addressof @__constant_3xi32 : !llvm.ptr
    %10 = llvm.getelementptr %9[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<3 x i32>
    %11 = llvm.mlir.zero : !llvm.ptr
    %12 = llvm.getelementptr %11[10] : (!llvm.ptr) -> !llvm.ptr, i32
    %13 = llvm.ptrtoint %12 : !llvm.ptr to i64
    %14 = llvm.add %13, %1  : i64
    %15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %16 = llvm.ptrtoint %15 : !llvm.ptr to i64
    %17 = llvm.sub %1, %5  : i64
    %18 = llvm.add %16, %17  : i64
    %19 = llvm.urem %18, %1  : i64
    %20 = llvm.sub %18, %19  : i64
    %21 = llvm.inttoptr %20 : i64 to !llvm.ptr
    llvm.br ^bb1(%3 : i64)
  ^bb1(%22: i64):  // 2 preds: ^bb0, ^bb2
    %23 = llvm.icmp "slt" %22, %4 : i64
    llvm.cond_br %23, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    %24 = llvm.getelementptr %21[%22] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    llvm.store %arg0, %24 : i32, !llvm.ptr
    %25 = llvm.add %22, %5  : i64
    llvm.br ^bb1(%25 : i64)
  ^bb3:  // pred: ^bb1
    %26 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %27 = llvm.ptrtoint %26 : !llvm.ptr to i64
    %28 = llvm.add %27, %17  : i64
    %29 = llvm.urem %28, %1  : i64
    %30 = llvm.sub %28, %29  : i64
    %31 = llvm.inttoptr %30 : i64 to !llvm.ptr
    llvm.br ^bb4(%3 : i64)
  ^bb4(%32: i64):  // 2 preds: ^bb3, ^bb5
    %33 = llvm.icmp "slt" %32, %4 : i64
    llvm.cond_br %33, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    %34 = llvm.getelementptr %31[%32] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    llvm.store %0, %34 : i32, !llvm.ptr
    %35 = llvm.add %32, %5  : i64
    llvm.br ^bb4(%35 : i64)
  ^bb6:  // pred: ^bb4
    %36 = llvm.mul %2, %5  : i64
    %37 = llvm.getelementptr %11[1] : (!llvm.ptr) -> !llvm.ptr, i32
    %38 = llvm.ptrtoint %37 : !llvm.ptr to i64
    %39 = llvm.mul %36, %38  : i64
    "llvm.intr.memcpy"(%31, %10, %39) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
    llvm.br ^bb7(%3 : i64)
  ^bb7(%40: i64):  // 2 preds: ^bb6, ^bb8
    %41 = llvm.icmp "slt" %40, %4 : i64
    llvm.cond_br %41, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %42 = llvm.getelementptr %31[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %43 = llvm.load %42 : !llvm.ptr -> i32
    %44 = llvm.getelementptr %21[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %45 = llvm.load %44 : !llvm.ptr -> i32
    %46 = llvm.add %43, %45  : i32
    llvm.store %46, %42 : i32, !llvm.ptr
    %47 = llvm.add %40, %5  : i64
    llvm.br ^bb7(%47 : i64)
  ^bb9:  // pred: ^bb7
    %48 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
    %49 = llvm.ptrtoint %48 : !llvm.ptr to i64
    %50 = llvm.add %49, %17  : i64
    %51 = llvm.urem %50, %1  : i64
    %52 = llvm.sub %50, %51  : i64
    %53 = llvm.inttoptr %52 : i64 to !llvm.ptr
    %54 = llvm.mul %4, %5  : i64
    %55 = llvm.mul %54, %38  : i64
    "llvm.intr.memcpy"(%53, %8, %55) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
    llvm.br ^bb10(%3 : i64)
  ^bb10(%56: i64):  // 2 preds: ^bb9, ^bb14
    %57 = llvm.icmp "slt" %56, %4 : i64
    llvm.cond_br %57, ^bb11, ^bb15
  ^bb11:  // pred: ^bb10
    llvm.br ^bb12(%3 : i64)
  ^bb12(%58: i64):  // 2 preds: ^bb11, ^bb13
    %59 = llvm.icmp "slt" %58, %4 : i64
    llvm.cond_br %59, ^bb13, ^bb14
  ^bb13:  // pred: ^bb12
    %60 = llvm.add %56, %58  : i64
    %61 = llvm.urem %60, %4  : i64
    %62 = llvm.getelementptr %31[%58] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %63 = llvm.load %62 : !llvm.ptr -> i32
    %64 = llvm.getelementptr %31[%56] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %65 = llvm.load %64 : !llvm.ptr -> i32
    %66 = llvm.mul %65, %63  : i32
    %67 = llvm.getelementptr %53[%61] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %68 = llvm.load %67 : !llvm.ptr -> i32
    %69 = llvm.add %66, %68  : i32
    llvm.store %69, %67 : i32, !llvm.ptr
    %70 = llvm.add %58, %5  : i64
    llvm.br ^bb12(%70 : i64)
  ^bb14:  // pred: ^bb12
    %71 = llvm.add %56, %5  : i64
    llvm.br ^bb10(%71 : i64)
  ^bb15:  // pred: ^bb10
    llvm.br ^bb16(%3 : i64)
  ^bb16(%72: i64):  // 2 preds: ^bb15, ^bb17
    %73 = llvm.icmp "slt" %72, %4 : i64
    llvm.cond_br %73, ^bb17, ^bb18
  ^bb17:  // pred: ^bb16
    %74 = llvm.getelementptr %53[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %75 = llvm.load %74 : !llvm.ptr -> i32
    %76 = llvm.getelementptr %21[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %77 = llvm.load %76 : !llvm.ptr -> i32
    %78 = llvm.sub %75, %77  : i32
    llvm.store %78, %74 : i32, !llvm.ptr
    %79 = llvm.add %72, %5  : i64
    llvm.br ^bb16(%79 : i64)
  ^bb18:  // pred: ^bb16
    llvm.br ^bb19(%5, %0 : i64, i32)
  ^bb19(%80: i64, %81: i32):  // 2 preds: ^bb18, ^bb20
    %82 = llvm.icmp "slt" %80, %6 : i64
    llvm.cond_br %82, ^bb20, ^bb21
  ^bb20:  // pred: ^bb19
    %83 = llvm.sub %6, %80  : i64
    %84 = llvm.mul %arg0, %81  : i32
    %85 = llvm.getelementptr %53[%83] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %86 = llvm.load %85 : !llvm.ptr -> i32
    %87 = llvm.add %84, %86  : i32
    %88 = llvm.add %80, %5  : i64
    llvm.br ^bb19(%88, %87 : i64, i32)
  ^bb21:  // pred: ^bb19
    llvm.call @free(%15) : (!llvm.ptr) -> ()
    llvm.call @free(%26) : (!llvm.ptr) -> ()
    llvm.call @free(%48) : (!llvm.ptr) -> ()
    llvm.return %81 : i32
  }
}

Exiting MLIR

In the MLIR parlance, the LLVM dialect is an “exit” dialect, meaning after lowering to LLVM you run a different tool that generates code used by an external system. In our case, this will be LLVM’s internal representation (“LLVM IR” which is different from the LLVM MLIR dialect). The “code gen” step is often called translation in the MLIR docs. Here are the official docs on generating LLVM IR.

The codegen tool is called mlir-translate, and it has an --mlir-to-llvmir option. Running the command below on our output IR above gives the IR in this gist.

$  bazel build @llvm-project//mlir:mlir-translate
$  bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir

Next, to compile LLVM IR with LLVM directly, we use the llc tool that has the --filetype=obj option to emit an object file. Without that flag, you can see a textual representation of the machine code

$ bazel build @llvm-project//llvm:llc
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc

# textual representation
        .text
        .file   "LLVMDialectModule"
        .globl  test_poly_fn                    # -- Begin function test_poly_fn
        .p2align        4, 0x90
        .type   test_poly_fn,@function
test_poly_fn:                           # @test_poly_fn
        .cfi_startproc
# %bb.0:
        pushq   %rbp
        .cfi_def_cfa_offset 16
        pushq   %r15
        .cfi_def_cfa_offset 24
        pushq   %r14
        .cfi_def_cfa_offset 32
        pushq   %r13
        .cfi_def_cfa_offset 40
        pushq   %r12
        .cfi_def_cfa_offset 48
        pushq   %rbx
        .cfi_def_cfa_offset 56
        pushq   %rax
        .cfi_def_cfa_offset 64
        .cfi_offset %rbx, -56
        .cfi_offset %r12, -48
<snip>

And finally, we can save the object file and compile and link it to a C main that calls the function and prints the result.

$ cat tests/poly_to_llvm_main.c
#include <stdio.h>

// This is the function we want to call from LLVM
int test_poly_fn(int x);

int main(int argc, char *argv[]) {
  int i = 1;
  int result = test_poly_fn(i);
  printf("Result: %d\n", result);
  return 0;
}
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc --filetype=obj > poly_to_llvm.o
$ clang -c tests/poly_to_llvm_main.c && clang poly_to_llvm_main.o poly_to_llvm.o -o a.out
$ ./a.out
Result: 320

The polynomial computed by the test function is the rather bizarre:

((1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9) + (2 + 3*x + 4*x**2))**2 - (1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9)

But computing this mod $x^{10} – 1$ in Sympy, we get… 351. Uh oh. So somewhere along the way we messed up the lowering.

Before we fix it, let’s encode the whole process as a lit test. This requires making the binaries mlir-translate, llc, and clang available to the test runner, and setting up the RUN pipeline inside the test file. This is contained in this commit. Note %t tells lit to generate a test-unique temporary file.

// RUN: tutorial-opt --poly-to-llvm %s | mlir-translate --mlir-to-llvmir | llc -filetype=obj > %t
// RUN: clang -c poly_to_llvm_main.c
// RUN: clang poly_to_llvm_main.o %t -o a.out
// RUN: ./a.out | FileCheck

// CHECK: 351
func.func @test_poly_fn(%arg : i32) -> i32 {
  %tens = tensor.splat %arg : tensor<10xi32>
  %input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.add %0, %input : !poly.poly<10>
  %2 = poly.mul %1, %1 : !poly.poly<10>
  %3 = poly.sub %2, %input : !poly.poly<10>
  %4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
  return %4 : i32
}

Then running this as a normal lit test fails as

error: CHECK: expected string not found in input
# | // CHECK: 351
# |           ^
# | <stdin>:1:1: note: scanning from here
# | Result: 320
# | ^
# | <stdin>:1:9: note: possible intended match here
# | Result: 320

Fixing the bug in the lowering

How do you find the bug in a lowering? Apparently there are some folks doing this via formal verification, formalizing the dialects and the lowerings in lean to prove correctness. I don’t have the time for that in this tutorial, so instead I simplify/expand the tests and squint.

Simplifying the test to computing the simpler function $t \mapsto 1 + t + t^2$, I see that an input of 1 gives 2, when it should be 3. This suggests that eval is lowered wrong, at least. An input of 5 gives 4, so they’re all off by one term. This likely means I have an off-by-one error in the loop that eval lowers to, and indeed that’s the problem. I was essentially doing this, where $N$ is the degree of the polynomial ($N-1$ is the largest legal index into the tensor):

accum = 0
for 1 <= i < N+1
  index = N+1 - i
  accum = accum * point + coeffs[index]

When it should have been

accum = 0
for 1 <= i < N+1
  index = N - i
  accum = accum * point + coeffs[index]

This commit fixes it, and now all the tests pass.

Aside: I don’t know of a means to do more resilient testing in MLIR. For example, I don’t know of a fuzz testing or property testing option. In my main project, HEIR, I hand-rolled a simple test generation routine that reads a config file and spits out MLIR test files. That allows me to jot down a bunch of tests quickly, but doesn’t give the full power of a system like hypothesis, which is a framework I’m quite fond of. I think the xDSL project would work in my favor here, but I have yet to dive into that, and as far as I know it requires re-defining all your custom dialects in Python to use. More on that on a future article.

Taking a step back

This article showed a bit of a naive approach to building up a dialect conversion pipeline, where we just greedily looked for ops to lower and inserted the relevant passes somewhat haphazardly. That worked out OK, but some lower-level passes (converting to LLVM) confounded the overall pipeline when placed too early.

A better approach is to identify the highest level operations and lower those first. But that is only really possible if you already know which passes are available and what they do. For example, elementwise-to-linalg takes something that seems low-level a priori—until noticing that convert-arith-to-llvm silently ignored those ops. Similarly, the implications of converting func to LLVM (which appears to handle more than just the ops in func, were not clear until we tried it and ran into problems.

I don’t have a particularly good solution here besides trial and error. But I appreciate good tools, so I will amuse myself with some ideas.

Since most passes have per-op patterns, it seems like one could write a tool that analyzes an IR, simulates running all possible rewrite patterns from the standard list of passes, and checks which ones succeeded (i.e., the “match” was successful, though most patterns are a combined matchAndRewrite), and what op types were generated as a result. Then once you get to an IR that you don’t know how to lower further, you could run the tool and it would tell you all of your options.

An even more aggressive tool would could construct a complete graph of op-to-op conversions. You could identify an op to lower and a subset of legal dialects or ops (similar to a ConversionTarget), and it would report all possible paths from the op to your desired target.

I imagine performance is an obstacle here, and I also wonder to what extent this could be statically analyzed at least coarsely. For example, instead of running the rewrites, you could statically analyze the implementations of rewrite patterns lowering FooOp for occurrences of create<BarOp>, and just include an edge FooOp -> BarOp as a conservative estimate (it might not generate BarOp for every input).

Plans

At this point we’ve covered the basic ground of MLIR: defining a new dialect, writing some optimization passes, doing some lowerings, and compiling and running a proper binary. The topics we could cover from here on out are rather broad, but here’s a few options:


Want to respond? Send me an email, post a webmention, or find me elsewhere on the internet.

DOI: https://doi.org/10.59350/73rqq-e1844