Table of Contents

In previous articles we defined a dialect, and wrote various passes to optimize and canonicalize a program using that dialect. However, one of the main tenets of MLIR is “incremental lowering,” the idea that there are lots of levels of IR granularity, and you incrementally lower different parts of the IR, only discarding information when it’s no longer useful for optimizations. In this article we’ll see the first step of that: lowering the poly dialect to a combination of standard MLIR dialects, using the so-called dialect conversion infrastructure to accomplish it.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Side note: Victor Guerra kindly contributed a CMake build overlay to the mlir-tutorial project, in this PR, which I will be maintaining for the rest of the tutorial series. I still don’t know CMake well enough to give a tutorial on it, but I did work through parts of the official tutorial which seems excellent.

The type obstacle

If not for types, dialect conversion (a.k.a. lowering) would be essentially the same as a normal pass: you write some rewrite patterns and apply them to the IR. You’d typically have one rewrite pattern per operation that needs to be lowered.

Types make this significantly more complex, and I’ll demonstrate the issue by example with poly.

In poly we have a poly.add operation that takes two values of poly.poly type, and returns a value of poly.poly type. We want to lower poly.add to, say, a vectorized loop of basic arithmetic operations. We can do this, and we will do it below, by rewriting poly.add as an arith.addi. However, arith.addi doesn’t know about poly.poly and likely never will.

Aside: if we had to make arith extensible to know about poly, we’d need to upstream a change to the arith.addi op’s operands to allow types that implement some sort of interface like “integer-like, or containers of integer-like things,” but when I suggested that to some MLIR folks they gently demurred. I believe this is because the MLIR community feels that the arith dialect is too overloaded.

So, in addition to lowering the op, we need to lower the poly.poly<N> type to something like tensor<Nxi32>. And this is where the type obstacle comes into play. Once you change a specific value’s type, say, while lowering an op that produces that value as output, then all downstream users of that value are still expecting the old type and are technically invalid until they are lowered. In between each pass MLIR runs verifiers to ensure the IR is valid, so without some special handling, this means all the types and ops need to be converted in one pass, or else these verifiers would fail. But managing all that with standard rewrite rules would be hard: for each op rewrite rule, you’d have to continually check if the arguments and results have been converted yet or not.

MLIR handles this situation with a wrapper around standard passes that they call the dialect conversion framework. Official docs here. They require the user of the framework to inherit from different classes normal rewrites, set up some additional metadata, and separate type conversion from op conversion in a specific way we’ll see shortly. But at a high level, this framework works by lowering ops in a certain sorted order, converting the types as they go, and giving the op converters access to both the original types of each op as well as what the in-progress converted types look like at the time the op is visited by the framework. Each op-based rewrite pattern is expected to make that op type-legal after it’s visited, but it need not worry about downstream ops.

Finally, the dialect conversion framework keeps track of any type conflicts that aren’t resolved, and if any remain at the end of the pass, one of two things happens. The conversion framework allows one to optionally implement what’s called a type materializer, which inserts new intermediate ops that resolve type conflicts. So the first possibility is that the dialect conversion framework uses your type materializer hooks to patch up the IR, and the pass ends successfully. If those hooks fail, or if you didn’t define any hooks, then the pass fails and complains that it couldn’t fix the type conflicts.

Part of the complexity of this infrastructure also has to do with one of the harder lowering pipelines in upstream MLIR: the bufferization pipeline. This pipeline essentially converts an IR that uses ops with “value semantics” into one with “pointer semantics.” For example, the tensor type and its associated operations have value semantics, meaning each op is semantically creating a brand new tensor in its output, and all operations are pure (with some caveats). On the other hand, memref has pointer semantics, meaning it’s modeling the physical hardware more closely, requires explicit memory allocation, and supports operations that mutate memory locations.

Because bufferization is complicated, it is split into multiple sub-passes that handle bufferization issues specific to each of the relevant upstream MLIR dialects (see in the docs, e.g., arith-bufferize, func-bufferize, etc.). Each of these bufferization passes creates some internally-unresolvable type conflicts, which require custom type materializations to resolve. And to juggle all these issues across all the relevant dialects, the MLIR folks built a dedicated dialect called bufferization to house the intermediate ops. You’ll notice ops like to_memref and to_tensor that serve this role. And then there is a finalizing-bufferize pass whose role is to clean up any lingering bufferization/materialization ops.

There was a talk at the 2020-11-19 MLIR Open Design meeting called “Type Conversions the Not-So-Hard Way” (slides) that helped me understand these details, but I couldn’t understand the talk much before I tried a few naive lowering attempts, and then even after I felt I understood the talk I ran into some issues around type materialization. So my aim in the rest of this article is to explain the clarity I’ve found, such as it may be, and make it easier for the reader to understand the content of that talk.

