MLIR — Canonicalizers and Declarative Rewrite Patterns

Table of Contents

In a previous article we defined folding functions, and used them to enable some canonicalization and the sccp constant propagation pass for the poly dialect. This time we’ll see how to add more general canonicalization patterns.

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Why is Canonicalization Needed?

MLIR provides folding as a mechanism to simplify an IR, which can result in simpler, more efficient ops (e.g., replacing multiplication by a power of 2 with a shift) or removing ops entirely (e.g., deleting $y = x+0$ and using $x$ for $y$ downstream). It also has a special hook for constant materialization.

General canonicalization differs from folding in MLIR in that it can transform many operations. In the MLIR docs they will call this “DAG-to-DAG” rewriting, where DAG stands for “directed acyclic graph” in the graph theory sense, but really just means a section of the IR.

Canonicalizers can be written in the standard way: declare the op has a canonicalizer in tablegen and then implement a generated C++ function declaration. The official docs for that are here. Or you can do it all declaratively in tablegen, the docs for that are here. We’ll do both in this article.

Aside: there is a third way, to use a new system called PDLL, but I haven’t figured out how to use that yet. It should be noted that PDLL is under active development, and in the meantime the tablegen-based approach in this article (called “DRR” for Declarative Rewrite Rules in the MLIR docs) is considered to be in maintenance mode, but not yet deprecated. I’ll try to cover PDLL in a future article.

Canonicalizers in C++

Reusing our poly dialect, we’ll start with the binary polynomial operations, adding let hasCanonicalizer = 1; to the op base class in this commit, which generates the following method headers on each of the binary op classes

// PolyOps.h.inc
static void getCanonicalizationPatterns(
  ::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);

The body of this method asks to add custom rewrite patterns to the input results set, and we can define those patterns however we feel in the C++.

The first canonicalization pattern we’ll write in this commit is for the simple identity $x^y – y^2 = (x+y)(x-y)$, which is useful because it replaces a multiplication with an addition. The only caveat is that this canonicalization is only more efficient if the squares have no other downstream uses.

// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
struct DifferenceOfSquares : public OpRewritePattern<SubOp> {
  DifferenceOfSquares(mlir::MLIRContext *context)
      : OpRewritePattern<SubOp>(context, /*benefit=*/1) {}

  LogicalResult matchAndRewrite(SubOp op, 
                                PatternRewriter &rewriter) const override {
    Value lhs = op.getOperand(0);
    Value rhs = op.getOperand(1);
    if (!lhs.hasOneUse() || !rhs.hasOneUse()) { 
      return failure();
    }

    auto rhsMul = rhs.getDefiningOp<MulOp>();
    auto lhsMul = lhs.getDefiningOp<MulOp>();
    if (!rhsMul || !lhsMul) {
      return failure();
    }

    bool rhsMulOpsAgree = rhsMul.getLhs() == rhsMul.getRhs();
    bool lhsMulOpsAgree = lhsMul.getLhs() == lhsMul.getRhs();

    if (!rhsMulOpsAgree || !lhsMulOpsAgree) {
      return failure();
    }

    auto x = lhsMul.getLhs();
    auto y = rhsMul.getLhs();

    AddOp newAdd = rewriter.create<AddOp>(op.getLoc(), x, y);
    SubOp newSub = rewriter.create<SubOp>(op.getLoc(), x, y);
    MulOp newMul = rewriter.create<MulOp>(op.getLoc(), newAdd, newSub);

    rewriter.replaceOp(op, {newMul});
    // We don't need to remove the original ops because MLIR already has
    // canonicalization patterns that remove unused ops.

    return success();
  }
};

The test in the same commit shows the impact:

// Input:
func.func @test_difference_of_squares(
  %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> {
  %2 = poly.mul %0, %0 : !poly.poly<3>
  %3 = poly.mul %1, %1 : !poly.poly<3>
  %4 = poly.sub %2, %3 : !poly.poly<3>
  %5 = poly.add %4, %4 : !poly.poly<3>
  return %5 : !poly.poly<3>
}

// Output:
// bazel run tools:tutorial-opt -- --canonicalize $FILE
func.func @test_difference_of_squares(%arg0: !poly.poly<3>, %arg1: !poly.poly<3>) -> !poly.poly<3> {
  %0 = poly.add %arg0, %arg1 : !poly.poly<3>
  %1 = poly.sub %arg0, %arg1 : !poly.poly<3>
  %2 = poly.mul %0, %1 : !poly.poly<3>
  %3 = poly.add %2, %2 : !poly.poly<3>
  return %3 : !poly.poly<3>
}

Other than this pattern being used in the getCanonicalizationPatterns function, there is nothing new here compared to the previous article on rewrite patterns.

Canonicalizers in Tablegen

The above canonicalization is really a simple kind of optimization attached to the canonicalization pass. It seems that this is how many minor optimizations end up being implemented, making the --canonicalize pass a heavyweight and powerful pass. However, the name “canonicalize” also suggests that it should be used to put the IR into a canonical form so that later passes don’t have to check as many cases.

So let’s implement a canonicalization like that as well. We’ll add one that has poly interact with the complex dialect, and implement a canonicalization that ensures complex conjugation always comes after polynomial evaluation. This works because of the identity $f(\overline{z}) = \overline{f(z)}$ for all polynomials $f$ and all complex numbers $z$.

This commit adds support for complex inputs in the poly.eval op. Note that in MLIR, complex types are forced to be floating point because all the op verifiers that construct complex numbers require it. The complex type itself, however, suggests a complex<i32> is perfectly legal, so it seems nobody in the MLIR community needs Gaussian integers yet.

The rule itself is implemented in tablegen in this commit. The main tablegen code is:

include "PolyOps.td"
include "mlir/Dialect/Complex/IR/ComplexOps.td"
include "mlir/IR/PatternBase.td"

def LiftConjThroughEval : Pat<
  (Poly_EvalOp $f, (ConjOp $z)),
  (ConjOp (Poly_EvalOp $f, $z))
>;

The two new pieces here are the Pat class (source) that defines a rewrite pattern, and the parenthetical notation that defines sub-trees of the IR being matched and rewritten. The source documentation on the Pattern parent class is quite well written, so read that for extra detail, and the normal docs provide a higher level view with additional semantics.

But the short story here is that the inputs to Pat are two “IR tree” objects (MLIR calls them “DAG nodes”), and each node in the tree is specified by parentheses ( ) with the first thing in the parentheses being the name of an operation (the tablegen name, e.g., Poly_EvalOp which comes from PolyOps.td), and the remaining arguments being the op’s arguments or attributes. Naturally, the nodes can nest, and that corresponds to a match applied to the argument. I.e., (Poly_EvalOp $f, (ConjOp $z)) means “an eval op whose first argument is anything (bind that to $f) and whose second argument is the output of a ConjOp whose input is anything (bind that to $z).

When running tablegen with the -gen-rewriters option, this generates this code, which is not much more than a thorough version of the pattern we’d write manually. Then in this commit we show how to include it in the codebase. We still have to tell MLIR which pass to add the generated patterns to. You can add each pattern by name, or use the populateWithGenerated function to add them all.

As another example, this commit reimplements the difference of squares pattern in tablegen. This one uses three additional features: a pattern that generates multiple new ops (which uses Pattern instead of Pat), binding the ops to names, and constraints that control when the pattern may be run.

// PolyPatterns.td
def HasOneUse: Constraint<CPred<"$_self.hasOneUse()">, "has one use">;

// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
def DifferenceOfSquares : Pattern<
  (Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)),
  [
    (Poly_AddOp:$sum $x, $y),
    (Poly_SubOp:$diff $x, $y),
    (Poly_MulOp:$res $sum, $diff),
  ],
  [(HasOneUse:$lhs), (HasOneUse:$rhs)]
