MLIR — Canonicalizers and Declarative Rewrite Patterns

Table of Contents

In a previous article we defined folding functions, and used them to enable some canonicalization and the sccp constant propagation pass for the poly dialect. This time we’ll see how to add more general canonicalization patterns.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Why is Canonicalization Needed?

MLIR provides folding as a mechanism to simplify an IR, which can result in simpler, more efficient ops (e.g., replacing multiplication by a power of 2 with a shift) or removing ops entirely (e.g., deleting $y = x+0$ and using $x$ for $y$ downstream). It also has a special hook for constant materialization.

General canonicalization differs from folding in MLIR in that it can transform many operations. In the MLIR docs they will call this “DAG-to-DAG” rewriting, where DAG stands for “directed acyclic graph” in the graph theory sense, but really just means a section of the IR.

Canonicalizers can be written in the standard way: declare the op has a canonicalizer in tablegen and then implement a generated C++ function declaration. The official docs for that are here. Or you can do it all declaratively in tablegen, the docs for that are here. We’ll do both in this article.

Aside: there is a third way, to use a new system called PDLL, but I haven’t figured out how to use that yet. It should be noted that PDLL is under active development, and in the meantime the tablegen-based approach in this article (called “DRR” for Declarative Rewrite Rules in the MLIR docs) is considered to be in maintenance mode, but not yet deprecated. I’ll try to cover PDLL in a future article.

Canonicalizers in C++

Reusing our poly dialect, we’ll start with the binary polynomial operations, adding let hasCanonicalizer = 1; to the op base class in this commit, which generates the following method headers on each of the binary op classes

static void getCanonicalizationPatterns(
  ::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);

The body of this method asks to add custom rewrite patterns to the input results set, and we can define those patterns however we feel in the C++.

The first canonicalization pattern we’ll write in this commit is for the simple identity $x^y – y^2 = (x+y)(x-y)$, which is useful because it replaces a multiplication with an addition. The only caveat is that this canonicalization is only more efficient if the squares have no other downstream uses.

// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
struct DifferenceOfSquares : public OpRewritePattern<SubOp> {
  DifferenceOfSquares(mlir::MLIRContext *context)
      : OpRewritePattern<SubOp>(context, /*benefit=*/1) {}