Lowering Poly without Type Materializations

The first commit sets up the pass shell, similar to the previous passes in the series, though it is located in a new directory lib/Conversion. The only caveat to this commit is that it adds dependent dialects:

let dependentDialects = [
  "mlir::arith::ArithDialect",
  "mlir::tutorial::poly::PolyDialect",
  "mlir::tensor::TensorDialect",
];

In particular, a lowering must depend in this way on any dialects that contain operations or types that the lowering will create, to ensure that MLIR loads those dialects before trying to run the pass.

The second commit defines a ConversionTarget, which tells MLIR how to determine what ops are within scope of the lowering. Specifically, it allows you to declare an entire dialect as “illegal” after the lowering is complete. We’ll do this for the poly dialect.

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
  using PolyToStandardBase::PolyToStandardBase;

  void runOnOperation() override {
    MLIRContext *context = &getContext();
    auto *module = getOperation();

    ConversionTarget target(*context);       // <--- new thing
    target.addIllegalDialect<PolyDialect>();

    RewritePatternSet patterns(context);

    if (failed(applyPartialConversion(module, target, std::move(patterns)))) {   // <-- new thing
      signalPassFailure();
    }
  }
};

ConversionTarget can also declare specific ops as illegal, or even conditionally legal, using a callback to specify exactly what counts as legal. And then instead of applyPatternsAndFoldGreedily, we use applyPartialConversion to kick off the lowering.

Aside: We use applyPartialConversion instead of the more natural sounding alternative applyFullConversion for a silly reason: the error messages are better. If applyFullConversion fails, even in debug mode, you don’t get a lot of information about what went wrong. In partial mode, you can see the steps and what type conflicts it wasn’t able to patch up at the end. As far as I can tell, partial conversion is a strict generalization of full conversion, and I haven’t found a reason to use applyFullConversion ever.

After adding the conversion target, running --poly-to-standard on any random poly program will result in an error:

$ bazel run tools:tutorial-opt -- --poly-to-standard $PWD/tests/poly_syntax.mlir
tests/poly_syntax.mlir:14:10: error: failed to legalize operation 'poly.add' that was explicitly marked illegal
    %0 = poly.add %arg0, %arg1 : !poly.poly<10>

In the next commit we define a subclass of TypeConverter to convert types from poly to other dialects.

class PolyToStandardTypeConverter : public TypeConverter {
 public:
  PolyToStandardTypeConverter(MLIRContext *ctx) {
    addConversion([](Type type) { return type; });
    addConversion([ctx](PolynomialType type) -> Type {
      int degreeBound = type.getDegreeBound();
      IntegerType elementTy =
          IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless);
      return RankedTensorType::get({degreeBound}, elementTy);
    });
  }
};

We call addConversion for each specific type we need to convert, and then the default addConversion([](Type type) { return type; }); signifies “no work needed, this type is legal.” By itself this type converter does nothing, because it is only used in conjunction with a specific kind of rewrite pattern, a subclass of OpConversionPattern.

struct ConvertAdd : public OpConversionPattern<AddOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult matchAndRewrite(
      AddOp op, OpAdaptor adaptor,
      ConversionPatternRewriter &rewriter) const override {
    // ...
    return success();
  }
};

We add this, along with the new function call to RewritePatternSet::add, in this commit. This call differs from a normal RewritePatternSet::add in that it takes the type converter as input and passes it along to the constructor of ConvertAdd, whose parent class takes the type converter as input and stores it as a member variable.

void runOnOperation() {
  ...
  RewritePatternSet patterns(context);
  PolyToStandardTypeConverter typeConverter(context);
  patterns.add<ConvertAdd>(typeConverter, context);
}

An OpConversionPattern‘s matchAndRewrite method has two new arguments: the OpAdaptor and the ConversionPatternRewriter. The first, OpAdaptor, is an alias for AddOp::Adaptor which is part of the generated C++ code, holds the type-converted operands during dialect conversion. It’s called “adaptor” because it uses the table-gen defined names for an op’s arguments and results in the method names, rather than generic getOperand equivalents. Meanwhile, the AddOp argument (same as a normal rewrite pattern) contains the original, un-type-converted operands and results. The ConversionPatternRewriter is like a PatternRewriter, but it has additional methods relevant to dialect conversion, such as convertRegionTypes which is a helper for applying type conversions for ops with nested regions. All modifications to the IR must go through the ConversionPatternRewriter.

In the next commit, we implement the lowering for poly.add.