>;

The HasOneUse constraint merely injects the quoted C++ code into a generated if guard, with $_self as a magic string to substitute in the argument when it’s used.

But then notice the syntax of (Poly_MulOp:$lhs $x, $x), the colon binds $lhs to refer to the op as a whole (or, via method overloads, its result value), so that it can be passed to the constraint. Similarly, the generated ops are all given names so they can be fed as the arguments of other generated ops Finally, the second argument of Pattern is a list of generated ops to replace the matched input IR, rather than a single node for Pat.

The benefit of doing this is significantly less boilerplate related to type casting, checking for nulls, and emitting error messages. But because you still occasionally need to inject random C++ code, and inspect the generated C++ to debug, it helps to be fluent in both styles. I don’t know how to check a constraint like “has a single use” in pure DRR tablegen.

MLIR — Defining a New Dialect

Table of Contents

In the last article in the series, we migrated the passes we had written to use the tablegen code generation framework. That was a preface to using tablegen to define dialects.

In this article we’ll define a dialect that represents arithmetic on single-variable polynomials, with coefficients in $\mathbb{Z} / 2^{32} \mathbb{Z}$ (32-bit unsigned integers).

The code for this article is in this pull request, and as usual the commits are organized to be read in order.

Sketching out a design

The basic dialect will define a new polynomial type, and provide operations to define polynomials by specifying their coefficients from standard MLIR types, extract data about a polynomial to store the results in standard MLIR types, and to do arithmetic operations on polynomials.

There is quite a large surface area of design options when writing a dialect. This talk by Jeff Niu and Mehdi Amini gives some indication of how one might start to think about dialect design. But in brief, a polynomial dialect would fit into the “computation” bucket (a dialect to represent number crunching).

A slide from “MLIR Dialect Design and Composition for Front-End Compilers” (timestamped link), describing a taxonomy of dialects.

Another idea relevant to starting a dialect design is to ask how a dialect will enable easy optimizations. That is the Optimizer dialect class above. But since this tutorial series is focusing on the low-level, day to day, nitty gritty details of working with MLIR, we’re going to focus on getting some custom dialect defined, and then come back to what makes a good dialect design later.

An empty dialect

We’ll start by defining an empty dialect and just look at the tablegen-generated code, which is done in this commit. The tablegen file looks like

include "mlir/IR/DialectBase.td"

def Poly_Dialect : Dialect {
  let name = "poly";
  let summary = "A dialect for polynomial math";
  let description = [{
    The poly dialect defines types and operations for single-variable
    polynomials over integers.
  }];

  let cppNamespace = "::mlir::tutorial::poly";
}

It is almost indistinguishable from the tablegen for a Pass, except for the Dialect base class. The build rule has different flags too.

gentbl_cc_library(
    name = "dialect_inc_gen",
    tbl_outs = [
        (["-gen-dialect-decls"], "PolyDialect.h.inc"),
        (["-gen-dialect-defs"], "PolyDialect.cpp.inc"),
    ],
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "PolyDialect.td",
    deps = [
        "@llvm-project//mlir:OpBaseTdFiles",
    ],
)

Unlike the Pass codegen, here the dialect has us specify a codegen’ed header and implementation file, which are short enough to include directly in the article:

// bazel-bin/lib/Dialect/Poly/PolyDialect.h.inc
namespace mlir {
namespace tutorial {

class PolyDialect : public ::mlir::Dialect {
  explicit PolyDialect(::mlir::MLIRContext *context);

  void initialize();
  friend class ::mlir::MLIRContext;
public:
  ~PolyDialect() override;
  static constexpr ::llvm::StringLiteral getDialectNamespace() {
    return ::llvm::StringLiteral("poly");
  }
};
} // namespace tutorial
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::PolyDialect)

And the cpp:

// bazel-bin/lib/Dialect/Poly/PolyDialect.cpp.inc
MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::tutorial::PolyDialect)
namespace mlir {
namespace tutorial {

PolyDialect::PolyDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<PolyDialect>()) {
  initialize();
}

PolyDialect::~PolyDialect() = default;

} // namespace tutorial
} // namespace mlir

Basically, they are empty containers that will hold the types and ops that we define next.

In this commit we register the dialect with the tutorial-opt main program, which is handled by a single API call. Now the dialect shows up in the help text of the tutorial-opt binary, though nothing else is there because we have no passes associated with it.

$ bazel run tools:tutorial-opt -- --help
Available Dialects: ..., pdl_interp, poly, quant, ...

Adding a trivial type

Next we’ll define a poly.poly type with no semantics, and in the next section we’ll focus on the semantics.

This commit defines a simple test that ensures we can parse and print our new type.

// RUN: tutorial-opt %s

module {
  func.func @main(%arg0: !poly.poly) -> !poly.poly {
    return %arg0 : !poly.poly
  }
}

Note that the exclamation mark ! sigil prefix is required for out-of-tree MLIR types.

Next, in this commit we add the tablegen for the poly type. The tablegen looks like

include "PolyDialect.td"
include "mlir/IR/AttrTypeBase.td"

