In the last article we lowered our custom poly dialect to standard MLIR dialects. In this article we’ll continue lowering it to LLVM IR, exporting it out of MLIR to LLVM, and then compiling to x86 machine code.
The code for this article is in this pull request, and as usual the commits are organized to be read in order.
Defining a Pipeline
The first step in lowering to machine code is to lower to an “exit dialect.” That is, a dialect from which there is a code-gen tool that converts MLIR code to some non-MLIR format. In our case, we’re targeting x86, so the exit dialect is the LLVM dialect, and the code-gen tool is the binary mlir-translate
(more on that later). Lowering to an exit dialect, as it turns out, is not all that simple.
One of the things I’ve struggled with when learning MLIR is how to compose all the different lowerings into a pipeline that ends in the result I want, especially when other people wrote those pipelines. When starting from a high level dialect like linalg (linear algebra), there can be dozens of lowerings involved, and some of them can reintroduce ops from dialects you thought you completely lowered already, or have complex pre-conditions or relationships to other lowerings. There are also simply too many lowerings to easily scan.
There are two ways to specify a lowering from the MLIR binary. One is completely on the command line. You can use the --pass-pipeline
flag with a tiny DSL, like this
bazel run //tools:tutorial-opt -- foo.mlir \
--pass-pipeline='builtin.module(func.func(cse,canonicalize),convert-func-to-llvm)'
Above the thing that looks like a function call is “anchoring” a sequences of passes to operate on a particular op (allowing them to run in parallel across ops).
Thankfully, you can also declare the pipeline in code, and wrap it up in a single flag. The above might be equivalently defined as follows:
void customPipelineBuilder(mlir::OpPassManager &pm) {
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanoncalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
pm.addPass(createConvertFuncToLLVMPass()); // runs on builtin.module by default
}
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
<... register dialects ...>
mlir::PassPipelineRegistration<>(
"my-pipeline", "A custom pipeline", customPipelineBuilder);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Tutorial opt main", registry));
}
We’ll do the actually-runnable analogue of this in the next section when we lower the poly dialect to LLVM.
Lowering Poly to LLVM
In this section we’ll define a pipeline lowering poly
to LLVM and show the MLIR along each step. Strap in, there’s going to be a lot of MLIR code in this article.
The process I’ve used to build up a big pipeline is rather toilsome and incremental. Basically, start from an empty pipeline and the starting MLIR, then look for the “highest level” op you can think of, add a pass that lowers it to the pipeline, and if that pass fails, figure out what pass is required before it. Then repeat until you have achieved your target.
In this commit we define a pass pipeline --poly-to-llvm
that includes only the --poly-to-standard
pass defined last time, along with a canonicalization pass. Then we start from this IR:
$ cat $PWD/tests/poly_to_llvm.mlir
func.func @test_poly_fn(%arg : i32) -> i32 {
%tens = tensor.splat %arg : tensor<10xi32>
%input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
%0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%1 = poly.add %0, %input : !poly.poly<10>
%2 = poly.mul %1, %1 : !poly.poly<10>
%3 = poly.sub %2, %input : !poly.poly<10>
%4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
return %4 : i32
}
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir
module {
func.func @test_poly_fn(%arg0: i32) -> i32 {
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : tensor<10xi32>
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
%splat = tensor.splat %arg0 : tensor<10xi32>
%padded = tensor.pad %cst_0 low[0] high[7] {
^bb0(%arg1: index):
tensor.yield %c0_i32 : i32
} : tensor<3xi32> to tensor<10xi32>
%0 = arith.addi %padded, %splat : tensor<10xi32>
%1 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
%4 = scf.for %arg3 = %c0 to %c10 step %c1 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
%5 = arith.addi %arg1, %arg3 : index
%6 = arith.remui %5, %c10 : index
%extracted = tensor.extract %0[%arg3] : tensor<10xi32>
%extracted_1 = tensor.extract %0[%arg1] : tensor<10xi32>
%7 = arith.muli %extracted_1, %extracted : i32
%extracted_2 = tensor.extract %arg4[%6] : tensor<10xi32>
%8 = arith.addi %7, %extracted_2 : i32
%inserted = tensor.insert %8 into %arg4[%6] : tensor<10xi32>
scf.yield %inserted : tensor<10xi32>
}
scf.yield %4 : tensor<10xi32>
}
%2 = arith.subi %1, %splat : tensor<10xi32>
%3 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
%4 = arith.subi %c11, %arg1 : index
%5 = arith.muli %arg0, %arg2 : i32
%extracted = tensor.extract %2[%4] : tensor<10xi32>
%6 = arith.addi %5, %extracted : i32
scf.yield %6 : i32
}
return %3 : i32
}
}
Let’s naively start from the top and see what happens. One way you can try this more interactively is to run a tentative pass to add like bazel run //tools:tutorial-opt -- --poly-to-llvm --pass-to-try $PWD/tests/poly_to_llvm.mlir
, and it will run the hard-coded pipeline followed by the new pass.
The first op that can be lowered is func.func
, and there is a convert-func-to-llvm pass, which we add in this commit. It turns out to process a lot more than just the func op:
module attributes {llvm.data_layout = ""} {
llvm.func @test_poly_fn(%arg0: i32) -> i32 {
%0 = llvm.mlir.constant(11 : index) : i64
%1 = builtin.unrealized_conversion_cast %0 : i64 to index
%2 = llvm.mlir.constant(1 : index) : i64
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
%4 = llvm.mlir.constant(10 : index) : i64
%5 = builtin.unrealized_conversion_cast %4 : i64 to index
%6 = llvm.mlir.constant(0 : index) : i64
%7 = builtin.unrealized_conversion_cast %6 : i64 to index
%cst = arith.constant dense<0> : tensor<10xi32>
%8 = llvm.mlir.constant(0 : i32) : i32
%cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
%splat = tensor.splat %arg0 : tensor<10xi32>
%padded = tensor.pad %cst_0 low[0] high[7] {
^bb0(%arg1: index):
tensor.yield %8 : i32
} : tensor<3xi32> to tensor<10xi32>
%9 = arith.addi %padded, %splat : tensor<10xi32>
%10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
%13 = builtin.unrealized_conversion_cast %arg1 : index to i64
%14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
%15 = builtin.unrealized_conversion_cast %arg3 : index to i64
%16 = llvm.add %13, %15 : i64
%17 = llvm.urem %16, %4 : i64
%18 = builtin.unrealized_conversion_cast %17 : i64 to index
%extracted = tensor.extract %9[%arg3] : tensor<10xi32>
%extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32>
%19 = llvm.mul %extracted_1, %extracted : i32
%extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32>
%20 = llvm.add %19, %extracted_2 : i32
%inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32>
scf.yield %inserted : tensor<10xi32>
}
scf.yield %14 : tensor<10xi32>
}
%11 = arith.subi %10, %splat : tensor<10xi32>
%12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) {
%13 = builtin.unrealized_conversion_cast %arg1 : index to i64
%14 = llvm.sub %0, %13 : i64
%15 = builtin.unrealized_conversion_cast %14 : i64 to index
%16 = llvm.mul %arg0, %arg2 : i32
%extracted = tensor.extract %11[%15] : tensor<10xi32>
%17 = llvm.add %16, %extracted : i32
scf.yield %17 : i32
}
llvm.return %12 : i32
}
}
Notably, this pass converted most of the arithmetic operations—though not the tensor-generating ones that use the dense
attribute—and inserted a number of unrealized_conversion_cast
ops for the resulting type conflicts, which we’ll have to get rid of eventually, but we can’t now because the values on both sides of the type conversion are still used.
Next we’ll lower arith
to LLVM using the suggestively-named convert-arith-to-llvm
in this commit. However, it has no effect on the resulting IR, and the arith
ops remain. What gives? It turns out that arith ops that operate on tensors are not supported by convert-arith-to-llvm
. To deal with this, we need a special pass called convert-elementwise-to-linalg
, which lowers these ops to linalg.generic
ops. (I plan to cover linalg.generic
in a future tutorial).
We add it in this commit, and this is the diff between the above IR and the new output (right >
is the new IR):
> #map = affine_map<(d0) -> (d0)>
19c20,24
< %9 = arith.addi %padded, %splat : tensor<10xi32>
---
> %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
> ^bb0(%in: i32, %in_1: i32, %out: i32):
> %13 = arith.addi %in, %in_1 : i32
> linalg.yield %13 : i32
> } -> tensor<10xi32>
37c42,46
< %11 = arith.subi %10, %splat : tensor<10xi32>
---
> %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) {
> ^bb0(%in: i32, %in_1: i32, %out: i32):
> %13 = arith.subi %in, %in_1 : i32
> linalg.yield %13 : i32
> } -> tensor<10xi32>
49a59
>
This is nice, but now you can see we have a new problem: lowering linalg
, and re-lowering the arith ops that were inserted by the convert-elementwise-to-linalg
pass. So adding back the arith pass at the end in this commit replaces the two inserted arith
ops, giving this IR
#map = affine_map<(d0) -> (d0)>
module attributes {llvm.data_layout = ""} {
llvm.func @test_poly_fn(%arg0: i32) -> i32 {
%0 = llvm.mlir.constant(11 : index) : i64
%1 = builtin.unrealized_conversion_cast %0 : i64 to index
%2 = llvm.mlir.constant(1 : index) : i64
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
%4 = llvm.mlir.constant(10 : index) : i64
%5 = builtin.unrealized_conversion_cast %4 : i64 to index
%6 = llvm.mlir.constant(0 : index) : i64
%7 = builtin.unrealized_conversion_cast %6 : i64 to index
%cst = arith.constant dense<0> : tensor<10xi32>
%8 = llvm.mlir.constant(0 : i32) : i32
%cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
%splat = tensor.splat %arg0 : tensor<10xi32>
%padded = tensor.pad %cst_0 low[0] high[7] {
^bb0(%arg1: index):
tensor.yield %8 : i32
} : tensor<3xi32> to tensor<10xi32>
%9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
^bb0(%in: i32, %in_1: i32, %out: i32):
%13 = llvm.add %in, %in_1 : i32
linalg.yield %13 : i32
} -> tensor<10xi32>
%10 = scf.for %arg1 = %7 to %5 step %3 iter_args(%arg2 = %cst) -> (tensor<10xi32>) {
%13 = builtin.unrealized_conversion_cast %arg1 : index to i64
%14 = scf.for %arg3 = %7 to %5 step %3 iter_args(%arg4 = %arg2) -> (tensor<10xi32>) {
%15 = builtin.unrealized_conversion_cast %arg3 : index to i64
%16 = llvm.add %13, %15 : i64
%17 = llvm.urem %16, %4 : i64
%18 = builtin.unrealized_conversion_cast %17 : i64 to index
%extracted = tensor.extract %9[%arg3] : tensor<10xi32>
%extracted_1 = tensor.extract %9[%arg1] : tensor<10xi32>
%19 = llvm.mul %extracted_1, %extracted : i32
%extracted_2 = tensor.extract %arg4[%18] : tensor<10xi32>
%20 = llvm.add %19, %extracted_2 : i32
%inserted = tensor.insert %20 into %arg4[%18] : tensor<10xi32>
scf.yield %inserted : tensor<10xi32>
}
scf.yield %14 : tensor<10xi32>
}
%11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%10, %splat : tensor<10xi32>, tensor<10xi32>) outs(%10 : tensor<10xi32>) {
^bb0(%in: i32, %in_1: i32, %out: i32):
%13 = llvm.sub %in, %in_1 : i32
linalg.yield %13 : i32
} -> tensor<10xi32>
%12 = scf.for %arg1 = %3 to %1 step %3 iter_args(%arg2 = %8) -> (i32) {
%13 = builtin.unrealized_conversion_cast %arg1 : index to i64
%14 = llvm.sub %0, %13 : i64
%15 = builtin.unrealized_conversion_cast %14 : i64 to index
%16 = llvm.mul %arg0, %arg2 : i32
%extracted = tensor.extract %11[%15] : tensor<10xi32>
%17 = llvm.add %16, %extracted : i32
scf.yield %17 : i32
}
llvm.return %12 : i32
}
}
However, there are still two arith
ops remaining: arith.constant
using dense
attributes to define tensors. There’s no obvious pass that handles this from looking at the list of available passes. As we’ll see, it’s the bufferization pass that handles this (we discussed it briefly in the previous article), and we should try to lower as much as possible before bufferizing. In general, bufferizing makes optimization passes harder.
The next thing to try lowering is tensor.splat
. There’s no obvious pass as well (tensor-to-linalg
doesn’t lower it), and searching the LLVM codebase for SplatOp
we find this pattern which suggests tensor.splat
is also lowered during bufferization, and produces linalg.map
ops. So we’ll have to lower those as well eventually.
But what tensor-to-linalg
does lower is the next op: tensor.pad
. It replaces a pad with the following
< %padded = tensor.pad %cst_0 low[0] high[7] {
< ^bb0(%arg1: index):
< tensor.yield %8 : i32
< } : tensor<3xi32> to tensor<10xi32>
< %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%padded, %splat : tensor<10xi32>, tensor<10xi32>) outs(%padded : tensor<10xi32>) {
---
> %9 = tensor.empty() : tensor<10xi32>
> %10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32>
> %inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32>
> %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {
We add that pass in this commit.
Next we have all of these linalg
ops, as well as the scf.for
loops. The linalg
ops can lower naively as for loops, so we’ll do that first. There are three options:
-convert-linalg-to-affine-loops
¶: Lower the operations from the linalg dialect into affine loops-convert-linalg-to-loops
¶: Lower the operations from the linalg dialect into loops-convert-linalg-to-parallel-loops
¶: Lower the operations from the linalg dialect into parallel loops
I want to make this initial pipeline as simple and naive as possible, so convert-linalg-to-loops
it is. However, the pass does nothing on this IR (added in this commit). If you run the pass with --debug
you can find an explanation:
** Failure : expected linalg op with buffer semantics
So linalg
must be bufferized before it can be lowered to loops.
However, we can tackle the scf
to LLVM step with a combination of two passes: convert-scf-to-cf
and convert-cf-to-llvm
, in this commit. The lowering added arith.cmpi
for the loop predicates, so moving arith-to-llvm
to the end of the pipeline fixes that (commit). In the end we get this IR:
#map = affine_map<(d0) -> (d0)>
module attributes {llvm.data_layout = ""} {
llvm.func @test_poly_fn(%arg0: i32) -> i32 {
%0 = llvm.mlir.constant(11 : index) : i64
%1 = builtin.unrealized_conversion_cast %0 : i64 to index
%2 = llvm.mlir.constant(1 : index) : i64
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
%4 = llvm.mlir.constant(10 : index) : i64
%5 = builtin.unrealized_conversion_cast %4 : i64 to index
%6 = llvm.mlir.constant(0 : index) : i64
%7 = builtin.unrealized_conversion_cast %6 : i64 to index
%cst = arith.constant dense<0> : tensor<10xi32>
%8 = llvm.mlir.constant(0 : i32) : i32
%cst_0 = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
%splat = tensor.splat %arg0 : tensor<10xi32>
%9 = tensor.empty() : tensor<10xi32>
%10 = linalg.fill ins(%8 : i32) outs(%9 : tensor<10xi32>) -> tensor<10xi32>
%inserted_slice = tensor.insert_slice %cst_0 into %10[0] [3] [1] : tensor<3xi32> into tensor<10xi32>
%11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%inserted_slice, %splat : tensor<10xi32>, tensor<10xi32>) outs(%inserted_slice : tensor<10xi32>) {
^bb0(%in: i32, %in_4: i32, %out: i32):
%43 = llvm.add %in, %in_4 : i32
linalg.yield %43 : i32
} -> tensor<10xi32>
cf.br ^bb1(%7, %cst : index, tensor<10xi32>)
^bb1(%12: index, %13: tensor<10xi32>): // 2 preds: ^bb0, ^bb5
%14 = builtin.unrealized_conversion_cast %12 : index to i64
%15 = llvm.icmp "slt" %14, %4 : i64
llvm.cond_br %15, ^bb2, ^bb6
^bb2: // pred: ^bb1
%16 = builtin.unrealized_conversion_cast %12 : index to i64
cf.br ^bb3(%7, %13 : index, tensor<10xi32>)
^bb3(%17: index, %18: tensor<10xi32>): // 2 preds: ^bb2, ^bb4
%19 = builtin.unrealized_conversion_cast %17 : index to i64
%20 = llvm.icmp "slt" %19, %4 : i64
llvm.cond_br %20, ^bb4, ^bb5
^bb4: // pred: ^bb3
%21 = builtin.unrealized_conversion_cast %17 : index to i64
%22 = llvm.add %16, %21 : i64
%23 = llvm.urem %22, %4 : i64
%24 = builtin.unrealized_conversion_cast %23 : i64 to index
%extracted = tensor.extract %11[%17] : tensor<10xi32>
%extracted_1 = tensor.extract %11[%12] : tensor<10xi32>
%25 = llvm.mul %extracted_1, %extracted : i32
%extracted_2 = tensor.extract %18[%24] : tensor<10xi32>
%26 = llvm.add %25, %extracted_2 : i32
%inserted = tensor.insert %26 into %18[%24] : tensor<10xi32>
%27 = llvm.add %19, %2 : i64
%28 = builtin.unrealized_conversion_cast %27 : i64 to index
cf.br ^bb3(%28, %inserted : index, tensor<10xi32>)
^bb5: // pred: ^bb3
%29 = llvm.add %14, %2 : i64
%30 = builtin.unrealized_conversion_cast %29 : i64 to index
cf.br ^bb1(%30, %18 : index, tensor<10xi32>)
^bb6: // pred: ^bb1
%31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%13, %splat : tensor<10xi32>, tensor<10xi32>) outs(%13 : tensor<10xi32>) {
^bb0(%in: i32, %in_4: i32, %out: i32):
%43 = llvm.sub %in, %in_4 : i32
linalg.yield %43 : i32
} -> tensor<10xi32>
cf.br ^bb7(%3, %8 : index, i32)
^bb7(%32: index, %33: i32): // 2 preds: ^bb6, ^bb8
%34 = builtin.unrealized_conversion_cast %32 : index to i64
%35 = llvm.icmp "slt" %34, %0 : i64
llvm.cond_br %35, ^bb8, ^bb9
^bb8: // pred: ^bb7
%36 = builtin.unrealized_conversion_cast %32 : index to i64
%37 = llvm.sub %0, %36 : i64
%38 = builtin.unrealized_conversion_cast %37 : i64 to index
%39 = llvm.mul %arg0, %33 : i32
%extracted_3 = tensor.extract %31[%38] : tensor<10xi32>
%40 = llvm.add %39, %extracted_3 : i32
%41 = llvm.add %34, %2 : i64
%42 = builtin.unrealized_conversion_cast %41 : i64 to index
cf.br ^bb7(%42, %40 : index, i32)
^bb9: // pred: ^bb7
llvm.return %33 : i32
}
}
Bufferization
Last time I wrote at some length about the dialect conversion framework, and how it’s complicated mostly because of the bufferization passes, which were split across multiple passes and introduced type conflicts that required special type materialization and later resolution.
Well, turns out these bufferization passes are now deprecated. These are the
official
docs,
but the short story is that the network of bufferization passes was replaced by
a one-shot-bufferize
pass, with some boilerplate cleanup passes. This is the
sequence of passes that is recommended by the docs, along with an option to
ensure that function signatures are bufferized as well:
// One-shot bufferize, from
// https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation
bufferization::OneShotBufferizationOptions bufferizationOptions;
bufferizationOptions.bufferizeFunctionBoundaries = true;
manager.addPass(
bufferization::createOneShotBufferizePass(bufferizationOptions));
manager.addPass(memref::createExpandReallocPass());
manager.addPass(bufferization::createOwnershipBasedBufferDeallocationPass());
manager.addPass(createCanonicalizerPass());
manager.addPass(bufferization::createBufferDeallocationSimplificationPass());
manager.addPass(bufferization::createLowerDeallocationsPass());
manager.addPass(createCSEPass());
manager.addPass(createCanonicalizerPass());
This pipeline exists upstream in a helper, but it was added after the commit we’ve pinned to. Moreover, some of the passes above were added after the commit we’ve pinned to! At this point it’s worth updating the upstream MLIR commit, which I did in this commit, and it required a few additional fixes to deal with API updates (starting with this commit and ending with this commit). Then this commit adds the one-shot bufferization pass and helpers to the pipeline.
Even after all of this, I was dismayed to learn that the pass still did not lower the linalg
operators. After a bit of digging, I realized it was because the func-to-llvm
pass was running too early in the pipeline. So this commit moves the pass later. Just because we’re building this a bit backwards and naively, let’s comment out the tail end of the pipeline to see the result of bufferization.
Before convert-linalg-to-loops
, but omitting the rest:
#map = affine_map<(d0) -> (d0)>
module {
memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
func.func @test_poly_fn(%arg0: i32) -> i32 {
%c0_i32 = arith.constant 0 : i32
%c11 = arith.constant 11 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%0 = memref.get_global @__constant_10xi32 : memref<10xi32>
%1 = memref.get_global @__constant_3xi32 : memref<3xi32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
linalg.map outs(%alloc : memref<10xi32>)
() {
linalg.yield %arg0 : i32
}
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
linalg.fill ins(%c0_i32 : i32) outs(%alloc_0 : memref<10xi32>)
%subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_0, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_0 : memref<10xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%3 = arith.addi %in, %in_2 : i32
linalg.yield %3 : i32
}
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32>
scf.for %arg1 = %c0 to %c10 step %c1 {
scf.for %arg2 = %c0 to %c10 step %c1 {
%3 = arith.addi %arg1, %arg2 : index
%4 = arith.remui %3, %c10 : index
%5 = memref.load %alloc_0[%arg2] : memref<10xi32>
%6 = memref.load %alloc_0[%arg1] : memref<10xi32>
%7 = arith.muli %6, %5 : i32
%8 = memref.load %alloc_1[%4] : memref<10xi32>
%9 = arith.addi %7, %8 : i32
memref.store %9, %alloc_1[%4] : memref<10xi32>
}
}
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%alloc_1, %alloc : memref<10xi32>, memref<10xi32>) outs(%alloc_1 : memref<10xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%3 = arith.subi %in, %in_2 : i32
linalg.yield %3 : i32
}
%2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
%3 = arith.subi %c11, %arg1 : index
%4 = arith.muli %arg0, %arg2 : i32
%5 = memref.load %alloc_1[%3] : memref<10xi32>
%6 = arith.addi %4, %5 : i32
scf.yield %6 : i32
}
memref.dealloc %alloc : memref<10xi32>
memref.dealloc %alloc_0 : memref<10xi32>
memref.dealloc %alloc_1 : memref<10xi32>
return %2 : i32
}
}
After convert-linalg-to-loops
, omitting the conversions to LLVM, wherein there are no longer any linalg
ops, just loops:
module {
memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
func.func @test_poly_fn(%arg0: i32) -> i32 {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%c1 = arith.constant 1 : index
%c0_i32 = arith.constant 0 : i32
%c11 = arith.constant 11 : index
%0 = memref.get_global @__constant_10xi32 : memref<10xi32>
%1 = memref.get_global @__constant_3xi32 : memref<3xi32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
scf.for %arg1 = %c0 to %c10 step %c1 {
memref.store %arg0, %alloc[%arg1] : memref<10xi32>
}
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
scf.for %arg1 = %c0 to %c10 step %c1 {
memref.store %c0_i32, %alloc_0[%arg1] : memref<10xi32>
}
%subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
memref.copy %1, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
scf.for %arg1 = %c0 to %c10 step %c1 {
%3 = memref.load %alloc_0[%arg1] : memref<10xi32>
%4 = memref.load %alloc[%arg1] : memref<10xi32>
%5 = arith.addi %3, %4 : i32
memref.store %5, %alloc_0[%arg1] : memref<10xi32>
}
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
memref.copy %0, %alloc_1 : memref<10xi32> to memref<10xi32>
scf.for %arg1 = %c0 to %c10 step %c1 {
scf.for %arg2 = %c0 to %c10 step %c1 {
%3 = arith.addi %arg1, %arg2 : index
%4 = arith.remui %3, %c10 : index
%5 = memref.load %alloc_0[%arg2] : memref<10xi32>
%6 = memref.load %alloc_0[%arg1] : memref<10xi32>
%7 = arith.muli %6, %5 : i32
%8 = memref.load %alloc_1[%4] : memref<10xi32>
%9 = arith.addi %7, %8 : i32
memref.store %9, %alloc_1[%4] : memref<10xi32>
}
}
scf.for %arg1 = %c0 to %c10 step %c1 {
%3 = memref.load %alloc_1[%arg1] : memref<10xi32>
%4 = memref.load %alloc[%arg1] : memref<10xi32>
%5 = arith.subi %3, %4 : i32
memref.store %5, %alloc_1[%arg1] : memref<10xi32>
}
%2 = scf.for %arg1 = %c1 to %c11 step %c1 iter_args(%arg2 = %c0_i32) -> (i32) {
%3 = arith.subi %c11, %arg1 : index
%4 = arith.muli %arg0, %arg2 : i32
%5 = memref.load %alloc_1[%3] : memref<10xi32>
%6 = arith.addi %4, %5 : i32
scf.yield %6 : i32
}
memref.dealloc %alloc : memref<10xi32>
memref.dealloc %alloc_0 : memref<10xi32>
memref.dealloc %alloc_1 : memref<10xi32>
return %2 : i32
}
}
And then finally, the rest of the pipeline, as defined so far.
module {
memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[2, 3, 4]> {alignment = 64 : i64}
memref.global "private" constant @__constant_10xi32 : memref<10xi32> = dense<0> {alignment = 64 : i64}
llvm.func @test_poly_fn(%arg0: i32) -> i32 {
%0 = llvm.mlir.constant(0 : index) : i64
%1 = builtin.unrealized_conversion_cast %0 : i64 to index
%2 = llvm.mlir.constant(10 : index) : i64
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
%4 = llvm.mlir.constant(1 : index) : i64
%5 = builtin.unrealized_conversion_cast %4 : i64 to index
%6 = llvm.mlir.constant(0 : i32) : i32
%7 = llvm.mlir.constant(11 : index) : i64
%8 = builtin.unrealized_conversion_cast %7 : i64 to index
%9 = memref.get_global @__constant_10xi32 : memref<10xi32>
%10 = memref.get_global @__constant_3xi32 : memref<3xi32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
cf.br ^bb1(%1 : index)
^bb1(%11: index): // 2 preds: ^bb0, ^bb2
%12 = builtin.unrealized_conversion_cast %11 : index to i64
%13 = llvm.icmp "slt" %12, %2 : i64
llvm.cond_br %13, ^bb2, ^bb3
^bb2: // pred: ^bb1
memref.store %arg0, %alloc[%11] : memref<10xi32>
%14 = llvm.add %12, %4 : i64
%15 = builtin.unrealized_conversion_cast %14 : i64 to index
cf.br ^bb1(%15 : index)
^bb3: // pred: ^bb1
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
cf.br ^bb4(%1 : index)
^bb4(%16: index): // 2 preds: ^bb3, ^bb5
%17 = builtin.unrealized_conversion_cast %16 : index to i64
%18 = llvm.icmp "slt" %17, %2 : i64
llvm.cond_br %18, ^bb5, ^bb6
^bb5: // pred: ^bb4
memref.store %6, %alloc_0[%16] : memref<10xi32>
%19 = llvm.add %17, %4 : i64
%20 = builtin.unrealized_conversion_cast %19 : i64 to index
cf.br ^bb4(%20 : index)
^bb6: // pred: ^bb4
%subview = memref.subview %alloc_0[0] [3] [1] : memref<10xi32> to memref<3xi32, strided<[1]>>
memref.copy %10, %subview : memref<3xi32> to memref<3xi32, strided<[1]>>
cf.br ^bb7(%1 : index)
^bb7(%21: index): // 2 preds: ^bb6, ^bb8
%22 = builtin.unrealized_conversion_cast %21 : index to i64
%23 = llvm.icmp "slt" %22, %2 : i64
llvm.cond_br %23, ^bb8, ^bb9
^bb8: // pred: ^bb7
%24 = memref.load %alloc_0[%21] : memref<10xi32>
%25 = memref.load %alloc[%21] : memref<10xi32>
%26 = llvm.add %24, %25 : i32
memref.store %26, %alloc_0[%21] : memref<10xi32>
%27 = llvm.add %22, %4 : i64
%28 = builtin.unrealized_conversion_cast %27 : i64 to index
cf.br ^bb7(%28 : index)
^bb9: // pred: ^bb7
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
memref.copy %9, %alloc_1 : memref<10xi32> to memref<10xi32>
cf.br ^bb10(%1 : index)
^bb10(%29: index): // 2 preds: ^bb9, ^bb14
%30 = builtin.unrealized_conversion_cast %29 : index to i64
%31 = llvm.icmp "slt" %30, %2 : i64
llvm.cond_br %31, ^bb11, ^bb15
^bb11: // pred: ^bb10
%32 = builtin.unrealized_conversion_cast %29 : index to i64
cf.br ^bb12(%1 : index)
^bb12(%33: index): // 2 preds: ^bb11, ^bb13
%34 = builtin.unrealized_conversion_cast %33 : index to i64
%35 = llvm.icmp "slt" %34, %2 : i64
llvm.cond_br %35, ^bb13, ^bb14
^bb13: // pred: ^bb12
%36 = builtin.unrealized_conversion_cast %33 : index to i64
%37 = llvm.add %32, %36 : i64
%38 = llvm.urem %37, %2 : i64
%39 = builtin.unrealized_conversion_cast %38 : i64 to index
%40 = memref.load %alloc_0[%33] : memref<10xi32>
%41 = memref.load %alloc_0[%29] : memref<10xi32>
%42 = llvm.mul %41, %40 : i32
%43 = memref.load %alloc_1[%39] : memref<10xi32>
%44 = llvm.add %42, %43 : i32
memref.store %44, %alloc_1[%39] : memref<10xi32>
%45 = llvm.add %34, %4 : i64
%46 = builtin.unrealized_conversion_cast %45 : i64 to index
cf.br ^bb12(%46 : index)
^bb14: // pred: ^bb12
%47 = llvm.add %30, %4 : i64
%48 = builtin.unrealized_conversion_cast %47 : i64 to index
cf.br ^bb10(%48 : index)
^bb15: // pred: ^bb10
cf.br ^bb16(%1 : index)
^bb16(%49: index): // 2 preds: ^bb15, ^bb17
%50 = builtin.unrealized_conversion_cast %49 : index to i64
%51 = llvm.icmp "slt" %50, %2 : i64
llvm.cond_br %51, ^bb17, ^bb18
^bb17: // pred: ^bb16
%52 = memref.load %alloc_1[%49] : memref<10xi32>
%53 = memref.load %alloc[%49] : memref<10xi32>
%54 = llvm.sub %52, %53 : i32
memref.store %54, %alloc_1[%49] : memref<10xi32>
%55 = llvm.add %50, %4 : i64
%56 = builtin.unrealized_conversion_cast %55 : i64 to index
cf.br ^bb16(%56 : index)
^bb18: // pred: ^bb16
cf.br ^bb19(%5, %6 : index, i32)
^bb19(%57: index, %58: i32): // 2 preds: ^bb18, ^bb20
%59 = builtin.unrealized_conversion_cast %57 : index to i64
%60 = llvm.icmp "slt" %59, %7 : i64
llvm.cond_br %60, ^bb20, ^bb21
^bb20: // pred: ^bb19
%61 = builtin.unrealized_conversion_cast %57 : index to i64
%62 = llvm.sub %7, %61 : i64
%63 = builtin.unrealized_conversion_cast %62 : i64 to index
%64 = llvm.mul %arg0, %58 : i32
%65 = memref.load %alloc_1[%63] : memref<10xi32>
%66 = llvm.add %64, %65 : i32
%67 = llvm.add %59, %4 : i64
%68 = builtin.unrealized_conversion_cast %67 : i64 to index
cf.br ^bb19(%68, %66 : index, i32)
^bb21: // pred: ^bb19
memref.dealloc %alloc : memref<10xi32>
memref.dealloc %alloc_0 : memref<10xi32>
memref.dealloc %alloc_1 : memref<10xi32>
llvm.return %58 : i32
}
}
The remaining problems:
- There are
cf.br
ops left in there, meaninglower-cf-to-llvm
was unable to convert them, leaving in a bunch of un-removable index types (see third bullet). - We still have a
memref.subview
that is not supported in LLVM - We have a bunch of casts like
builtin.unrealized_conversion_cast %7 : i64 to index
, which are becauseindex
is not part of LLVM.
The second needs a special pass, memref-expand-strided-metadata
, which reduces more complicated memref
ops to simpler ones that can be lowered. The third is fixed by using finalize-memref-to-llvm
, which lowers index
to llvm.ptr
and memref
to llvm.struct
and llvm.array
. A final reconcile-unrealized-casts
removes the cast operations, provided they can safely be removed. Both are combined in this commit.
However, the first one still eluded me for a while, until I figured out through embarrassing trial and error that again func-to-llvm
was too early in the pipeline. Moving it to the end (this commit) resulted in cf.br
being lowered to llvm.br
and llvm.cond_br
.
Finally, this commit adds a set of standard cleanup passes, including constant propagation, common subexpression elimination, dead code elimination, and canonicalization. The final IR looks like this
module {
llvm.func @free(!llvm.ptr)
llvm.func @malloc(i64) -> !llvm.ptr
llvm.mlir.global private constant @__constant_3xi32(dense<[2, 3, 4]> : tensor<3xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<3 x i32>
llvm.mlir.global private constant @__constant_10xi32(dense<0> : tensor<10xi32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<10 x i32>
llvm.func @test_poly_fn(%arg0: i32) -> i32 {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.constant(64 : index) : i64
%2 = llvm.mlir.constant(3 : index) : i64
%3 = llvm.mlir.constant(0 : index) : i64
%4 = llvm.mlir.constant(10 : index) : i64
%5 = llvm.mlir.constant(1 : index) : i64
%6 = llvm.mlir.constant(11 : index) : i64
%7 = llvm.mlir.addressof @__constant_10xi32 : !llvm.ptr
%8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
%9 = llvm.mlir.addressof @__constant_3xi32 : !llvm.ptr
%10 = llvm.getelementptr %9[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<3 x i32>
%11 = llvm.mlir.zero : !llvm.ptr
%12 = llvm.getelementptr %11[10] : (!llvm.ptr) -> !llvm.ptr, i32
%13 = llvm.ptrtoint %12 : !llvm.ptr to i64
%14 = llvm.add %13, %1 : i64
%15 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
%16 = llvm.ptrtoint %15 : !llvm.ptr to i64
%17 = llvm.sub %1, %5 : i64
%18 = llvm.add %16, %17 : i64
%19 = llvm.urem %18, %1 : i64
%20 = llvm.sub %18, %19 : i64
%21 = llvm.inttoptr %20 : i64 to !llvm.ptr
llvm.br ^bb1(%3 : i64)
^bb1(%22: i64): // 2 preds: ^bb0, ^bb2
%23 = llvm.icmp "slt" %22, %4 : i64
llvm.cond_br %23, ^bb2, ^bb3
^bb2: // pred: ^bb1
%24 = llvm.getelementptr %21[%22] : (!llvm.ptr, i64) -> !llvm.ptr, i32
llvm.store %arg0, %24 : i32, !llvm.ptr
%25 = llvm.add %22, %5 : i64
llvm.br ^bb1(%25 : i64)
^bb3: // pred: ^bb1
%26 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
%27 = llvm.ptrtoint %26 : !llvm.ptr to i64
%28 = llvm.add %27, %17 : i64
%29 = llvm.urem %28, %1 : i64
%30 = llvm.sub %28, %29 : i64
%31 = llvm.inttoptr %30 : i64 to !llvm.ptr
llvm.br ^bb4(%3 : i64)
^bb4(%32: i64): // 2 preds: ^bb3, ^bb5
%33 = llvm.icmp "slt" %32, %4 : i64
llvm.cond_br %33, ^bb5, ^bb6
^bb5: // pred: ^bb4
%34 = llvm.getelementptr %31[%32] : (!llvm.ptr, i64) -> !llvm.ptr, i32
llvm.store %0, %34 : i32, !llvm.ptr
%35 = llvm.add %32, %5 : i64
llvm.br ^bb4(%35 : i64)
^bb6: // pred: ^bb4
%36 = llvm.mul %2, %5 : i64
%37 = llvm.getelementptr %11[1] : (!llvm.ptr) -> !llvm.ptr, i32
%38 = llvm.ptrtoint %37 : !llvm.ptr to i64
%39 = llvm.mul %36, %38 : i64
"llvm.intr.memcpy"(%31, %10, %39) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
llvm.br ^bb7(%3 : i64)
^bb7(%40: i64): // 2 preds: ^bb6, ^bb8
%41 = llvm.icmp "slt" %40, %4 : i64
llvm.cond_br %41, ^bb8, ^bb9
^bb8: // pred: ^bb7
%42 = llvm.getelementptr %31[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%43 = llvm.load %42 : !llvm.ptr -> i32
%44 = llvm.getelementptr %21[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%45 = llvm.load %44 : !llvm.ptr -> i32
%46 = llvm.add %43, %45 : i32
llvm.store %46, %42 : i32, !llvm.ptr
%47 = llvm.add %40, %5 : i64
llvm.br ^bb7(%47 : i64)
^bb9: // pred: ^bb7
%48 = llvm.call @malloc(%14) : (i64) -> !llvm.ptr
%49 = llvm.ptrtoint %48 : !llvm.ptr to i64
%50 = llvm.add %49, %17 : i64
%51 = llvm.urem %50, %1 : i64
%52 = llvm.sub %50, %51 : i64
%53 = llvm.inttoptr %52 : i64 to !llvm.ptr
%54 = llvm.mul %4, %5 : i64
%55 = llvm.mul %54, %38 : i64
"llvm.intr.memcpy"(%53, %8, %55) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
llvm.br ^bb10(%3 : i64)
^bb10(%56: i64): // 2 preds: ^bb9, ^bb14
%57 = llvm.icmp "slt" %56, %4 : i64
llvm.cond_br %57, ^bb11, ^bb15
^bb11: // pred: ^bb10
llvm.br ^bb12(%3 : i64)
^bb12(%58: i64): // 2 preds: ^bb11, ^bb13
%59 = llvm.icmp "slt" %58, %4 : i64
llvm.cond_br %59, ^bb13, ^bb14
^bb13: // pred: ^bb12
%60 = llvm.add %56, %58 : i64
%61 = llvm.urem %60, %4 : i64
%62 = llvm.getelementptr %31[%58] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%63 = llvm.load %62 : !llvm.ptr -> i32
%64 = llvm.getelementptr %31[%56] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%65 = llvm.load %64 : !llvm.ptr -> i32
%66 = llvm.mul %65, %63 : i32
%67 = llvm.getelementptr %53[%61] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%68 = llvm.load %67 : !llvm.ptr -> i32
%69 = llvm.add %66, %68 : i32
llvm.store %69, %67 : i32, !llvm.ptr
%70 = llvm.add %58, %5 : i64
llvm.br ^bb12(%70 : i64)
^bb14: // pred: ^bb12
%71 = llvm.add %56, %5 : i64
llvm.br ^bb10(%71 : i64)
^bb15: // pred: ^bb10
llvm.br ^bb16(%3 : i64)
^bb16(%72: i64): // 2 preds: ^bb15, ^bb17
%73 = llvm.icmp "slt" %72, %4 : i64
llvm.cond_br %73, ^bb17, ^bb18
^bb17: // pred: ^bb16
%74 = llvm.getelementptr %53[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%75 = llvm.load %74 : !llvm.ptr -> i32
%76 = llvm.getelementptr %21[%72] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%77 = llvm.load %76 : !llvm.ptr -> i32
%78 = llvm.sub %75, %77 : i32
llvm.store %78, %74 : i32, !llvm.ptr
%79 = llvm.add %72, %5 : i64
llvm.br ^bb16(%79 : i64)
^bb18: // pred: ^bb16
llvm.br ^bb19(%5, %0 : i64, i32)
^bb19(%80: i64, %81: i32): // 2 preds: ^bb18, ^bb20
%82 = llvm.icmp "slt" %80, %6 : i64
llvm.cond_br %82, ^bb20, ^bb21
^bb20: // pred: ^bb19
%83 = llvm.sub %6, %80 : i64
%84 = llvm.mul %arg0, %81 : i32
%85 = llvm.getelementptr %53[%83] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%86 = llvm.load %85 : !llvm.ptr -> i32
%87 = llvm.add %84, %86 : i32
%88 = llvm.add %80, %5 : i64
llvm.br ^bb19(%88, %87 : i64, i32)
^bb21: // pred: ^bb19
llvm.call @free(%15) : (!llvm.ptr) -> ()
llvm.call @free(%26) : (!llvm.ptr) -> ()
llvm.call @free(%48) : (!llvm.ptr) -> ()
llvm.return %81 : i32
}
}
Exiting MLIR
In the MLIR parlance, the LLVM dialect is an “exit” dialect, meaning after lowering to LLVM you run a different tool that generates code used by an external system. In our case, this will be LLVM’s internal representation (“LLVM IR” which is different from the LLVM MLIR dialect). The “code gen” step is often called translation in the MLIR docs. Here are the official docs on generating LLVM IR.
The codegen tool is called mlir-translate
, and it has an --mlir-to-llvmir
option. Running the command below on our output IR above gives the IR in this gist.
$ bazel build @llvm-project//mlir:mlir-translate
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir
Next, to compile LLVM IR with LLVM directly, we use the llc
tool that has the --filetype=obj
option to emit an object file. Without that flag, you can see a textual representation of the machine code
$ bazel build @llvm-project//llvm:llc
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc
# textual representation
.text
.file "LLVMDialectModule"
.globl test_poly_fn # -- Begin function test_poly_fn
.p2align 4, 0x90
.type test_poly_fn,@function
test_poly_fn: # @test_poly_fn
.cfi_startproc
# %bb.0:
pushq %rbp
.cfi_def_cfa_offset 16
pushq %r15
.cfi_def_cfa_offset 24
pushq %r14
.cfi_def_cfa_offset 32
pushq %r13
.cfi_def_cfa_offset 40
pushq %r12
.cfi_def_cfa_offset 48
pushq %rbx
.cfi_def_cfa_offset 56
pushq %rax
.cfi_def_cfa_offset 64
.cfi_offset %rbx, -56
.cfi_offset %r12, -48
<snip>
And finally, we can save the object file and compile and link it to a C main that calls the function and prints the result.
$ cat tests/poly_to_llvm_main.c
#include <stdio.h>
// This is the function we want to call from LLVM
int test_poly_fn(int x);
int main(int argc, char *argv[]) {
int i = 1;
int result = test_poly_fn(i);
printf("Result: %d\n", result);
return 0;
}
$ bazel run //tools:tutorial-opt -- --poly-to-llvm $PWD/tests/poly_to_llvm.mlir | ./bazel-bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | bazel-bin/external/llvm-project/llvm/llc --filetype=obj > poly_to_llvm.o
$ clang -c tests/poly_to_llvm_main.c && clang poly_to_llvm_main.o poly_to_llvm.o -o a.out
$ ./a.out
Result: 320
The polynomial computed by the test function is the rather bizarre:
((1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9) + (2 + 3*x + 4*x**2))**2 - (1+x+x**2+x**3+x**4+x**5+x**6+x**7+x**8+x**9)
But computing this mod $x^{10} – 1$ in Sympy, we get… 351. Uh oh. So somewhere along the way we messed up the lowering.
Before we fix it, let’s encode the whole process as a lit test. This requires making the binaries mlir-translate
, llc
, and clang
available to the test runner, and setting up the RUN
pipeline inside the test file. This is contained in this commit. Note %t
tells lit to generate a test-unique temporary file.
// RUN: tutorial-opt --poly-to-llvm %s | mlir-translate --mlir-to-llvmir | llc -filetype=obj > %t
// RUN: clang -c poly_to_llvm_main.c
// RUN: clang poly_to_llvm_main.o %t -o a.out
// RUN: ./a.out | FileCheck
// CHECK: 351
func.func @test_poly_fn(%arg : i32) -> i32 {
%tens = tensor.splat %arg : tensor<10xi32>
%input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10>
%0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%1 = poly.add %0, %input : !poly.poly<10>
%2 = poly.mul %1, %1 : !poly.poly<10>
%3 = poly.sub %2, %input : !poly.poly<10>
%4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32
return %4 : i32
}
Then running this as a normal lit test fails as
error: CHECK: expected string not found in input
# | // CHECK: 351
# | ^
# | <stdin>:1:1: note: scanning from here
# | Result: 320
# | ^
# | <stdin>:1:9: note: possible intended match here
# | Result: 320
Fixing the bug in the lowering
How do you find the bug in a lowering? Apparently there are some folks doing this via formal verification, formalizing the dialects and the lowerings in lean to prove correctness. I don’t have the time for that in this tutorial, so instead I simplify/expand the tests and squint.
Simplifying the test to computing the simpler function $t \mapsto 1 + t + t^2$, I see that an input of 1 gives 2, when it should be 3. This suggests that eval
is lowered wrong, at least. An input of 5 gives 4, so they’re all off by one term. This likely means I have an off-by-one error in the loop that eval
lowers to, and indeed that’s the problem. I was essentially doing this, where $N$ is the degree of the polynomial ($N-1$ is the largest legal index into the tensor):
accum = 0
for 1 <= i < N+1
index = N+1 - i
accum = accum * point + coeffs[index]
When it should have been
accum = 0
for 1 <= i < N+1
index = N - i
accum = accum * point + coeffs[index]
This commit fixes it, and now all the tests pass.
Aside: I don’t know of a means to do more resilient testing in MLIR. For example, I don’t know of a fuzz testing or property testing option. In my main project, HEIR, I hand-rolled a simple test generation routine that reads a config file and spits out MLIR test files. That allows me to jot down a bunch of tests quickly, but doesn’t give the full power of a system like hypothesis, which is a framework I’m quite fond of. I think the xDSL project would work in my favor here, but I have yet to dive into that, and as far as I know it requires re-defining all your custom dialects in Python to use. More on that on a future article.
Taking a step back
This article showed a bit of a naive approach to building up a dialect conversion pipeline, where we just greedily looked for ops to lower and inserted the relevant passes somewhat haphazardly. That worked out OK, but some lower-level passes (converting to LLVM) confounded the overall pipeline when placed too early.
A better approach is to identify the highest level operations and lower those
first. But that is only really possible if you already know which passes are
available and what they do. For example, elementwise-to-linalg
takes
something that seems low-level a priori—until noticing that
convert-arith-to-llvm
silently ignored those ops. Similarly, the implications
of converting func to LLVM (which appears to handle more than just the ops in
func
, were not clear until we tried it and ran into problems.
I don’t have a particularly good solution here besides trial and error. But I appreciate good tools, so I will amuse myself with some ideas.
Since most passes have per-op patterns, it seems like one could write a tool that analyzes an IR, simulates running all possible rewrite patterns from the standard list of passes, and checks which ones succeeded (i.e., the “match” was successful, though most patterns are a combined matchAndRewrite
), and what op types were generated as a result. Then once you get to an IR that you don’t know how to lower further, you could run the tool and it would tell you all of your options.
An even more aggressive tool would could construct a complete graph of op-to-op conversions. You could identify an op to lower and a subset of legal dialects or ops (similar to a ConversionTarget
), and it would report all possible paths from the op to your desired target.
I imagine performance is an obstacle here, and I also wonder to what extent this could be statically analyzed at least coarsely. For example, instead of running the rewrites, you could statically analyze the implementations of rewrite patterns lowering FooOp
for occurrences of create<BarOp>
, and just include an edge FooOp -> BarOp
as a conservative estimate (it might not generate BarOp
for every input).
Plans
At this point we’ve covered the basic ground of MLIR: defining a new dialect, writing some optimization passes, doing some lowerings, and compiling and running a proper binary. The topics we could cover from here on out are rather broad, but here’s a few options:
- Writing an analysis pass.
- Writing a global program optimization pass with a cost model, in contrast to the local rewrites we’ve done so far.
- Defining a custom interface (trait with input/output) to write a generic pass that applies to multiple dialects.
- Explaining the linear algebra dialect and
linalg.generic
. - Exploring some of the front ends for MLIR, like Polygeist and ClangIR.
- Exploring the Python bindings or C API for MLIR.
- Exploring some of the peripheral and in-progress projects around MLIR like PDLL, IRDL, xDSL, etc.
- Diving into the details of various upstream MLIR optimizations, like the polyhedral analysis or constant propagation passes.
- Working with sparse tensors.
- Covering more helper tools like the diagnostic infrastructure and bytecode generation.
Want to respond? Send me an email, post a webmention, or find me elsewhere on the internet.