struct ConvertAdd : public OpConversionPattern<AddOp> {
  <...>
  LogicalResult matchAndRewrite(AddOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    arith::AddIOp addOp = rewriter.create<arith::AddIOp>(
        op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
    rewriter.replaceOp(op.getOperation(), {addOp});
    return success();
  }
};

The main thing to notice is that we don’t have to convert the types ourselves. The conversion framework converted the operands in advance, put them on the OpAdaptor, and we can just focus on creating the lowered op.

Mathematically, however, this lowering is so simple because I’m taking advantage of the overflow semantics of 32-bit integer addition to avoid computing modular remainders. Or rather, I’m interpreting addi to have natural unsigned overflow semantics, though I think there is some ongoing discussion about how precise these semantics should be. Also, arith.addi is defined to operate on tensors of integers elementwise, in addition to operating on individual integers.

However, there is one hiccup when we try to test it on the following IR

func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) {
  %2 = poly.add %0, %1: !poly.poly<10>
  return
}

The pass fails here because of a type conflict between the function arguments and the converted add op. We get an error that looks like

error: failed to legalize unresolved materialization from '!poly.poly<10>' to 'tensor<10xi32>' that remained live after conversion
  %2 = poly.add %0, %1: !poly.poly<10>
       ^
poly_to_standard.mlir:6:8: note: see current operation: %0 = "builtin.unrealized_conversion_cast"(%arg1) : (!poly.poly<10>) -> tensor<10xi32>
poly_to_standard.mlir:6:8: note: see existing live user here: %2 = arith.addi %1, %0 : tensor<10xi32>

If you look at the debug output (run the pass with --debug, i.e., bazel run tools:tutorial-opt -- --poly-to-standard --debug $PWD/tests/poly_to_standard.mlir) you’ll see a log of what the conversion framework is trying to do. Eventually it spits out this IR just before the end:

func.func @test_lower_add(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) {
  %0 = builtin.unrealized_conversion_cast %arg1 : !poly.poly<10> to tensor<10xi32>
  %1 = builtin.unrealized_conversion_cast %arg0 : !poly.poly<10> to tensor<10xi32>
  %2 = arith.addi %1, %0 : tensor<10xi32>
  %3 = poly.add %arg0, %arg1 : !poly.poly<10>
  return
}

This builtin.unrealized_conversion_cast is the internal stand-in for a type conflict before attempting to use user-defined materialization hooks to fix the conflicts (which we didn’t implement). It’s basically a forced type coercion. From the docs:

This operation should not be attributed any special representational or execution semantics, and is generally only intended to be used to satisfy the temporary intermixing of type systems during the conversion of one type system to another.

So at this point we have two options to make the pass succeed:

The second option doesn’t really make sense, since we want to get rid of poly.poly entirely in this pass. But if we had to, we would use poly.from_tensor and poly.to_tensor, and we’ll show how to do that in the next section. For now, we’ll convert the function types.

One should have to lower a structural op like func.func in every conversion pass that has a new type that can show up in a function signature. Ditto for ops like scf.if, scf.for, etc. Luckily, MLIR is general enough to handle it for you, but it requires opting in via some extra boilerplate in this commit. I won’t copy the boilerplate here, but basically you call helpers that use interfaces to define patterns to support type conversions for all function-looking ops (func/call/return), and all if/then/else-looking ops. We just call those helpers and give them our type converter, and it gives us patterns to add to our RewritePatternSet.

After that, the pass runs and we get

  func.func @test_lower_add(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) {
    %0 = arith.addi %arg0, %arg1 : tensor<10xi32>
    return
  }

I wanted to highlight one other aspect of dialect conversions, which is that they use the folding capabilities of the dialect being converted. So for an IR like

func.func @test_lower_add_and_fold() {
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.constant dense<[3, 4, 5]> : tensor<3xi32> : !poly.poly<10>
  %2 = poly.add %0, %1: !poly.poly<10>
  return
}

Instead of lowering poly.add to arith.addi, and using a lowering for poly.constant added in this commit, we get

func.func @test_lower_add_and_fold() {
  %cst = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
  %cst_0 = arith.constant dense<[3, 4, 5]> : tensor<3xi32>
  %cst_1 = arith.constant dense<[5, 7, 9]> : tensor<3xi32>
  return
}

So as it lowers, it eagerly tries to fold the results, and that can result in some constant propagation.

Next, we lower the rest of the ops:

Then in this commit you can see how it lowers a combination of all the poly ops to an equivalent program that uses just tensor, scf, and arith. For the input test:

func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 {
  %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
  %1 = poly.add %0, %arg : !poly.poly<10>
  %2 = poly.mul %1, %1 : !poly.poly<10>
  %3 = poly.sub %2, %arg : !poly.poly<10>
  %4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32
  return %4 : i32
}