// A base class for all types in this dialect
class Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> {
  let mnemonic = typeMnemonic;
}

def Poly : Poly_Type<"Polynomial", "poly"> {
  let summary = "A polynomial with i32 coefficients";

  let description = [{
    A type for polynomials with integer coefficients in a single-variable polynomial ring.
  }];
}

This is effectively a boilerplate shell since nothing here is specific to polynomial arithmetic. But it shows a few new things about tablegen worth mentioning.

First and most trivially, tablegen has include statements so that you can split definitions across files. I like to put each conceptual type of thing in its own tablegen file (types, ops, attributes, etc.), but conventions differ across projects.

Second, a common pattern when defining types (and ops) is to have a base class for each dialect that all concrete type definitions inherit from. This shows first of all that there is a difference between class and def in tablegen. def defines actual types, where by “actual” I mean it generates C++ code, whereas class is only an inheritance base and disappears after tablegen is done. Think of class as allowing us to refactor out common shared code among type definitions in one file. I make a stink of this because I have mixed up class and def many times and been confused by the error messages and generated code. For example, if you change the class to a def above, you’ll see the following unhelpful error message.

lib/Dialect/Poly/PolyTypes.td:8:5: error: Expected a class name, got 'Poly_Type'
def Poly_Type<string name, string typeMnemonic> : TypeDef<Poly_Dialect, name> {

Moreover, TypeDef itself is a class that takes as template arguments the dialect the type should belong to and a name field (related to the type’s eventual C++ class name), and results in Tablegen associating the generated C++ classes with the same namespace as we told the dialect to use, among other things.

Third, the new field is the mnemonic declaration. This determines the name of the type in the textual representation of the IR.

The generated code again has separate .h.inc and .cpp.inc files:

// PolyTypes.h.inc
#ifdef GET_TYPEDEF_CLASSES
#undef GET_TYPEDEF_CLASSES

namespace mlir {
class AsmParser;
class AsmPrinter;
} // namespace mlir
namespace mlir {
namespace tutorial {
namespace poly {
class PolynomialType;
class PolynomialType : public ::mlir::Type::TypeBase<PolynomialType, ::mlir::Type, ::mlir::TypeStorage> {
public:
  using Base::Base;
  static constexpr ::llvm::StringLiteral getMnemonic() {
    return {"poly"};
  }

};
} // namespace poly
} // namespace tutorial
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolynomialType)

#endif  // GET_TYPEDEF_CLASSES
// PolyTypes.cpp.inc
#ifdef GET_TYPEDEF_LIST
#undef GET_TYPEDEF_LIST

::mlir::tutorial::poly::PolynomialType

#endif  // GET_TYPEDEF_LIST

#ifdef GET_TYPEDEF_CLASSES
#undef GET_TYPEDEF_CLASSES

static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
  return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
    .Case(::mlir::tutorial::poly::PolynomialType::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {
      value = ::mlir::tutorial::poly::PolynomialType::get(parser.getContext());
      return ::mlir::success(!!value);
    })
    .Default([&](llvm::StringRef keyword, llvm::SMLoc) {
      *mnemonic = keyword;
      return std::nullopt;
    });
}

static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) {
  return ::llvm::TypeSwitch<::mlir::Type, ::mlir::LogicalResult>(def)
   .Case<::mlir::tutorial::poly::PolynomialType>([&](auto t) {
      printer << ::mlir::tutorial::poly::PolynomialType::getMnemonic();
      return ::mlir::success();
    })
    .Default([](auto) { return ::mlir::failure(); });
}

namespace mlir {
namespace tutorial {
namespace poly {
} // namespace poly
} // namespace tutorial
} // namespace mlir
MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::tutorial::poly::PolynomialType)

#endif  // GET_TYPEDEF_CLASSES

The name PolynomialType is generated by adding Type to the "Polynomial" template argument we passed in the tablegen file. The name of the def itself is used to refer to the class elsewhere in tablegen files, and the two can be different.

One thing to pay attention to here is that tablegen is attempting to generate a type parser and printer for us. But it’s not usable yet—we’ll come back to this. If you build at this commit you’ll see a compiler warning like generatedTypePrinter defined but not used and a hard failure if you try running the test.

A second thing to notice is that it uses the header-guards-as-function-arguments style again, here to separate the cpp file into two include-guarded sections, one that just has a list of the types defined, and the other that has implementations of functions. The first one, GET_TYPEDEF_LIST is curious because it just includes a comma-separated list of class names. This is because the PolyDialect.cpp from this commit is responsible for registering the types with the dialect, and that happens by using this include to add the C++ class names for the types as template arguments in the Dialect’s initialization function.

// PolyDialect.cpp
#include "lib/Dialect/Poly/PolyDialect.h"

#include "lib/Dialect/Poly/PolyTypes.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h"

#include "lib/Dialect/Poly/PolyDialect.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/Poly/PolyTypes.cpp.inc"

namespace mlir {
namespace tutorial {
namespace poly {

void PolyDialect::initialize() {
  addTypes<
#define GET_TYPEDEF_LIST
#include "lib/Dialect/Poly/PolyTypes.cpp.inc"
      >();
}

} // namespace poly
} // namespace tutorial
} // namespace mlir

We’ll do the same registration dance for ops, attributes, etc., later.

The expected way to set up the C++ interface to the tablegen files is:

  • Create a header file PolyTypes.h which is the only file allowed to include PolyTypes.h.inc,
  • Include PolyTypes.cpp.inc inside PolyDialect.cpp with any additional #includes needed for the auto-generated implementations in PolyTypes.cpp.inc. In our case, the default parser/printer uses the type switch function from llvm/include/llvm/ADT/TypeSwitch.h.
  • If needed, add a PolyTypes.cpp with any additional implementations needed for functions declared by tablegen that can’t be automatically generated.

I don’t think I’ve found the best arrangement of these files and build targets yet. What I really want is to have each header with a corresponding .h.inc, a corresponding .cpp file that includes the relevant .cpp.inc, and to connect them all with PolyDialect.cpp. However, PolyDialect::initialize does some introspection on the classes declared in PolyTypes.h.inc to ensure they’re valid (specifically looking for details of Storage types like is_trivially_destructible, which are generated for us in the next section), and that requires the implementations to exist in the same compilation unit (I believe). From what I’ve read of other projects, people just tend to cram things into the same cpp files, leading to multi-thousand line implementation files, which I dislike.

As a final note, the build file uses a new rule called td_files to group all the tablegen files into one build target.

