From 2b56c094542f960ec283bd509170c1a136c678ff Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 31 Jan 2020 20:00:39 +0900 Subject: [PATCH] Rewrite ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceSumSquare in the ONNX dialect (#38) * Rewrite ReduceSumSquare * Edit gen_doc.py * Revise the code * Do shape inference after canonicalization so that there is no need to implement shape inference of rewritten ops * Rewrite ReduceL2 * Add onnx_rewrite.cpp for all rewriting for ONNX ops * Rewrite ReduceL1, ReduceLogSum, ReduceLogSumExp * Edit comments * Change the use of -> to . * Checkout gen_doc.py from the master branch * Use emplace_back instead of push_back * Revise the code * Edit comments Co-authored-by: Tian Jin --- doc/gen_doc.py | 3 +- src/CMakeLists.txt | 1 + src/dialect/onnx/onnxop.inc | 5 + src/main.cpp | 2 +- src/pass/onnx_rewrite.cpp | 295 ++++++++++++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 47 ++++ 6 files changed, 351 insertions(+), 2 deletions(-) create mode 100644 src/pass/onnx_rewrite.cpp diff --git a/doc/gen_doc.py b/doc/gen_doc.py index e5c2b8a..428c360 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -47,7 +47,8 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze'] -CanonicalList=['Add', 'Identity'] +CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', + 'ReduceLogSumExp', 'ReduceSumSquare'] manual_code_in_op_def = dict([ ('DummyExample', ' let extraClassDeclaration = [{ \n'+ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 73c0ffe..8531f57 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(compiler pass/shape_inference_interface.hpp dialect/onnx/onnxop.inc pass/onnx_combine.cpp + pass/onnx_rewrite.cpp pass/passes.hpp) # Include root src directory. diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index d9d3c8b..02ce5e7 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -2266,6 +2266,7 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal", def ONNXReduceL1Op:ONNX_Op<"ReduceL1", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX ReduceL1 operation"; let description = [{ "Computes the L1 norm of the input tensor's element along the provided axes. The resulted" @@ -2283,6 +2284,7 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1", def ONNXReduceL2Op:ONNX_Op<"ReduceL2", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX ReduceL2 operation"; let description = [{ "Computes the L2 norm of the input tensor's element along the provided axes. The resulted" @@ -2300,6 +2302,7 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2", def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX ReduceLogSum operation"; let description = [{ "Computes the log sum of the input tensor's element along the provided axes. The resulted" @@ -2317,6 +2320,7 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX ReduceLogSumExp operation"; let description = [{ "Computes the log sum exponent of the input tensor's element along the provided axes. The resulted" @@ -2419,6 +2423,7 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX ReduceSumSquare operation"; let description = [{ "Computes the sum square of the input tensor's element along the provided axes. The resulted" diff --git a/src/main.cpp b/src/main.cpp index 6e2c8e2..f0de7e9 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -117,8 +117,8 @@ int main(int argc, char *argv[]) { } mlir::PassManager pm(&context); - pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createShapeInferencePass()); if (emissionTarget >= EmitMLIR) { pm.addPass(mlir::createLowerToKrnlPass()); diff --git a/src/pass/onnx_rewrite.cpp b/src/pass/onnx_rewrite.cpp new file mode 100644 index 0000000..2f172de --- /dev/null +++ b/src/pass/onnx_rewrite.cpp @@ -0,0 +1,295 @@ +//===- onnx_rewrite.cpp - ONNX High Level Optimizer -----------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a set of rewriters for operations in the ONNX dialect +// that can be rewritten by using other ONNX operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include "src/dialect/onnx/onnx_ops.hpp" + +using namespace mlir; + +namespace { + +// There are two ways to write rewrite rules: +// - Declarative manner: specify rewrite rules in a TableGen record, and +// - Manual Manner: subclass the mlir::RewritePattern. +// +// We prefer to use the former way as much as possible. However, there is a +// limitation about operation definition specification (ODS) in TableGen that +// requires us to write custom builders, that is +// "all ODS-generated `build()` methods require specifying the result type(s), +// unless the op has known traits like `SameOperandsAndResultType` that we can +// use to auto-generate a `build()` method with result type deduction". +// +// More information about the limitation can be found here: +// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations +// +// Currently, we use the latter way of writing rewrite rules. There are two +// reasons for this decision: +// - To insert custom builders for operations, it is better to change the script +// gen_doc.py to generate all possibles custom builders for a large class of +// operations. At the time of this patch created, the gen_doc.py was changing, +// so we decided to write manually to reduce conflicts. +// - In declarative rewriting, we should deal with optional attributes. E.g. for +// to handle optional attributes, but I haven't tried it yet. +// +// Once we have done the above issues, we will switch to use the declarative +// manner. + +//===----------------------------------------------------------------------===// +// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) +//===----------------------------------------------------------------------===// +struct ReduceL1OpPattern : public RewritePattern { + ReduceL1OpPattern(MLIRContext *context) + : RewritePattern(ONNXReduceL1Op::getOperationName(), + {ONNXAbsOp::getOperationName(), + ONNXReduceSumOp::getOperationName()}, + 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto opInput = op->getOperands()[0]; // %X + auto opResults = op->getResults(); + auto opAttrs = op->getAttrs(); + + // Rewrite + ONNXAbsOp absOp; + { + auto elementType = opInput.getType().cast().getElementType(); + absOp = rewriter.create( + loc, UnrankedTensorType::get(elementType), opInput); + } + + ONNXReduceSumOp sumOp; + { + SmallVector types; + for (auto v : opResults) { + types.emplace_back(v.getType()); + } + + SmallVector values; + values.emplace_back(absOp.getResult()); + + SmallVector attrs; + for (auto attr : opAttrs) { + attrs.emplace_back(attr); + } + + sumOp = rewriter.create(loc, types, values, attrs); + } + + rewriter.replaceOp(op, sumOp.getResult()); + return matchSuccess(); + }; +}; + +//===----------------------------------------------------------------------===// +// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) +//===----------------------------------------------------------------------===// +struct ReduceL2OpPattern : public RewritePattern { + ReduceL2OpPattern(MLIRContext *context) + : RewritePattern(ONNXReduceL2Op::getOperationName(), + {ONNXSqrtOp::getOperationName(), + ONNXReduceSumSquareOp::getOperationName()}, + 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto opInput = op->getOperands()[0]; // %X + auto opResults = op->getResults(); + auto opAttrs = op->getAttrs(); + + // Rewrite + ONNXReduceSumSquareOp sumSquareOp; + { + auto elementType = opInput.getType().cast().getElementType(); + sumSquareOp = rewriter.create( + loc, UnrankedTensorType::get(elementType), opInput, opAttrs); + } + + ONNXSqrtOp sqrtOp; + { + SmallVector types; + for (auto v : opResults) { + types.emplace_back(v.getType()); + } + sqrtOp = rewriter.create(loc, types, sumSquareOp.getResult()); + } + + rewriter.replaceOp(op, sqrtOp.getResult()); + return matchSuccess(); + }; +}; + +//===----------------------------------------------------------------------===// +// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) +//===----------------------------------------------------------------------===// +struct ReduceLogSumOpPattern : public RewritePattern { + ReduceLogSumOpPattern(MLIRContext *context) + : RewritePattern(ONNXReduceLogSumOp::getOperationName(), + {ONNXReduceSumOp::getOperationName(), + ONNXLogOp::getOperationName()}, + 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto opInput = op->getOperands()[0]; // %X + auto opResults = op->getResults(); + auto opAttrs = op->getAttrs(); + + // Rewrite + ONNXReduceSumOp sumOp; + { + auto elementType = opInput.getType().cast().getElementType(); + sumOp = rewriter.create( + loc, UnrankedTensorType::get(elementType), opInput, opAttrs); + } + + ONNXLogOp logOp; + { + SmallVector types; + for (auto v : opResults) { + types.emplace_back(v.getType()); + } + logOp = rewriter.create(loc, types, sumOp.getResult()); + } + + rewriter.replaceOp(op, logOp.getResult()); + return matchSuccess(); + }; +}; + +//===----------------------------------------------------------------------===// +// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) +//===----------------------------------------------------------------------===// +struct ReduceLogSumExpOpPattern : public RewritePattern { + ReduceLogSumExpOpPattern(MLIRContext *context) + : RewritePattern(ONNXReduceLogSumExpOp::getOperationName(), + {ONNXExpOp::getOperationName(), + ONNXReduceLogSumOp::getOperationName()}, + 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto opInput = op->getOperands()[0]; // %X + auto opResults = op->getResults(); + auto opAttrs = op->getAttrs(); + + // Rewrite + ONNXExpOp expOp; + { + auto elementType = opInput.getType().cast().getElementType(); + expOp = rewriter.create( + loc, UnrankedTensorType::get(elementType), opInput); + } + + ONNXReduceLogSumOp logSumOp; + { + SmallVector types; + for (auto v : opResults) { + types.emplace_back(v.getType()); + } + + SmallVector values; + values.emplace_back(expOp.getResult()); + + SmallVector attrs; + for (auto attr : opAttrs) { + attrs.emplace_back(attr); + } + logSumOp = rewriter.create(loc, types, values, attrs); + } + + rewriter.replaceOp(op, logSumOp.getResult()); + return matchSuccess(); + }; +}; + +//===----------------------------------------------------------------------===// +// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) +//===----------------------------------------------------------------------===// +struct ReduceSumSquareOpPattern : public RewritePattern { + ReduceSumSquareOpPattern(MLIRContext *context) + : RewritePattern(ONNXReduceSumSquareOp::getOperationName(), + {ONNXMulOp::getOperationName(), + ONNXReduceSumOp::getOperationName()}, + 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto opInput = op->getOperands()[0]; // %X + auto opResults = op->getResults(); + auto opAttrs = op->getAttrs(); + + // Rewrite + ONNXMulOp mulOp; + { + auto elementType = opInput.getType().cast().getElementType(); + mulOp = rewriter.create( + loc, UnrankedTensorType::get(elementType), opInput, opInput); + } + + ONNXReduceSumOp sumOp; + { + SmallVector types; + for (auto v : opResults) { + types.emplace_back(v.getType()); + } + + SmallVector values; + values.emplace_back(mulOp.getResult()); + + SmallVector attrs; + for (auto attr : opAttrs) { + attrs.emplace_back(attr); + } + sumOp = rewriter.create(loc, types, values, attrs); + } + + rewriter.replaceOp(op, sumOp.getResult()); + return matchSuccess(); + }; +}; +} // end anonymous namespace + +/// on the ONNXReduceL1Op. +void ONNXReduceL1Op::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} +/// on the ONNXReduceL2Op. +void ONNXReduceL2Op::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +/// on the ONNXReduceLogSumOp. +void ONNXReduceLogSumOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +/// on the ONNXReduceLogSumExpOp. +void ONNXReduceLogSumExpOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +/// on the ONNXReduceSumSquareOp. +void ONNXReduceSumSquareOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 9697c8e..f791a59 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -27,3 +27,50 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> %2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> "std.return"(%2) : (tensor<10x10xf32>) -> () } + +// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducel1(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor) -> tensor<*xf32> + // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducel2(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducelogsum(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor) -> tensor<*xf32> + // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducelogsumexp(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor) -> tensor<*xf32> + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> +} + +// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor) -> tensor<*xf32> +func @test_reducesumsquare(%arg0 : tensor) -> tensor<*xf32> { + %0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor)-> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> +}