It produces the following IR:

  func.func @test_lower_many(%arg0: tensor<10xi32>, %arg1: i32) -> i32 {
    %cst = arith.constant dense<[2, 3, 4]> : tensor<3xi32>
    %c0_i32 = arith.constant 0 : i32
    %padded = tensor.pad %cst low[0] high[7] {
    ^bb0(%arg2: index):
      tensor.yield %c0_i32 : i32
    } : tensor<3xi32> to tensor<10xi32>
    %0 = arith.addi %padded, %arg0 : tensor<10xi32>
    %cst_0 = arith.constant dense<0> : tensor<10xi32>
    %c0 = arith.constant 0 : index
    %c10 = arith.constant 10 : index
    %c1 = arith.constant 1 : index
    %1 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %cst_0) -> (tensor<10xi32>) {
      %4 = scf.for %arg4 = %c0 to %c10 step %c1 iter_args(%arg5 = %arg3) -> (tensor<10xi32>) {
        %5 = arith.addi %arg2, %arg4 : index
        %6 = arith.remui %5, %c10 : index
        %extracted = tensor.extract %0[%arg2] : tensor<10xi32>
        %extracted_3 = tensor.extract %0[%arg4] : tensor<10xi32>
        %7 = arith.muli %extracted, %extracted_3 : i32
        %extracted_4 = tensor.extract %arg5[%6] : tensor<10xi32>
        %8 = arith.addi %7, %extracted_4 : i32
        %inserted = tensor.insert %8 into %arg5[%6] : tensor<10xi32>
        scf.yield %inserted : tensor<10xi32>
      }
      scf.yield %4 : tensor<10xi32>
    }
    %2 = arith.subi %1, %arg0 : tensor<10xi32>
    %c1_1 = arith.constant 1 : index
    %c11 = arith.constant 11 : index
    %c0_i32_2 = arith.constant 0 : i32
    %3 = scf.for %arg2 = %c1_1 to %c11 step %c1_1 iter_args(%arg3 = %c0_i32_2) -> (i32) {
      %4 = arith.subi %c11, %arg2 : index
      %5 = arith.muli %arg1, %arg3 : i32
      %extracted = tensor.extract %2[%4] : tensor<10xi32>
      %6 = arith.addi %5, %extracted : i32
      scf.yield %6 : i32
    }
    return %3 : i32
  }

Materialization hooks

The previous section didn’t use the materialization hook infrastructure mentioned earlier, simply because we don’t need it. It’s only necessary when type conflicts must persist across multiple passes, and we have no trouble lowering all of poly in one pass. But for demonstration purposes, this commit (reverted in the same PR) shows how one would implement the materialization hooks. I removed the lowering for poly.sub, reconfigured the ConversionTarget so that poly.sub and to_tensor, from_tensor are no longer illegal, and then added two hooks on the TypeConverter to insert new from_tensor and to_tensor ops when they are needed:

// Convert from a tensor type to a poly type: use from_tensor
addSourceMaterialization([](OpBuilder &builder, Type type,
                            ValueRange inputs, Location loc) -> Value {
  return builder.create<poly::FromTensorOp>(loc, type, inputs[0]);
});

// Convert from a poly type to a tensor type: use to_tensor
addTargetMaterialization([](OpBuilder &builder, Type type,
                            ValueRange inputs, Location loc) -> Value {
  return builder.create<poly::ToTensorOp>(loc, type, inputs[0]);
});

Now the same lowering above succeeds, except anywhere there is a poly.sub it is surrounded by from_tensor and to_tensor. The output for test_lower_many now looks like this:

  func.func @test_lower_many(%arg0: tensor<10xi32>, %arg1: i32) -> i32 {
    %2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %cst_0) -> (tensor<10xi32>) {
      %7 = scf.for %arg4 = %c0 to %c10 step %c1 iter_args(%arg5 = %arg3) -> (tensor<10xi32>) {
         ...
      }
      scf.yield %7 : tensor<10xi32>
    }
    %3 = poly.from_tensor %2 : tensor<10xi32> -> !poly.poly<10>
    %4 = poly.sub %3, %0 : !poly.poly<10>
    %5 = poly.to_tensor %4 : !poly.poly<10> -> tensor<10xi32>
    %c1_1 = arith.constant 1 : index
    ...
  }

Correctness?

Lowering poly is not entirely trivial, because you’re implementing polynomial math. The tests I wrote are specific, in that they put tight constraints on the output IR, but what if I’m just dumb and I got the algorithm wrong? It stands to reason that I would want an extra layer of certainty to guard against my own worst enemy.

One of the possible solutions here is to use mlir-cpu-runner, like we did when we lowered ctlz to LLVM. Another is to lower all the way to machine code and run it on the actual hardware instead of through an interpreter. And that’s what we’ll do next time. Until then, the poly lowerings might be buggy. Feel free to point out any bugs in GitHub issues or comments on this blog.


Want to respond? Send me an email, post a webmention, or find me elsewhere on the internet.

DOI: https://doi.org/10.59350/ze9yy-hqc83