The last ingredient required to get this to compile and run is to make the type parser and printer usable. Thankfully it’s one line: adding let useDefaultTypePrinterParser = 1; to the dialect tablegen (see this commit). This adds the following declarations to PolyDialect.h.inc, and modifies the generated implementations to be member functions on PolyDialect.

  /// Parse a type registered to this dialect.
  ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;

  /// Print a type registered to this dialect.
  void printType(::mlir::Type type,
                 ::mlir::DialectAsmPrinter &os) const override;

Adding a poly type parameter

When designing the poly type, we have to think about what we want to express, and how the type will be lowered to types in lower level dialects. For the sake of this tutorial we’ll restrict to single-variable polynomials, so the indeterminate can be arbitrary and implicit (we don’t need to keep track of the variable name x or y or t in the IR).

The next question is the polynomial analogue of integer bit width: degree. The difficulty here is that tracking the exact degree of a polynomial is impossible in general. The degree of a sum of two polynomials depends on the coefficient values (the highest degree terms may cancel), which may not be known statically because they are read from an external source at runtime. So at best we could track an upper bound on the degree.

The simplest approach is to require a polynomial type to declare its “degree upper bound” statically, and then define “overflow” semantics just like integers have. A natural option is to implicitly treat polynomials as members of a ring $R[x]/ (x^D-1)$, where $D$ is the degree upper bound and $R$ is the ring of 32-bit unsigned integers. In less mathematical terms, this means that whenever a term $x^n$ in a polynomial would exceed degree $D \leq n$, we would replace $x^n$ with $x^{n \mod D}$, which is the same thing as “declaring” $x^D=1$ and reducing polynomials via that substitution until all terms have degree less than $D$.

This simple approach has the benefit of making the dialect design easy: all binary polynomial ops have the same input and output type, and lowering a poly type requires replacing poly.poly<D> with a tensor of D coefficients. Since the shapes are static and most polynomials would have the same bound, it’s probably the most performant option. The downsides are that now the programmer has to worry about overflow semantics, and if you want polynomials of larger degree you have to insert bookkeeping operations to extend the degree bound.

On the other hand, we could allow a poly to have a growing degree. So that, e.g., a poly.mul operation with two poly.poly<7> input polynomials would have a poly.poly<14> as output. This would require more work in the dialect definition to ensure the IR is valid, and the lowering would be more complex in that you’d have to manage tensors of different sizes. It would probably also have an impact on performance, since there would be more memory allocations and copying involved (but maybe the compiler could be smart enough to avoid that).

I would like to do both so as to show how these differences materialize in lowerings and optimization passes. But for now I will start with the one that I think is easier: wrapping overflow semantics.

To make that work, we need to add an attribute to our polynomial type representing its degree upper bound. The official docs on how to do this are here.

This is where tablegen starts to be quite useful, because the changes only require two lines added to the type definition’s tablegen.

  let parameters = (ins "int":$degreeBound);
  let assemblyFormat = "`<` $degreeBound `>`";

The first line associates each instance of the type with an integer parameter (the "int" can be any string containing a literal C++ type), and a name $degreeBound that is used both for the generated C++ and to refer to it elsewhere in the tablegen file.

The second line is required because now that a poly type has this associated data, we need to be able to print it to and parse it from the textual IR representation. This simplest option is to let tablegen auto-generate the parser and printer for us, but we could have also used the line let hasCustomAssemblyFormat = 1; which would generate some headers that it expects us to fill in the implementation for. In the syntax of that line, tokens in backticks are literally printed/parsed.

After adding these two lines in this commit, the generated code gets quite a bit more complicated. Here’s the entirety of both files. Some things to note:

  • PolynomialType gets a new int getDegreeBound() method, as well as a static get(MLIRContext, int) factory method.
  • The parser and printer are upgraded to the new format.
  • There is a new class called PolynomialTypeStorage that holds the int parameter and is hidden in an inner detail namespace.

The storage class is autogenerated for us now because integers have simple construction/destruction semantics. If we had a more complicated argument like an array that needed allocation, we’d have to implement special classes to define those semantics. And at the most extreme end, if we had a fully custom type parameter, we’d have to implement a storage class manually, implement things like hash_code, and register it with the dialect. For the curious, I had to do this in heir/pull/74 to implement a custom Polynomial type parameter.

The same commit updates the syntax test to include the type parameter.

Adding some simple operations

Moving a bit faster now, we want to add some polynomial operations. This commit adds a polynomial addition op. The tablegen:

include "PolyDialect.td"
include "PolyTypes.td"

def Poly_AddOp : Op<Poly_Dialect, "add"> {
  let summary = "Addition operation between polynomials.";
  let arguments = (ins Polynomial:$lhs, Polynomial:$rhs);
  let results = (outs Polynomial:$output);
  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)";
}

It looks very similar to a type, but the base class is Op, the arguments correspond to the operation’s inputs, and the assembly format is more complicated. We’ll enhance it in a future article, but the reason is that without inserting some special hooks, tablegen isn’t able to generate a parser that can infer what the types of the inputs and outputs should be when constructing it from the textual representation. [Aside: I feel like it should be able to with the simple example above, but for whatever reason the MLIR devs appear to have made auto-generated type inference opt-in via the trait infrastructure, see next article.]

Still, we can add a test and it passes

  // CHECK-LABEL: test_add_syntax
  func.func @test_add_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> {
    // CHECK: poly.add
    %0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
    return %0 : !poly.poly<10>
  }

The generated C++ has quite a bit more going on. Here’s the full generated code. Mainly the generated header defines AddOp which has getters for the arguments and results, “mutable” getters for modfying the op in place, parse/print, and generated builder methods. It also defines an AddOpAdaptor class which is used during lowering. The cpp file contains mostly rote implementations of these, converting from a generic internal representation to specific types and named objects that client code would use.

Next, in this commit I added a sub and mul operation, with a slight refactoring to make a base tablegen class for a binary operation.

Next, in this commit I added a from_tensor operation that converts a list of coefficients to a polynomial, and an eval op that evaluates a polynomial at a given input value.

In the next few articles we’ll expand on this dialect’s capabilities. First, we’ll study what other “batteries” we can include in the dialect itself, and how we can run optimizations on poly programs. Then we’ll study the nuances of lowering poly to existing MLIR dialects, and eventually through to LLVM and then machine code.

Negacyclic Polynomial Multiplication