  LogicalResult matchAndRewrite(SubOp op, 
                                PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);
    Value rhs = op.getOperand(1);
    if (!lhs.hasOneUse() || !rhs.hasOneUse()) { 
      return failure();

    auto rhsMul = rhs.getDefiningOp<MulOp>();
    auto lhsMul = lhs.getDefiningOp<MulOp>();
    if (!rhsMul || !lhsMul) {
      return failure();

    bool rhsMulOpsAgree = rhsMul.getLhs() == rhsMul.getRhs();
    bool lhsMulOpsAgree = lhsMul.getLhs() == lhsMul.getRhs();

    if (!rhsMulOpsAgree || !lhsMulOpsAgree) {
      return failure();

    auto x = lhsMul.getLhs();
    auto y = rhsMul.getLhs();

    AddOp newAdd = rewriter.create<AddOp>(op.getLoc(), x, y);
    SubOp newSub = rewriter.create<SubOp>(op.getLoc(), x, y);
    MulOp newMul = rewriter.create<MulOp>(op.getLoc(), newAdd, newSub);

    rewriter.replaceOp(op, {newMul});
    // We don't need to remove the original ops because MLIR already has
    // canonicalization patterns that remove unused ops.

    return success();

The test in the same commit shows the impact:

// Input:
func.func @test_difference_of_squares(
  %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> {
  %2 = poly.mul %0, %0 : !poly.poly<3>
  %3 = poly.mul %1, %1 : !poly.poly<3>
  %4 = poly.sub %2, %3 : !poly.poly<3>
  %5 = poly.add %4, %4 : !poly.poly<3>
  return %5 : !poly.poly<3>

// Output:
// bazel run tools:tutorial-opt -- --canonicalize $FILE
func.func @test_difference_of_squares(%arg0: !poly.poly<3>, %arg1: !poly.poly<3>) -> !poly.poly<3> {
  %0 = poly.add %arg0, %arg1 : !poly.poly<3>
  %1 = poly.sub %arg0, %arg1 : !poly.poly<3>
  %2 = poly.mul %0, %1 : !poly.poly<3>
  %3 = poly.add %2, %2 : !poly.poly<3>
  return %3 : !poly.poly<3>

Other than this pattern being used in the getCanonicalizationPatterns function, there is nothing new here compared to the previous article on rewrite patterns.

Canonicalizers in Tablegen

The above canonicalization is really a simple kind of optimization attached to the canonicalization pass. It seems that this is how many minor optimizations end up being implemented, making the --canonicalize pass a heavyweight and powerful pass. However, the name “canonicalize” also suggests that it should be used to put the IR into a canonical form so that later passes don’t have to check as many cases.

So let’s implement a canonicalization like that as well. We’ll add one that has poly interact with the complex dialect, and implement a canonicalization that ensures complex conjugation always comes after polynomial evaluation. This works because of the identity $f(\overline{z}) = \overline{f(z)}$ for all polynomials $f$ and all complex numbers $z$.

This commit adds support for complex inputs in the poly.eval op. Note that in MLIR, complex types are forced to be floating point because all the op verifiers that construct complex numbers require it. The complex type itself, however, suggests a complex<i32> is perfectly legal, so it seems nobody in the MLIR community needs Gaussian integers yet.

The rule itself is implemented in tablegen in this commit. The main tablegen code is:

include ""
include "mlir/Dialect/Complex/IR/"
include "mlir/IR/"

def LiftConjThroughEval : Pat<
  (Poly_EvalOp $f, (ConjOp $z)),
  (ConjOp (Poly_EvalOp $f, $z))

The two new pieces here are the Pat class (source) that defines a rewrite pattern, and the parenthetical notation that defines sub-trees of the IR being matched and rewritten. The source documentation on the Pattern parent class is quite well written, so read that for extra detail, and the normal docs provide a higher level view with additional semantics.

But the short story here is that the inputs to Pat are two “IR tree” objects (MLIR calls them “DAG nodes”), and each node in the tree is specified by parentheses ( ) with the first thing in the parentheses being the name of an operation (the tablegen name, e.g., Poly_EvalOp which comes from, and the remaining arguments being the op’s arguments or attributes. Naturally, the nodes can nest, and that corresponds to a match applied to the argument. I.e., (Poly_EvalOp $f, (ConjOp $z)) means “an eval op whose first argument is anything (bind that to $f) and whose second argument is the output of a ConjOp whose input is anything (bind that to $z).

When running tablegen with the -gen-rewriters option, this generates this code, which is not much more than a thorough version of the pattern we’d write manually. Then in this commit we show how to include it in the codebase. We still have to tell MLIR which pass to add the generated patterns to. You can add each pattern by name, or use the populateWithGenerated function to add them all.

As another example, this commit reimplements the difference of squares pattern in tablegen. This one uses three additional features: a pattern that generates multiple new ops (which uses Pattern instead of Pat), binding the ops to names, and constraints that control when the pattern may be run.

def HasOneUse: Constraint<CPred<"$_self.hasOneUse()">, "has one use">;

// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
def DifferenceOfSquares : Pattern<
  (Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)),
    (Poly_AddOp:$sum $x, $y),
    (Poly_SubOp:$diff $x, $y),
    (Poly_MulOp:$res $sum, $diff),
  [(HasOneUse:$lhs), (HasOneUse:$rhs)]

The HasOneUse constraint merely injects the quoted C++ code into a generated if guard, with $_self as a magic string to substitute in the argument when it’s used.

But then notice the syntax of (Poly_MulOp:$lhs $x, $x), the colon binds $lhs to refer to the op as a whole (or, via method overloads, its result value), so that it can be passed to the constraint. Similarly, the generated ops are all given names so they can be fed as the arguments of other generated ops Finally, the second argument of Pattern is a list of generated ops to replace the matched input IR, rather than a single node for Pat.

The benefit of doing this is significantly less boilerplate related to type casting, checking for nulls, and emitting error messages. But because you still occasionally need to inject random C++ code, and inspect the generated C++ to debug, it helps to be fluent in both styles. I don’t know how to check a constraint like “has a single use” in pure DRR tablegen.

MLIR — Verifiers

Table of Contents

Last time we defined folders and used them to enable some canonicalization and the sccp constant propagation pass for the poly dialect. This time we’ll add some additional safety checks to the dialect in the form of verifiers.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Purpose of a verifier

Verifiers ensure the types and operations in a concrete MLIR program are well-formed. Verifiers are run before and after every pass, and help to ensure that individual passes, folders, rewrite patterns, etc., emit proper IR. This allows you to enforce invariants of each op, and it makes passes simpler because they can rely on the invariants to avoid edge case checking.

The official docs for verifiers of attributes and types is here, and for operations here.

That said, most common kinds of verification code are implemented as Traits (mixins, recall the earlier article). So we’ll start with those.

Trait-based verifiers

In the last article we added SameOperandsAndResultElementType to enable poly.add to have a mixed poly and tensor-of-poly input semantics. This technically did add a verifier to the IR, but to demonstrate this more clearly I want to restrict that behavior to assert that the input and output types must all agree (all tensors-of-polys or all polys).

This commit shows the work to do this, which is mainly changing the trait name to SameOperandsAndResultType. As a result, we get a few new generated things for free. First the verification engine uses verifyTrait to check that the types agree. There, verifyInvariants is an Operation base class method that the generated code overrides when traits inject verification, the same thing that checks the type constraints on operation types. (Note: a custom verifier instead gets a method name verify to distinguish it from verifyInvariants). Since the SameOperandsAndResultType is a generic check, this doesn’t affect the generated code.

Next, an inferReturnTypes function is generated, below shown for AddOp.

::mlir::LogicalResult AddOp::inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
  ::mlir::Builder odsBuilder(context);
  ::mlir::Type odsInferredType0 = operands[0].getType();
  inferredReturnTypes[0] = odsInferredType0;
  return ::mlir::success();

With a type inference hook present, we can simplify the operation’s assembly format, so that the type need only be specified once instead of three times (type, type) -> type. If we tried to simplify it before this trait, tablegen would complain that it can’t infer the types needed to build a parser.

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";

This also requires updating all the tests to use the new assembly format (I did not try to find a way to make both the functional and abbreviated forms allowed at the same time, no big deal).

Finally, this trait adds builders that don’t need you to specify the return type. Another example for AddOp:

void AddOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs) {

  ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
  if (::mlir::succeeded(AddOp::inferReturnTypes(odsBuilder.getContext(),
          odsState.location, odsState.operands,
          odsState.regions, inferredReturnTypes)))
    ::llvm::report_fatal_error("Failed to infer result type(s).");

For another example, the EvalOp can’t use SameOperandsAndResultType, because its operands require different types, but we can use AllTypesMatch which generates similar code, but restricts the verification to a subset of types. This is added in this commit.

def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>]> {
  let summary = "Evaluates a Polynomial at a given input value.";
  let arguments = (ins Polynomial:$input, AnyInteger:$point);
  let results = (outs AnyInteger:$output);
  let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";

You can see many similar sorts of type-inference traits here and their corresponding verifiers here.

Aside: the nomenclature for verification in regards to traits is a bit confusing. There is a concept called constraint in MLIR, and the docs describe traits as subclasses of a Constraint base class. But at the time of this writing (2023-09-11) that particular claim is wrong. There’s a TraitBase base class for traits, and the Constraint base class appears to be used for verification on type declarations and attribute delarations (the let arguments = (ins ...) stuff). These go by the names “type constraints” and “attribute constraints.” AnyInteger is an example of a type constraint because it can match multiple types, and it does inherit (indirectly) from the Constraint base class. I think type constraints are a bit more complicated because the examples I see in MLIR all involve injecting C++ code through the tablegen (you’ll see stuff like CPred) and I haven’t explored how that materializes as generated code yet.

A custom verifier

I couldn’t think of any legitimate custom verification I wanted to have in poly, so to make one up arbitrarily, I will assert that EvalOp‘s input must be a 32-bit integer type. I could do this with the type constraint in tablegen, but I will do it in a custom verifier instead for demonstration purposes.

Repeating our routine, we start by adding let hasVerifier = 1; to the op’s tablegen, and inspect the generated signature in the header, in this commit.

class EvalOp ... {
  ::mlir::LogicalResult verify();

And the implementation

LogicalResult EvalOp::verify() {
  return getPoint().getType().IsIneger(32)
             ? success()
             : emitOpError("argument point must be a 32-bit integer");

The new thing here is emitOpError, which is required because if you just return failure, then the verifier will fail but not output any information, resulting in an empty stdout.

And then to test for failure, the lit run command should pipe stderr to stdout, and have FileCheck operate on that

// tests/poly_verifier.mlir
// RUN: tutorial-opt %s 2>%t; FileCheck %s < %t

A trait-based custom verifier

We can combine these two ideas together by defining a custom trait that includes a verification hook.

Each trait in MLIR has an optional verifyTrait hook (which is checked before the custom verifier created via hasVerifier), and we can use this to define generic verifiers that apply to many ops. We’ll do this by making a verifier that extends the previous section by asserting—generically for any op—that all integer-like operands must be 32-bit. Again, this is a silly and arbitrary constraint to enforce, but it’s a demonstration.

The process of defining this is almost entirely C++, and contained in this commit. We start by defining a so-called NativeOpTrait subclass in tablegen:

def Has32BitArguments:  NativeOpTrait<"Has32BitArguments"> {
  let cppNamespace = "::mlir::tutorial::poly";

This has an almost trivial effect: add a template argument called ::mlir::tutorial::poly::Has32BitArguments to the generated header class for an op that has this trait. E.g., for EvalOp,

def Poly_EvalOp : Op<Poly_Dialect, "eval", [
    AllTypesMatch<["point", "output"]>, 
]> {


class EvalOp : public ::mlir::Op<
    EvalOp, ::mlir::OpTrait::ZeroRegions, 
> {

The rest is up to you in C++. Define an implementation of Has32BitArguments, following the curiously-recurring-template pattern required of OpTrait::TraitBase, and then implement a verifyTrait hook. In our case, iterate over all ops, check which ones are integer-like, and then assert the width of those.

// PolyTraits.h
template <typename ConcreteType>
class Has32BitArguments : public OpTrait::TraitBase<ConcreteType, Has32BitArguments> {
  static LogicalResult verifyTrait(Operation *op) {
    for (auto type : op->getOperandTypes()) {
      // OK to skip non-integer operand types
      if (!type.isIntOrIndex()) continue;

      if (!type.isInteger(32)) {
        return op->emitOpError()
               << "requires each numeric operand to be a 32-bit integer";

    return success();

This has the upside of being more generic, but the downside of requiring awkward casting to support specific ops and their named arguments. I.e., here we can’t refer to getPoint unless we do a dynamic cast to EvalOp. Doing so would make the trait less general, so a custom op-specific verifier is more appropriate for that.

MLIR — Folders and Constant Propagation

Table of Contents

Last time we saw how to use pre-defined MLIR traits to enable upstream MLIR passes like loop-invariant-code-motion to apply to poly programs. We left out -sccp (sparse conditional constant propagation), and so this time we’ll add what is needed to make that pass work. It requires the concept of folding.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Constant Propagation vs Canonicalization

-sccp is sparse conditional constant propagation, which attempts to infer when an operation has a constant output, and then replaces the operation with the constant value. Repeating this, it “propagates” the constants as far as possible through the program. You can think of it like eagerly computing values it can during compile time, and then sticking them into the compiled program as constants.

Here’s what it looks like for arith, where all the needed tools are implemented. For an input like:

func.func @test_arith_sccp() -> i32 {
  %0 = arith.constant 7 : i32
  %1 = arith.constant 8 : i32
  %2 = arith.addi %0, %0 : i32
  %3 = arith.muli %0, %0 : i32
  %4 = arith.addi %2, %3 : i32
  return %2 : i32

The output of tutorial-opt --sccp is

func.func @test_arith_sccp() -> i32 {
  %c63_i32 = arith.constant 63 : i32
  %c49_i32 = arith.constant 49 : i32
  %c14_i32 = arith.constant 14 : i32
  %c8_i32 = arith.constant 8 : i32
  %c7_i32 = arith.constant 7 : i32
  return %c14_i32 : i32

Note two additional facts: sccp doesn’t delete dead code, and what is not shown here is the main novelty in sccp, which is that it can propagate constants through control flow (ifs and loops).

A related concept is the idea of canonicalization, which gets its own --canonicalize pass, and which hides a lot of the heavy lifting in MLIR. Canonicalize overlaps a little bit with sccp, in that it also computes constants and materializes them in the IR. Take, for example, the --canonicalize pass on the same IR:

func.func @test_arith_sccp() -> i32 {
  %c14_i32 = arith.constant 14 : i32
  return %c14_i32 : i32

The intermediate constants are all pruned, and all that remains is the return value and no operations. Canonicalize cannot propagate constants through control flow, and as such should be thought of as more “local” than sccp.

Both of these, however, are supported via folding, which is the process of taking series of ops and merging them together into simpler ops. It also requires our dialect has some sort of constant operation, which is inserted (“materialized”) with the results of a fold. Folding and canonicalization are more general than what I’m showing here, so we’ll come back to what else they can do in a future article.

The rough outline of what is needed to support folding in this way is:

  • Adding a constant operation
  • Adding a materialization hook
  • Adding folders for each op

This would result in a situation (test case commit) as follows. Starting from

%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%1 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
%2 = poly.mul %1, %1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %1, %1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>

We would get

%0 = poly.constant dense<[2, 8, 20, 24, 18]> : tensor<5xi32> : <10>
%1 = poly.constant dense<[1, 4, 10, 12, 9]> : tensor<5xi32> : <10>
%2 = poly.constant dense<[1, 2, 3]> : tensor<3xi32> : <10>

Making a constant op

Currently we’re imitating a constant polynomial construction by combing a from_tensor op with arith.constant. Like this:

%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>

While a constant operation might combine them into one op.

%0 = poly.constant dense<[2, 8, 20, 24, 18]> : !poly.poly<10>

The from_tensor op can also be used to build a polynomial from data, not just constants, so it’s worth having around even after we implement poly.constant.

Having a dedicated constant operation has benefits explained in the MLIR documentation on folding. What’s relevant here is that fold can be used to signal to passes like sccp that the result of an op is constant (statically known), or it can be used to say that the result of an op is equivalent to a pre-existing value created by a different op. For the constant case, a materializeConstant hook is also needed to tell MLIR how to take the constant result and turn it into a proper IR op.

The constant op itself, in this commit, comes with two new concepts, the ConstantLike trait and an argument that is an attribute constraint.

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {   // new
  let summary = "Define a constant polynomial via an attribute.";
  let arguments = (ins AnyIntElementsAttr:$coefficients);    // new
  let results = (outs Polynomial:$output);
  let assemblyFormat = "$coefficients attr-dict `:` type($output)";

The ConstantLike attribute is checked here during folding via the constant op matcher as an assertion. [Aside: I’m not sure why the trait is specifically required, so long as the materialization function is present on the dialect; it just seems like this check is used for assertions. Perhaps it’s just a safeguard.]

Next we have the line let arguments = (ins AnyIntElementsAttr:$coefficients); This defines the input to the op as an attribute (statically defined data) rather than a previous SSA value. The AnyIntElementsAttr is itself an attribute constraint, allowing any attribute that is has the IntElementsAttrBase as a base class to be used (e.g., 32-bit or 64-bit integer attributes). This means that we could use all of the following syntax forms:

%10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%11 = poly.constant dense<[2, 3, 4]> : tensor<3xi8> : !poly.poly<10>
%12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10>
%13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10>

Adding folders

We add folders in these commits:

Each has the same structure: add let hasFolder = 1; to the tablegen for the op, which adds a header declaration of the following form (noting that the signature would be different if the op has more than one result value, see docs).

OpFoldResult <OpName>::fold(<OpName>::FoldAdaptor adaptor);

Then we implement it in PolyOps.cpp. The semantics of this function are such that, if the fold decides the op should be replaced with a constant, it must return an attribute representing that constant, which can be given as the input to a poly.constant. The FoldAdaptor is a shim that has the same method names as an instance of the op’s C++ class, but arguments that have been folded themselves are replaced with Attribute instances representing the constants they were folded with. This will be relevant for folding add and mul, since the body needs to actually compute the result eagerly, and needs access to the actual values to do so.

For poly.constant the timplementation is trivial: you just return the input attribute.

OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
  return adaptor.getCoefficients();

The from_tensor op is similar, but has an extra cast that acts as an assertion, since the tensor might have been constructed with weird types we don’t want as input. If the dyn_cast fails, the result is nullptr, which is cast by MLIR to a failed OpFoldResult.

OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
  // Returns null if the cast failed, which corresponds to a failed fold.
  return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());

The poly binary ops are slightly more complicated since they are actually doing work. Each of these fold methods effectively takes as input two DenseIntElementsAttr for each operand, and expects us to return another DenseIntElementsAttr for the result.

For add/sub which are elementwise operations on the coefficients, we get to use an existing upstream helper method, constFoldBinaryOp, which through some template metaprogramming wizardry, allows us to specify only the elementwise operation itself.

OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
  return constFoldBinaryOp<IntegerAttr, APInt>(
      adaptor.getOperands(), [&](APInt a, APInt b) { return a + b; });

For mul, we have to write out the multiplication routine manually. In what’s below, I’m implementing the naive textbook polymul algorithm, which could be optimized if one expects people to start compiling programs with large, static polynomials in them.

OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
  auto lhs = cast<DenseIntElementsAttr>(adaptor.getOperands()[0]);
  auto rhs = cast<DenseIntElementsAttr>(adaptor.getOperands()[1]);
  auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
  auto maxIndex = lhs.size() + rhs.size() - 1;

  SmallVector<APInt, 8> result;
  for (int i = 0; i < maxIndex; ++i) {
    result.push_back(APInt((*lhs.begin()).getBitWidth(), 0));

  int i = 0;
  for (auto lhsIt = lhs.value_begin<APInt>(); lhsIt != lhs.value_end<APInt>();
       ++lhsIt) {
    int j = 0;
    for (auto rhsIt = rhs.value_begin<APInt>(); rhsIt != rhs.value_end<APInt>();
         ++rhsIt) {
      // index is modulo degree because poly's semantics are defined modulo x^N = 1.
      result[(i + j) % degree] += *rhsIt * (*lhsIt);

  return DenseIntElementsAttr::get(
                            IntegerType::get(getContext(), 32)),

Adding a constant materializer

Finally, we add a constant materializer. This is a dialect-level feature, so we start by adding let hasConstantMaterializer = 1; to the dialect tablegen, and observing the newly generated header signature:

Operation *PolyDialect::materializeConstant(
    OpBuilder &builder, Attribute value, Type type, Location loc);

The Attribute input represents the result of each folding step above. The Type is the desired result type of the op, which is needed in cases like arith.constant where the same attribute can generate multiple different types via different interpretations of a hex string or splatting with a result tensor that has different dimensions.

In our case the implementation is trivial: just construct a constant op from the attribute.

Operation *PolyDialect::materializeConstant(
    OpBuilder &builder, Attribute value, Type type, Location loc) {
  auto coeffs = dyn_cast<DenseIntElementsAttr>(value);
  if (!coeffs)
    return nullptr;
  return builder.create<ConstantOp>(loc, type, coeffs);

Other kinds of folding

While this has demonstrated a generic kind of folding with respect to static constants, many folding functions in MLIR use simple matches to determine when an op can be replaced with a value from a previously computed op.

Take, for example, the complex dialect (for complex numbers). A complex.create op constructs a complex number from real and imaginary parts. A folder in that dialect checks for a pattern like complex.create(,, and replaces it with %z directly. The arith dialect similarly has folds for things like a-b + b -> a and a + 0 -> a.

However, most work on simplifying an IR according to algebraic rules belongs in the canonicalization pass, since while it supports folding, it also supports general rewrite patterns that are allowed to delete and create ops as needed to simplify the IR. We’ll cover canonicalization in more detail in a future article. But just remember, folds may only modify the single operation being folded, use existing SSA values, and may not create new ops. So they are limited in power and decidedly local operations.

MLIR — Defining a New Dialect

Table of Contents

In the last article in the series, we migrated the passes we had written to use the tablegen code generation framework. That was a preface to using tablegen to define dialects.

In this article we’ll define a dialect that represents arithmetic on single-variable polynomials, with coefficients in $\mathbb{Z} / 2^{32} \mathbb{Z}$ (32-bit unsigned integers).

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Sketching out a design

The basic dialect will define a new polynomial type, and provide operations to define polynomials by specifying their coefficients from standard MLIR types, extract data about a polynomial to store the results in standard MLIR types, and to do arithmetic operations on polynomials.

There is quite a large surface area of design options when writing a dialect. This talk by Jeff Niu and Mehdi Amini gives some indication of how one might start to think about dialect design. But in brief, a polynomial dialect would fit into the “computation” bucket (a dialect to represent number crunching).

A slide from “MLIR Dialect Design and Composition for Front-End Compilers” (timestamped link), describing a taxonomy of dialects.

Another idea relevant to starting a dialect design is to ask how a dialect will enable easy optimizations. That is the Optimizer dialect class above. But since this tutorial series is focusing on the low-level, day to day, nitty gritty details of working with MLIR, we’re going to focus on getting some custom dialect defined, and then come back to what makes a good dialect design later.

An empty dialect

We’ll start by defining an empty dialect and just look at the tablegen-generated code, which is done in this commit. The tablegen file looks like

include "mlir/IR/"

def Poly_Dialect : Dialect {
  let name = "poly";
  let summary = "A dialect for polynomial math";
  let description = [{
    The poly dialect defines types and operations for single-variable
    polynomials over integers.

  let cppNamespace = "::mlir::tutorial::poly";

It is almost indistinguishable from the tablegen for a Pass, except for the Dialect base class. The build rule has different flags too.

    name = "dialect_inc_gen",
    tbl_outs = [
        (["-gen-dialect-decls"], ""),
        (["-gen-dialect-defs"], ""),
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "",
    deps = [

Unlike the Pass codegen, here the dialect has us specify a codegen’ed header and implementation file, which are short enough to include directly in the article:

// bazel-bin/lib/Dialect/Poly/
namespace mlir {
namespace tutorial {

class PolyDialect : public ::mlir::Dialect {
  explicit PolyDialect(::mlir::MLIRContext *context);

  void initialize();
  friend class ::mlir::MLIRContext;
  ~PolyDialect() override;
  static constexpr ::llvm::StringLiteral getDialectNamespace() {
    return ::llvm::StringLiteral("poly");
} // namespace tutorial
} // namespace mlir

And the cpp:

// bazel-bin/lib/Dialect/Poly/
namespace mlir {
namespace tutorial {

PolyDialect::PolyDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<PolyDialect>()) {

PolyDialect::~PolyDialect() = default;

} // namespace tutorial
} // namespace mlir

Basically, they are empty containers that will hold the types and ops that we define next.

In this commit we register the dialect with the tutorial-opt main program, which is handled by a single API call. Now the dialect shows up in the help text of the tutorial-opt binary, though nothing else is there because we have no passes associated with it.

$ bazel run tools:tutorial-opt -- --help
Available Dialects: ..., pdl_interp, poly, quant, ...

Adding a trivial type

Next we’ll define a poly.poly type with no semantics, and in the next section we’ll focus on the semantics.

This commit defines a simple test that ensures we can parse and print our new type.

// RUN: tutorial-opt %s

module {
  func.func @main(%arg0: !poly.poly) -> !poly.poly {
    return %arg0 : !poly.poly

Note that the exclamation mark ! sigil prefix is required for out-of-tree MLIR types.

Next, in this commit we add the tablegen for the poly type. The tablegen looks like

include ""
include "mlir/IR/"

// A base class for all types in this dialect
class Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> {
  let mnemonic = typeMnemonic;

def Poly : Poly_Type<"Polynomial", "poly"> {
  let summary = "A polynomial with i32 coefficients";

  let description = [{
    A type for polynomials with integer coefficients in a single-variable polynomial ring.

This is effectively a boilerplate shell since nothing here is specific to polynomial arithmetic. But it shows a few new things about tablegen worth mentioning.

First and most trivially, tablegen has include statements so that you can split definitions across files. I like to put each conceptual type of thing in its own tablegen file (types, ops, attributes, etc.), but conventions differ across projects.

Second, a common pattern when defining types (and ops) is to have a base class for each dialect that all concrete type definitions inherit from. This shows first of all that there is a difference between class and def in tablegen. def defines actual types, where by “actual” I mean it generates C++ code, whereas class is only an inheritance base and disappears after tablegen is done. Think of class as allowing us to refactor out common shared code among type definitions in one file. I make a stink of this because I have mixed up class and def many times and been confused by the error messages and generated code. For example, if you change the class to a def above, you’ll see the following unhelpful error message.

lib/Dialect/Poly/ error: Expected a class name, got 'Poly_Type'
def Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> {

Moreover, TypeDef itself is a class that takes as template arguments the dialect the type should belong to and a name field (related to the type’s eventual C++ class name), and results in Tablegen associating the generated C++ classes with the same namespace as we told the dialect to use, among other things.

Third, the new field is the mnemonic declaration. This determines the name of the type in the textual representation of the IR.

The generated code again has separate and files:


namespace mlir {
class AsmParser;
class AsmPrinter;
} // namespace mlir
namespace mlir {
namespace tutorial {
namespace poly {
class PolynomialType;
class PolynomialType : public ::mlir::Type::TypeBase<PolynomialType, ::mlir::Type, ::mlir::TypeStorage> {
  using Base::Base;
  static constexpr ::llvm::StringLiteral getMnemonic() {
    return {"poly"};

} // namespace poly
} // namespace tutorial
} // namespace mlir





static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
  return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
    .Case(::mlir::tutorial::poly::PolynomialType::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
      value = ::mlir::tutorial::poly::PolynomialType::get(parser.getContext());
      return ::mlir::success(!!value);
    .Default([&](llvm::StringRef keyword, llvm::SMLoc) {
      *mnemonic = keyword;
      return std::nullopt;

static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
  return ::llvm::TypeSwitch<::mlir::Type, ::mlir::LogicalResult>(def)
   .Case<::mlir::tutorial::poly::PolynomialType>([&](auto t) {
      printer << ::mlir::tutorial::poly::PolynomialType::getMnemonic();
      return ::mlir::success();
    .Default([](auto) { return ::mlir::failure(); });

namespace mlir {
namespace tutorial {
namespace poly {
} // namespace poly
} // namespace tutorial
} // namespace mlir


The name PolynomialType is generated by adding Type to the "Polynomial" template argument we passed in the tablegen file. The name of the def itself is used to refer to the class elsewhere in tablegen files, and the two can be different.

One thing to pay attention to here is that tablegen is attempting to generate a type parser and printer for us. But it’s not usable yet—we’ll come back to this. If you build at this commit you’ll see a compiler warning like generatedTypePrinter defined but not used and a hard failure if you try running the test.

A second thing to notice is that it uses the header-guards-as-function-arguments style again, here to separate the cpp file into two include-guarded sections, one that just has a list of the types defined, and the other that has implementations of functions. The first one, GET_TYPEDEF_LIST is curious because it just includes a comma-separated list of class names. This is because the PolyDialect.cpp from this commit is responsible for registering the types with the dialect, and that happens by using this include to add the C++ class names for the types as template arguments in the Dialect’s initialization function.

// PolyDialect.cpp
#include "lib/Dialect/Poly/PolyDialect.h"

#include "lib/Dialect/Poly/PolyTypes.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h"

#include "lib/Dialect/Poly/"
#include "lib/Dialect/Poly/"

namespace mlir {
namespace tutorial {
namespace poly {

void PolyDialect::initialize() {
#include "lib/Dialect/Poly/"

} // namespace poly
} // namespace tutorial
} // namespace mlir

We’ll do the same registration dance for ops, attributes, etc., later.

The expected way to set up the C++ interface to the tablegen files is:

  • Create a header file PolyTypes.h which is the only file allowed to include,
  • Include inside PolyDialect.cpp with any additional #includes needed for the auto-generated implementations in In our case, the default parser/printer uses the type switch function from llvm/include/llvm/ADT/TypeSwitch.h.
  • If needed, add a PolyTypes.cpp with any additional implementations needed for functions declared by tablegen that can’t be automatically generated.

I don’t think I’ve found the best arrangement of these files and build targets yet. What I really want is to have each header with a corresponding, a corresponding .cpp file that includes the relevant, and to connect them all with PolyDialect.cpp. However, PolyDialect::initialize does some introspection on the classes declared in to ensure they’re valid (specifically looking for details of Storage types like is_trivially_destructible, which are generated for us in the next section), and that requires the implementations to exist in the same compilation unit (I believe). From what I’ve read of other projects, people just tend to cram things into the same cpp files, leading to multi-thousand line implementation files, which I dislike.

As a final note, the build file uses a new rule called td_files to group all the tablegen files into one build target.

The last ingredient required to get this to compile and run is to make the type parser and printer usable. Thankfully it’s one line: adding let useDefaultTypePrinterParser = 1; to the dialect tablegen (see this commit). This adds the following declarations to, and modifies the generated implementations to be member functions on PolyDialect.

  /// Parse a type registered to this dialect.
  ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;

  /// Print a type registered to this dialect.
  void printType(::mlir::Type type,
                 ::mlir::DialectAsmPrinter &os) const override;

Adding a poly type parameter

When designing the poly type, we have to think about what we want to express, and how the type will be lowered to types in lower level dialects. For the sake of this tutorial we’ll restrict to single-variable polynomials, so the indeterminate can be arbitrary and implicit (we don’t need to keep track of the variable name x or y or t in the IR).

The next question is the polynomial analogue of integer bit width: degree. The difficulty here is that tracking the exact degree of a polynomial is impossible in general. The degree of a sum of two polynomials depends on the coefficient values (the highest degree terms may cancel), which may not be known statically because they are read from an external source at runtime. So at best we could track an upper bound on the degree.

The simplest approach is to require a polynomial type to declare its “degree upper bound” statically, and then define “overflow” semantics just like integers have. A natural option is to implicitly treat polynomials as members of a ring $R[x]/ (x^D-1)$, where $D$ is the degree upper bound and $R$ is the ring of 32-bit unsigned integers. In less mathematical terms, this means that whenever a term $x^n$ in a polynomial would exceed degree $D \leq n$, we would replace $x^n$ with $x^{n \mod D}$, which is the same thing as “declaring” $x^D=1$ and reducing polynomials via that substitution until all terms have degree less than $D$.

This simple approach has the benefit of making the dialect design easy: all binary polynomial ops have the same input and output type, and lowering a poly type requires replacing poly.poly<D> with a tensor of D coefficients. Since the shapes are static and most polynomials would have the same bound, it’s probably the most performant option. The downsides are that now the programmer has to worry about overflow semantics, and if you want polynomials of larger degree you have to insert bookkeeping operations to extend the degree bound.

On the other hand, we could allow a poly to have a growing degree. So that, e.g., a poly.mul operation with two poly.poly<7> input polynomials would have a poly.poly<14> as output. This would require more work in the dialect definition to ensure the IR is valid, and the lowering would be more complex in that you’d have to manage tensors of different sizes. It would probably also have an impact on performance, since there would be more memory allocations and copying involved (but maybe the compiler could be smart enough to avoid that).

I would like to do both so as to show how these differences materialize in lowerings and optimization passes. But for now I will start with the one that I think is easier: wrapping overflow semantics.

To make that work, we need to add an attribute to our polynomial type representing its degree upper bound. The official docs on how to do this are here.

This is where tablegen starts to be quite useful, because the changes only require two lines added to the type definition’s tablegen.

  let parameters = (ins "int":$degreeBound);
  let assemblyFormat = "`<` $degreeBound `>`";

The first line associates each instance of the type with an integer parameter (the "int" can be any string containing a literal C++ type), and a name $degreeBound that is used both for the generated C++ and to refer to it elsewhere in the tablegen file.

The second line is required because now that a poly type has this associated data, we need to be able to print it to and parse it from the textual IR representation. This simplest option is to let tablegen auto-generate the parser and printer for us, but we could have also used the line let hasCustomAssemblyFormat = 1; which would generate some headers that it expects us to fill in the implementation for. In the syntax of that line, tokens in backticks are literally printed/parsed.

After adding these two lines in this commit, the generated code gets quite a bit more complicated. Here’s the entirety of both files. Some things to note:

  • PolynomialType gets a new int getDegreeBound() method, as well as a static get(MLIRContext, int) factory method.
  • The parser and printer are upgraded to the new format.
  • There is a new class called PolynomialTypeStorage that holds the int parameter and is hidden in an inner detail namespace.

The storage class is autogenerated for us now because integers have simple construction/destruction semantics. If we had a more complicated argument like an array that needed allocation, we’d have to implement special classes to define those semantics. And at the most extreme end, if we had a fully custom type parameter, we’d have to implement a storage class manually, implement things like hash_code, and register it with the dialect. For the curious, I had to do this in heir/pull/74 to implement a custom Polynomial type parameter.

The same commit updates the syntax test to include the type parameter.

Adding some simple operations

Moving a bit faster now, we want to add some polynomial operations. This commit adds a polynomial addition op. The tablegen:

include ""
include ""

def Poly_AddOp : Op<Poly_Dialect, "add"> {
  let summary = "Addition operation between polynomials.";
  let arguments = (ins Polynomial:$lhs, Polynomial:$rhs);
  let results = (outs Polynomial:$output);
  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)";

It looks very similar to a type, but the base class is Op, the arguments correspond to the operation’s inputs, and the assembly format is more complicated. We’ll enhance it in a future article, but the reason is that without inserting some special hooks, tablegen isn’t able to generate a parser that can infer what the types of the inputs and outputs should be when constructing it from the textual representation. [Aside: I feel like it should be able to with the simple example above, but for whatever reason the MLIR devs appear to have made auto-generated type inference opt-in via the trait infrastructure, see next article.]

Still, we can add a test and it passes

  // CHECK-LABEL: test_add_syntax
  func.func @test_add_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> {
    // CHECK: poly.add
    %0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
    return %0 : !poly.poly<10>

The generated C++ has quite a bit more going on. Here’s the full generated code. Mainly the generated header defines AddOp which has getters for the arguments and results, “mutable” getters for modfying the op in place, parse/print, and generated builder methods. It also defines an AddOpAdaptor class which is used during lowering. The cpp file contains mostly rote implementations of these, converting from a generic internal representation to specific types and named objects that client code would use.

Next, in this commit I added a sub and mul operation, with a slight refactoring to make a base tablegen class for a binary operation.

Next, in this commit I added a from_tensor operation that converts a list of coefficients to a polynomial, and an eval op that evaluates a polynomial at a given input value.

In the next few articles we’ll expand on this dialect’s capabilities. First, we’ll study what other “batteries” we can include in the dialect itself, and how we can run optimizations on poly programs. Then we’ll study the nuances of lowering poly to existing MLIR dialects, and eventually through to LLVM and then machine code.