| # Quickstart tutorial to adding MLIR graph rewrite |
| |
| This document will present a quickstart to adding graph rewrites. We shall start |
| by defining an operation, showing multiple ways to define the rewrite using |
| patterns, as well as defining the rewrite using a graph walker (note: using |
| patterns and the rewrite engine is preferred, showing the walker is for |
| demonstration purposes). |
| |
| See [MLIR specification](../LangRef.md) for more information about MLIR, the |
| structure of the IR, operations, etc. See |
| [Table-driven Operation Definition](../DefiningDialects/Operations.md) and |
| [Declarative Rewrite Rule](../DeclarativeRewrites.md) for the detailed explanation |
| of all available mechanisms for defining operations and rewrites in a |
| table-driven manner. |
| |
| ## Adding operation |
| |
| An operation in MLIR is specified using a definition in |
| [TableGen](https://llvm.org/docs/TableGen/index.html) file. TableGen is a |
| modeling tool to specify the ops and the C++ code to interact with these |
| operations are generated from. To define an operation one needs to specify: |
| |
| * The operation name. This name is a unique identifier of the operation within |
| MLIR. Most operations are within a dialect, so for example one could have |
| `tfl.add` to represent the add operation in the TensorFlow Lite dialect. |
| Instead of repeating the dialect in the op definition, a base class for the |
| op dialect is commonly created that prepends the dialect namespace given an |
| op name. |
| * The traits of the operation. These allow you to specify traits of the |
| operation, such as whether it has side effects or whether it should be |
| verified that the operands and result types are the same. These are backed |
| by C++ traits that perform the verification. |
| * The arguments of the operation. These are the input operands (values at |
| runtime produced by other ops) and attributes (compile time known constant |
| values that affect the behavior of the op) that are the inputs of/define the |
| behavior of the operation. The input operands may be named, the attributes |
| must be named. |
| * The result(s) of the operation. These may again named or not. |
| * Documentation of the operation. This includes a one-line summary as well as |
| a longer human-readable description of the operation. |
| * Dialect specific information. Additional information could be added to the |
| operation definition that are only used by dialect specific drivers. These |
| are ignored by the main op and doc generators, but could be used in, say, |
| the translation from a dialect to another representation. |
| |
| ```tablegen |
| def TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu", |
| [NoMemoryEffect, SameValueType]>, |
| Results<(outs Tensor)> { |
| let arguments = (ins |
| F32Tensor:$x, |
| // Slope of the activation function at x < 0. |
| F32Attr:$alpha |
| ); |
| |
| let summary = "Leaky ReLU operator"; |
| let description = [{ |
| Element-wise Leaky ReLU operator |
| x -> x >= 0 ? x : (alpha * x) |
| }]; |
| |
| // TFLite specific attribute that is used when generating the output |
| // flatbuffer. |
| let hasOptions = 1; |
| } |
| ``` |
| |
| Note in the above the result types and inputs are specified in different ways, |
| one by way of trait and the other by way of let. It is possible to specify both |
| in either way. |
| |
| <!-- TODO: Define a style convention. --> |
| |
| Operations can also have custom parser, printer, builder, verifier, constant |
| folder, or canonicalizer. These require specifying additional C++ methods to |
| invoke for additional functionality. For example, if an operation is marked to |
| have a folder, the constant folder also needs to be added, e.g.,: |
| |
| ```c++ |
| OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) { |
| if (unable_to_fold) |
| return {}; |
| .... |
| return val; |
| } |
| ``` |
| |
| ## Adding patterns |
| |
| There are multiple forms of graph rewrite that can be performed in MLIR. One of |
| the most common is DAG tile to DAG tile rewrite. Patterns provide a concise way |
| to express this transformation as a pair of source pattern to match and |
| resultant pattern. There are both the C++ classes to represent this |
| transformation, as well as the patterns in TableGen from which these can be |
| generated. |
| |
| ### TableGen patterns |
| |
| Let us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to |
| TensorFlow Lite's `LeakyRelu`: |
| |
| ```tablegen |
| def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)> |
| ``` |
| |
| The pattern is specified by instantiating a `Pat` with a source and result DAG. |
| The arguments in the source pattern is captured and can be used in the result |
| pattern. This is a simple pattern as we have a 1:1 mapping and the attribute |
| does not need to be transformed (e.g., both have a floating point attribute for |
| alpha). The names of the attributes specified in the pattern is for |
| matching/referencing and need not match the original attribute name in the op |
| definition but the order of arguments of the dags do need to match. |
| |
| To specify a pattern, both the source and resultant ops need to be defined using |
| TableGen. |
| |
| If this were a more advance pattern that the current framework could not express |
| as destination then one could use a general native code fallback method. This |
| consists of defining a pattern as well as adding a C++ function to perform the |
| replacement: |
| |
| ```tablegen |
| def createTFLLeakyRelu : NativeCodeCall< |
| "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">; |
| |
| def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), |
| (createTFLLeakyRelu $old_value, $arg, $a)>; |
| ``` |
| |
| ```c++ |
| static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, |
| Value operand, Attribute attr) { |
| return rewriter.create<mlir::TFL::LeakyReluOp>( |
| op->getLoc(), operands[0].getType(), /*arg=*/operands[0], |
| /*alpha=*/attrs[0].cast<FloatAttr>()); |
| } |
| ``` |
| |
| This allows for arbitrarily complex builders. Input pattern side one can express |
| multi-op patterns with constraints on input operands and attributes. But input |
| patterns cannot yet express constraints across multiple operands/attributes. |
| |
| ### Register the pattern |
| |
| The file containing the patterns need to be processed using `mlir-tblgen` |
| `-gen-rewriters` during compilation time. It can be invoked with the following |
| configuration in CMake: |
| |
| ```cmake |
| set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>) |
| mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters) |
| add_public_tablegen_target(<name-of-the-cmake-target>) |
| ``` |
| |
| Then you can `#include` the generated file in any C++ implementation file you |
| like. (You will also need to make sure the library depends on the CMake target |
| defined in the above.) The generated file will have a `populateWithGenerated( |
| RewritePatternSet &patterns)` function that you can |
| use to collect all the generated patterns inside `patterns` and then use |
| `patterns` in any pass you would like. |
| |
| ### Simple C++ `matchAndRewrite` style specifications |
| |
| Many simple rewrites can be expressed with a `matchAndRewrite` style of |
| pattern, e.g. when converting a multiply by a power of two into a shift. For |
| these cases, the you can define the pattern as a simple function: |
| |
| ```c++ |
| static LogicalResult |
| convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( |
| op, op->getResult(0).getType(), op->getOperand(0), |
| /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); |
| return success(); |
| } |
| |
| void populateRewrites(RewritePatternSet &patternSet) { |
| // Add it to a pattern set. |
| patternSet.add(convertTFLeakyRelu); |
| } |
| ``` |
| |
| ODS provides a simple way to define a function-style canonicalization for your |
| operation. In the TableGen definition of the op, specify |
| `let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in |
| your .cpp file: |
| |
| ```c++ |
| // Example from the CIRCT project which has a variadic integer multiply. |
| LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { |
| auto inputs = op.inputs(); |
| APInt value; |
| |
| // mul(x, c) -> shl(x, log2(c)), where c is a power of two. |
| if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && |
| value.isPowerOf2()) { |
| auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(), |
| value.exactLogBase2()); |
| auto shlOp = |
| rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift); |
| rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), |
| ArrayRef<Value>(shlOp)); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| ``` |
| |
| However, you may want the full generality of canonicalization patterns, for that |
| you can specify an arbitrary list of `RewritePattern`s. |
| |
| ### Fully general C++ `RewritePattern` specifications |
| |
| In case ODS patterns and `matchAndRewrite`-style functions are not sufficient |
| you can also specify rewrites as a general set of `RewritePattern`s: |
| |
| ```c++ |
| struct ConvertTFLeakyRelu : public RewritePattern { |
| ConvertTFLeakyRelu(MLIRContext *context) |
| : RewritePattern("tf.LeakyRelu", 1, context) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>( |
| op, op->getResult(0).getType(), op->getOperand(0), |
| /*alpha=*/op->getAttrOfType<FloatAttr>("alpha")); |
| return success(); |
| } |
| }; |
| ``` |
| |
| In the C++ rewrite the static benefit of the rewrite pattern is specified at |
| construction. While in the pattern generator a simple heuristic is currently |
| employed based around the number of ops matched and replaced. |
| |
| The above rule did not capture the matching operands/attributes, but in general |
| the `match` function in a multi-step rewrite may populate and return a |
| `PatternState` (or class derived from one) to pass information extracted during |
| matching to the rewrite. A single-step rewrite with the `matchAndRewrite` |
| function has the benefit of being able to directly use any values created when |
| matching; removing the need for `PatternState`. |
| |
| ## Testing |
| |
| MLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated |
| Testing) tool for performing testing. Testing is performed by way of creating |
| the input IR file, running a transformation and then verifying the output IR. |
| C++ unit tests are the exception, with the IR transformation serving as the core |
| testing mechanism. This results in fewer binaries that need to be built (and |
| linked) and forces to focus on the representation as an important piece. |
| |
| For the legalization transform above we would have a test (probably as part of |
| the legalization pass test in TensorFlow Lite) such as: |
| |
| ```mlir |
| // RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s |
| |
| func.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { |
| %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32> |
| return %2: tensor<1xf32> |
| |
| // CHECK-LABEL: LeakyRelu |
| // CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32> |
| } |
| ``` |
| |
| The RUN command at the top results in running the `mlir-opt` binary (which is |
| compiler writer tool to exercise different registered passes) to invoke the |
| optimization pass this transform was added as part of on the current file and to |
| verify its output using `FileCheck`. `FileCheck` is textual output verifier. In |
| particular it uses the CHECK expressions to verify the given output is produced. |
| |
| There can be multiple RUN commands with different corresponding CHECK prefixes. |
| And in addition multiple independent tests separated by `// -----` and |
| `mlir-opt` invoked with `-split-input-file` flag. This is especially useful for |
| error testing. |
| |
| This results in very simple, directed testing without need to work around |
| constant propagation or other, unrelated, optimization passes. |
| |
| ## Adding optimization pass |
| |
| Optimization passes that do not fit/are difficult to specify in the above |
| structure can be specified as general iterations across modules/functions. See |
| [Writing a Pass](../PassManagement.md) for a general overview and introduction to |
| optimization passes in MLIR. |