In this article I’ll cover three techniques to compute special types of polynomial products that show up in lattice cryptography and fully homomorphic encryption. Namely, the negacyclic polynomial product, which is the product of two polynomials in the quotient ring $\mathbb{Z}[x] / (x^N + 1)$. As a precursor to the negacyclic product, we’ll cover the simpler cyclic product.

All of the Python code written for this article is on GitHub.

The DFT and Cyclic Polynomial Multiplication

A recent program gallery piece showed how single-variable polynomial multiplication could be implemented using the Discrete Fourier Transform (DFT). This boils down to two observations:

  1. The product of two polynomials $f, g$ can be computed via the convolution of the coefficients of $f$ and $g$.
  2. The Convolution Theorem, which says that the Fourier transform of a convolution of two signals $f, g$ is the point-wise product of the Fourier transforms of the two signals. (The same holds for the DFT)

This provides a much faster polynomial product operation than one could implement using the naïve polynomial multiplication algorithm (though see the last section for an implementation anyway). The DFT can be used to speed up large integer multiplication as well.

A caveat with normal polynomial multiplication is that one needs to pad the input coefficient lists with enough zeros so that the convolution doesn’t “wrap around.” That padding results in the output having length at least as large as the sum of the degrees of $f$ and $g$ (see the program gallery piece for more details).

If you don’t pad the polynomials, instead you get what’s called a cyclic polynomial product. More concretely, if the two input polynomials $f, g$ are represented by coefficient lists $(f_0, f_1, \dots, f_{N-1}), (g_0, g_1, \dots, g_{N-1})$ of length $N$ (implying the inputs are degree at most $N-1$, i.e., the lists may end in a tail of zeros), then the Fourier Transform technique computes

\[ f(x) \cdot g(x) \mod (x^N – 1) \]

This modulus is in the sense of a quotient ring $\mathbb{Z}[x] / (x^N – 1)$, where $(x^N – 1)$ denotes the ring ideal generated by $x^N-1$, i.e., all polynomials that are evenly divisible by $x^N – 1$. A particularly important interpretation of this quotient ring is achieved by interpreting the ideal generator $x^N – 1$ as an equation $x^N – 1 = 0$, also known as $x^N = 1$. To get the canonical ring element corresponding to any polynomial $h(x) \in \mathbb{Z}[x]$, you “set” $x^N = 1$ and reduce the polynomial until there are no more terms with degree bigger than $N-1$. For example, if $N=5$ then $x^{10} + x^6 – x^4 + x + 2 = -x^4 + 2x + 3$ (the $x^{10}$ becomes 1, and $x^6 = x$).

To prove the DFT product computes a product in this particular ring, note how the convolution theorem produces the following formula, where $\textup{fprod}(f, g)$ denotes the process of taking the Fourier transform of the two coefficient lists, multiplying them entrywise, and taking a (properly normalized) inverse FFT, and $\textup{fprod}(f, g)(j)$ is the $j$-th coefficient of the output polynomial:

\[ \textup{fprod}(f, g)(j) = \sum_{k=0}^{N-1} f_k g_{j-k \textup{ mod } N} \]

In words, the output polynomial coefficient $j$ equals the sum of all products of pairs of coefficients whose indices sum to $j$ when considered “wrapping around” $N$. Fixing $j=1$ as an example, $\textup{fprod}(f, g)(1) = f_0 g_1 + f_1g_0 + f_2 g_{N-1} + f_3 g_{N-2} + \dots$. This demonstrates the “set $x^N = 1$” interpretation above: the term $f_2 g_{N-1}$ corresponds to the product $f_2x^2 \cdot g_{N-1}x^{N-1}$, which contributes to the $x^1$ term of the polynomial product if and only if $x^{2 + N-1} = x$, if and only if $x^N = 1$.

To achieve this in code, we simply use the version of the code from the program gallery piece, but fix the size of the arrays given to numpy.fft.fft in advance. We will also, for simplicity, assume the $N$ one wishes to use is a power of 2. The resulting code is significantly simpler than the original program gallery code (we omit zero-padding to length $N$ for brevity).

import numpy
from numpy.fft import fft, ifft

def cyclic_polymul(p1, p2, N):
    """Multiply two integer polynomials modulo (x^N - 1).

    p1 and p2 are arrays of coefficients in degree-increasing order.
    """
    assert len(p1) == len(p2) == N
    product = fft(p1) * fft(p2)
    inverted = ifft(product)
    return numpy.round(numpy.real(inverted)).astype(numpy.int32)

As a side note, there’s nothing that stops this from working with polynomials that have real or complex coefficients, but so long as we use small magnitude integer coefficients and round at the end, I don’t have to worry about precision issues (hat tip to Brad Lucier for suggesting an excellent paper by Colin Percival, “Rapid multiplication modulo the sum and difference of highly composite numbers“, which covers these precision issues in detail).

Negacyclic polynomials, DFT with duplication

Now the kind of polynomial quotient ring that shows up in cryptography is critically not $\mathbb{Z}[x]/(x^N-1)$, because that ring has enough easy-to-reason-about structure that it can’t hide secrets. Instead, cryptographers use the ring $\mathbb{Z}[x]/(x^N+1)$ (the minus becomes a plus), which is believed to be more secure for cryptography—although I don’t have a great intuitive grasp on why.

The interpretation is similar here as before, except we “set” $x^N = -1$ instead of $x^N = 1$ in our reductions. Repeating the above example, if $N=5$ then $x^{10} + x^6 – x^4 + x + 2 = -x^4 + 3$ (the $x^{10}$ becomes $(-1)^2 = 1$, and $x^6 = -x$). It’s called negacyclic because as a term $x^k$ passes $k \geq N$, it cycles back to $x^0 = 1$, but with a sign flip.

The negacyclic polynomial multiplication can’t use the DFT without some special hacks. The first and simplest hack is to double the input lists with a negation. That is, starting from $f(x) \in \mathbb{Z}[x]/(x^N+1)$, we can define $f^*(x) = f(x) – x^Nf(x)$ in a different ring $\mathbb{Z}[x]/(x^{2N} – 1)$ (and similarly for $g^*$ and $g$).

Before seeing how this causes the DFT to (almost) compute a negacyclic polynomial product, some math wizardry. The ring $\mathbb{Z}[x]/(x^{2N} – 1)$ is special because it contains our negacyclic ring as a subring. Indeed, because the polynomial $x^{2N} – 1$ factors as $(x^N-1)(x^N+1)$, and because these two factors are coprime in $\mathbb{Z}[x]/(x^{2N} – 1)$, the Chinese remainder theorem (aka Sun-tzu’s theorem) generalizes to polynomial rings and says that any polynomial in $\mathbb{Z}[x]/(x^{2N} – 1)$ is uniquely determined by its remainders when divided by $(x^N-1)$ and $(x^N+1)$. Another way to say it is that the ring $\mathbb{Z}[x]/(x^{2N} – 1)$ factors as a direct product of the two rings $\mathbb{Z}[x]/(x^{N} – 1)$ and $\mathbb{Z}[x]/(x^{N} + 1)$.

Now mapping a polynomial $f(x)$ from the bigger ring $(x^{2N} – 1)$ to the smaller ring $(x^{N}+1)$ involves taking a remainder of $f(x)$ when dividing by $x^{N}+1$ (“setting” $x^N = -1$ and reducing). There are many possible preimage mappings, depending on what your goal is. In this case, we actually intentionally choose a non preimage mapping, because in general to compute a preimage requires solving a system of congruences in the larger polynomial ring. So instead we choose $f(x) \mapsto f^*(x) = f(x) – x^Nf(x) = -f(x)(x^N – 1)$, which maps back down to $2f(x)$ in $\mathbb{Z}[x]/(x^{N} + 1)$. This preimage mapping has a particularly nice structure, in that you build it by repeating the polynomial’s coefficients twice and flipping the sign of the second half. It’s easy to see that the product $f^*(x) g^*(x)$ maps down to $4f(x)g(x)$.

So if we properly account for these extra constant factors floating around, our strategy to perform negacyclic polynomial multiplication is to map $f$ and $g$ up to the larger ring as described, compute their cyclic product (modulo $x^{2N} – 1$) using the FFT, and then the result should be a degree $2N-1$ polynomial which can be reduced with one more modular reduction step to the right degree $N-1$ negacyclic product, i.e., setting $x^N = -1$, which materializes as taking the second half of the coefficients, flipping their signs, and adding them to the corresponding coefficients in the first half.

The code for this is:

def negacyclic_polymul_preimage_and_map_back(p1, p2):
    p1_preprocessed = numpy.concatenate([p1, -p1])
    p2_preprocessed = numpy.concatenate([p2, -p2])
    product = fft(p1_preprocessed) * fft(p2_preprocessed)
    inverted = ifft(product)
    rounded = numpy.round(numpy.real(inverted)).astype(p1.dtype)
    return (rounded[: p1.shape[0]] - rounded[p1.shape[0] :]) // 4

However, this chosen mapping hides another clever trick. The product of the two preimages has enough structure that we can “read” the result off without doing the full “set $x^N = -1$” reduction step. Mapping $f$ and $g$ up to $f^*, g^*$ and taking their product modulo $(x^{2N} – 1)$ gives

\[ \begin{aligned} f^*g^* &= -f(x^N-1) \cdot -g(x^N – 1) \\ &= fg (x^N-1)^2 \\ &= fg(x^{2N} – 2x^N + 1) \\ &= fg(2 – 2x^N) \\ &= 2(fg – x^Nfg) \end{aligned} \]

This has the same syntactical format as the original mapping $f \mapsto f – x^Nf$, with an extra factor of 2, and so its coefficients also have the form “repeat the coefficients and flip the sign of the second half” (times two). We can then do the “inverse mapping” by reading only the first half of the coefficients and dividing by 2.

def negacyclic_polymul_use_special_preimage(p1, p2):
    p1_preprocessed = numpy.concatenate([p1, -p1])
    p2_preprocessed = numpy.concatenate([p2, -p2])
    product = fft(p1_preprocessed) * fft(p2_preprocessed)
    inverted = ifft(product)
    rounded = numpy.round(0.5 * numpy.real(inverted)).astype(p1.dtype)
    return rounded[: p1.shape[0]]

Our chosen mapping $f \mapsto f-x^Nf$ is not particularly special, except that it uses a small number of pre and post-processing operations. For example, if you instead used the mapping $f \mapsto 2f + x^Nf$ (which would map back to $f$ exactly), then the FFT product would result in $5fg + 4x^Nfg$ in the larger ring. You can still read off the coefficients as before, but you’d have to divide by 5 instead of 2 (which, the superstitious would say, is harder). It seems that “double and negate” followed by “halve and take first half” is the least amount of pre/post processing possible.

Negacyclic polynomials with a “twist”

The previous section identified a nice mapping (or embedding) of the input polynomials into a larger ring. But studying that shows some symmetric structure in the FFT output. I.e., the coefficients of $f$ and $g$ are repeated twice, with some scaling factors. It also involves taking an FFT of two $2N$-dimensional vectors when we start from two $N$-dimensional vectors.

This sort of situation should make you think that we can do this more efficiently, either by using a smaller size FFT or by packing some data into the complex part of the input, and indeed we can do both.

[Aside: it’s well known that if all the entries of an FFT input are real, then the result also has symmetry that can be exploted for efficiency by reframing the problem as a size-N/2 FFT in some cases, and just removing half the FFT algorithm’s steps in other cases, see Wikipedia for more]

This technique was explained in Fast multiplication and its applications (pdf link) by Daniel Bernstein, a prominent cryptographer who specializes in cryptography performance, and whose work appears in widely-used standards like TLS, OpenSSH, and he designed a commonly used elliptic curve for cryptography.

[Aside: Bernstein cites this technique as using something called the “Tangent FFT (pdf link).” This is a drop-in FFT replacement he invented that is faster than previous best (split-radix FFT), and Bernstein uses it mainly to give a precise expression for the number of operations required to do the multiplication end to end. We will continue to use the numpy FFT implementation, since in this article I’m just focusing on how to express negacyclic multiplication in terms of the FFT. Also worth noting both the Tangent FFT and “Fast multiplication” papers frame their techniques—including FFT algorithm implementations!—in terms of polynomial ring factorizations and mappings. Be still, my beating cardioid.]

In terms of polynomial mappings, we start from the ring $\mathbb{R}[x] / (x^N + 1)$, where $N$ is a power of 2. We then pick a reversible mapping from $\mathbb{R}[x]/(x^N + 1) \to \mathbb{C}[x]/(x^{N/2} – 1)$ (note the field change from real to complex), apply the FFT to the image of the mapping, and reverse appropriately it at the end.

One such mapping takes two steps, first mapping $\mathbb{R}[x]/(x^N + 1) \to \mathbb{C}[x]/(x^{N/2} – i)$ and then from $\mathbb{C}[x]/(x^{N/2} – i) \to \mathbb{C}[x]/(x^{N/2} – 1)$. The first mapping is as easy as the last section, because $(x^N + 1) = (x^{N/2} + i) (x^{N/2} – i)$, and so we can just set $x^{N/2} = i$ and reduce the polynomial. This as the effect of making the second half of the polynomial’s coefficients become the complex part of the first half of the coefficients.

The second mapping is more nuanced, because we’re not just reducing via factorization. And we can’t just map $i \mapsto 1$ generically, because that would reduce complex numbers down to real values. Instead, we observe that (momentarily using an arbitrary degree $k$ instead of $N/2$), for any polynomial $f \in \mathbb{C}[x]$, the remainder of $f \mod x^k-i$ uniquely determines the remainder of $f \mod x^k – 1$ via the change of variables $x \mapsto \omega_{4k} x$, where $\omega_{4k}$ is a $4k$-th primitive root of unity $\omega_{4k} = e^{\frac{2 \pi i}{4k}}$. Spelling this out in more detail: if $f(x) \in \mathbb{C}[x]$ has remainder $f(x) = g(x) + h(x)(x^k – i)$ for some polynomial $h(x)$, then

\[ \begin{aligned} f(\omega_{4k}x) &= g(\omega_{4k}x) + h(\omega_{4k}x)((\omega_{4k}x)^{k} – i) \\ &= g(\omega_{4k}x) + h(\omega_{4k}x)(e^{\frac{\pi i}{2}} x^k – i) \\ &= g(\omega_{4k}x) + i h(\omega_{4k}x)(x^k – 1) \\ &= g(\omega_{4k}x) \mod (x^k – 1) \end{aligned} \]

Translating this back to $k=N/2$, the mapping from $\mathbb{C}[x]/(x^{N/2} – i) \to \mathbb{C}[x]/(x^{N/2} – 1)$ is $f(x) \mapsto f(\omega_{2N}x)$. And if $f = f_0 + f_1x + \dots + f_{N/2 – 1}x^{N/2 – 1}$, then the mapping involves multiplying each coefficient $f_k$ by $\omega_{2N}^k$.

When you view polynomials as if they were a simple vector of their coefficients, then this operation $f(x) \mapsto f(\omega_{k}x)$ looks like $(a_0, a_1, \dots, a_n) \mapsto (a_0, \omega_{k} a_1, \dots, \omega_k^n a_n)$. Bernstein calls the operation a twist of $\mathbb{C}^n$, which I mused about in this Mathstodon thread.

What’s most important here is that each of these transformations are invertible. The first because the top half coefficients end up in the complex parts of the polynomial, and the second because the mapping $f(x) \mapsto f(\omega_{2N}^{-1}x)$ is an inverse. Together, this makes the preprocessing and postprocessing exact inverses of each other. The code is then

def negacyclic_polymul_complex_twist(p1, p2):
    n = p2.shape[0]
    primitive_root = primitive_nth_root(2 * n)
    root_powers = primitive_root ** numpy.arange(n // 2)

    p1_preprocessed = (p1[: n // 2] + 1j * p1[n // 2 :]) * root_powers
    p2_preprocessed = (p2[: n // 2] + 1j * p2[n // 2 :]) * root_powers

    p1_ft = fft(p1_preprocessed)
    p2_ft = fft(p2_preprocessed)
    prod = p1_ft * p2_ft
    ifft_prod = ifft(prod)
    ifft_rotated = ifft_prod * primitive_root ** numpy.arange(0, -n // 2, -1)

    return numpy.round(
        numpy.concatenate([numpy.real(ifft_rotated), numpy.imag(ifft_rotated)])
    ).astype(p1.dtype)

And so, at the cost of a bit more pre- and postprocessing, we can negacyclically multiply two degree $N-1$ polynomials using an FFT of length $N/2$. In theory, no information is wasted and this is optimal.

And finally, a simple matrix multiplication

The last technique I wanted to share is not based on the FFT, but it’s another method for doing negacyclic polynomial multiplication that has come in handy in situations where I am unable to use FFTs. I call it the Toeplitz method, because one of the polynomials is converted to a Toeplitz matrix. Sometimes I hear it referred to as a circulant matrix technique, but due to the negacyclic sign flip, I don’t think it’s a fully accurate term.

The idea is to put the coefficients of one polynomial $f(x) = f_0 + f_1x + \dots + f_{N-1}x^{N-1}$ into a matrix as follows:

\[ \begin{pmatrix} f_0 & -f_{N-1} & \dots & -f_1 \\ f_1 & f_0 & \dots & -f_2 \\ \vdots & \vdots & \ddots & \vdots \\ f_{N-1} & f_{N-2} & \dots & f_0 \end{pmatrix} \]

The polynomial coefficients are written down in the first column unchanged, then in each subsequent column, the coefficients are cyclically shifted down one, and the term that wraps around the top has its sign flipped. When the second polynomial is treated as a vector of its coefficients, say, $g(x) = g_0 + g_1x + \dots + g_{N-1}x^{N-1}$, then the matrix-vector product computes their negacyclic product (as a vector of coefficients):

\[ \begin{pmatrix} f_0 & -f_{N-1} & \dots & -f_1 \\ f_1 & f_0 & \dots & -f_2 \\ \vdots & \vdots & \ddots & \vdots \\ f_{N-1} & f_{N-2} & \dots & f_0 \end{pmatrix} \begin{pmatrix} g_0 \\ g_1 \\ \vdots \\ g_{N-1} \end{pmatrix} \]

This works because each row $j$ corresponds to one output term $x^j$, and the cyclic shift for that row accounts for the degree-wrapping, with the sign flip accounting for the negacyclic part. (If there were no sign attached, this method could be used to compute a cyclic polynomial product).

The Python code for this is

def cylic_matrix(c: numpy.array) -> numpy.ndarray:
    """Generates a cyclic matrix with each row of the input shifted.

    For input: [1, 2, 3], generates the following matrix:

        [[1 2 3]
         [2 3 1]
         [3 1 2]]
    """
    c = numpy.asarray(c).ravel()
    a, b = numpy.ogrid[0 : len(c), 0 : -len(c) : -1]
    indx = a + b
    return c[indx]


def negacyclic_polymul_toeplitz(p1, p2):
    n = len(p1)

    # Generates a sign matrix with 1s below the diagonal and -1 above.
    up_tri = numpy.tril(numpy.ones((n, n), dtype=int), 0)
    low_tri = numpy.triu(numpy.ones((n, n), dtype=int), 1) * -1
    sign_matrix = up_tri + low_tri

    cyclic_matrix = cylic_matrix(p1)
    toeplitz_p1 = sign_matrix * cyclic_matrix
    return numpy.matmul(toeplitz_p1, p2)

Obviously on most hardware this would be less efficient than an FFT-based method (and there is some relationship between circulant matrices and Fourier Transforms, see Wikipedia). But in some cases—when the polynomials are small, or one of the two polynomials is static, or a particular hardware choice doesn’t handle FFTs with high-precision floats very well, or you want to take advantage of natural parallelism in the matrix-vector product—this method can be useful. It’s also simpler to reason about.

Until next time!

Polynomial Multiplication Using the FFT

Problem: Compute the product of two polynomials efficiently.

Solution:

import numpy
from numpy.fft import fft, ifft


def poly_mul(p1, p2):
    """Multiply two polynomials.

    p1 and p2 are arrays of coefficients in degree-increasing order.
    """
    deg1 = p1.shape[0] - 1
    deg2 = p1.shape[0] - 1
    # Would be 2*(deg1 + deg2) + 1, but the next-power-of-2 handles the +1
    total_num_pts = 2 * (deg1 + deg2)
    next_power_of_2 = 1 << (total_num_pts - 1).bit_length()

    ff_p1 = fft(numpy.pad(p1, (0, next_power_of_2 - p1.shape[0])))
    ff_p2 = fft(numpy.pad(p2, (0, next_power_of_2 - p2.shape[0])))
    product = ff_p1 * ff_p2
    inverted = ifft(product)
    rounded = numpy.round(numpy.real(inverted)).astype(numpy.int32)
    return numpy.trim_zeros(rounded, trim='b')

Discussion: The Fourier Transform has a lot of applications to science, and I’ve covered it on this blog before, see the Signal Processing section of Main Content. But it also has applications to fast computational mathematics.

The naive algorithm for multiplying two polynomials is the “grade-school” algorithm most readers will already be familiar with (see e.g., this page), but for large polynomials that algorithm is slow. It requires $O(n^2)$ arithmetic operations to multiply two polynomials of degree $n$.

This short tip shows a different approach, which is based on the idea of polynomial interpolation. As a side note, I show the basic theory of polynomial interpolation in chapter 2 of my book, A Programmer’s Introduction to Mathematics, along with an application to cryptography called “Secret Sharing.”

The core idea is that given $n+1$ distinct evaluations of a polynomial $p(x)$ (i.e., points $(x, p(x))$ with different $x$ inputs), you can reconstruct the coefficients of $p(x)$ exactly. And if you have two such point sets for two different polynomials $p(x), q(x)$, a valid point set of the product $(pq)(x)$ is the product of the points that have the same $x$ inputs.

\[ \begin{aligned} p(x) &= \{ (x_0, p(x_0)), (x_1, p(x_1)), \dots, (x_n, p(x_n)) \} \\ q(x) &= \{ (x_0, q(x_0)), (x_1, q(x_1)), \dots, (x_n, q(x_n)) \} \\ (pq)(x) &= \{ (x_0, p(x_0)q(x_0)), (x_1, p(x_1)q(x_1)), \dots, (x_n, p(x_n)q(x_n)) \} \end{aligned} \]

The above uses $=$ loosely to represent that the polynomial $p$ can be represented by the point set on the right hand side.

So given two polynomials $p(x), q(x)$ in their coefficient forms, one can first convert them to their point forms, multiply the points, and then reconstruct the resulting product.

The problem is that the two conversions, both to and from the coefficient form, are inefficient for arbitrary choices of points $x_0, \dots, x_n$. The trick comes from choosing special points, in such a way that the intermediate values computed in the conversion steps can be reused. This is where the Fourier Transform comes in: choose $x_0 = \omega_{N}$, the complex-N-th root of unity, and $x_k = \omega_N^k$ as its exponents. $N$ is required to be large enough so that $\omega_N$’s exponents have at least $2n+1$ distinct values required for interpolating a degree-at-most-$2n$ polynomial, and because we’re doing the Fourier Transform, it will naturally be “the next largest power of 2” bigger than the degree of the product polynomial.

Then one has to observe that, by its very formula, the Fourier Transform is exactly the evaluation of a polynomial at the powers of the $N$-th root of unity! In formulas: if $a = (a_0, \dots, a_{n-1})$ is a list of real numbers define $p_a(x) = a_0 + a_1x + \dots + a_{n-1}x^{n-1}$. Then $\mathscr{F}(a)(k)$, the Fourier Transform of $a$ at index $k$, is equal to $p_a(\omega_n^k)$. These notes by Denis Pankratov have more details showing that the Fourier Transform formula is a polynomial evaluation (see Section 3), and this YouTube video by Reducible also has a nice exposition. This interpretation of the FT as polynomial evaluation seems to inspire quite a few additional techniques for computing the Fourier Transform that I plan to write about in the future.

The last step is to reconstruct the product polynomial from the product of the two point sets, but because the Fourier Transform is an invertible function (and linear, too), the inverse Fourier Transform does exactly that: given a list of the $n$ evaluations of a polynomial at $\omega_n^k, k=0, \dots, n-1$, return the coefficients of the polynomial.

This all fits together into the code above:

  1. Pad the input coefficient lists with zeros, so that the lists are a power of 2 and at least 1 more than the degree of the output product polynomial.
  2. Compute the FFT of the two padded polynomials.
  3. Multiply the result pointwise.
  4. Compute the IFFT of the product.
  5. Round the resulting (complex) values back to integers.

Hey, wait a minute! What about precision issues?

They are a problem when you have large numbers or large polynomials, because the intermediate values in steps 2-4 can lose precision due to the floating point math involved. This short note of Richard Fateman discusses some of those issues, and two paths forward include: deal with it somehow, or use an integer-exact analogue called the Number Theoretic Transform (which itself has issues I’ll discuss in a future, longer article).

Postscript: I’m not sure how widely this technique is used. It appears the NTL library uses the polynomial version of Karatsuba multiplication instead (though it implements FFT elsewhere). However, I know for sure that much software involved in doing fully homomorphic encryption rely on the FFT for performance reasons, and the ones that don’t instead use the NTT.