Rewrite ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceSumSquare in the ONNX dialect (#38)

* Rewrite ReduceSumSquare

* Edit gen_doc.py

* Revise the code

* Do shape inference after canonicalization so that there is no need to implement shape inference of rewritten ops

* Rewrite ReduceL2

* Add onnx_rewrite.cpp for all rewriting for ONNX ops

* Rewrite ReduceL1, ReduceLogSum, ReduceLogSumExp

* Edit comments

* Change the use of -> to .

* Checkout gen_doc.py from the master branch

* Use emplace_back instead of push_back

* Revise the code

* Edit comments

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-01-31 20:00:39 +09:00 committed by GitHub
parent 0d77840969
commit 2b56c09454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 351 additions and 2 deletions

View File

@ -47,7 +47,8 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze'] 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
'ReduceLogSumExp', 'ReduceSumSquare']
manual_code_in_op_def = dict([ manual_code_in_op_def = dict([
('DummyExample', ' let extraClassDeclaration = [{ \n'+ ('DummyExample', ' let extraClassDeclaration = [{ \n'+

View File

@ -10,6 +10,7 @@ add_library(compiler
pass/shape_inference_interface.hpp pass/shape_inference_interface.hpp
dialect/onnx/onnxop.inc dialect/onnx/onnxop.inc
pass/onnx_combine.cpp pass/onnx_combine.cpp
pass/onnx_rewrite.cpp
pass/passes.hpp) pass/passes.hpp)
# Include root src directory. # Include root src directory.

View File

@ -2266,6 +2266,7 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",
def ONNXReduceL1Op:ONNX_Op<"ReduceL1", def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL1 operation"; let summary = "ONNX ReduceL1 operation";
let description = [{ let description = [{
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted" "Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
@ -2283,6 +2284,7 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
def ONNXReduceL2Op:ONNX_Op<"ReduceL2", def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL2 operation"; let summary = "ONNX ReduceL2 operation";
let description = [{ let description = [{
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted" "Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
@ -2300,6 +2302,7 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSum operation"; let summary = "ONNX ReduceLogSum operation";
let description = [{ let description = [{
"Computes the log sum of the input tensor's element along the provided axes. The resulted" "Computes the log sum of the input tensor's element along the provided axes. The resulted"
@ -2317,6 +2320,7 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSumExp operation"; let summary = "ONNX ReduceLogSumExp operation";
let description = [{ let description = [{
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted" "Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
@ -2419,6 +2423,7 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",
def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceSumSquare operation"; let summary = "ONNX ReduceSumSquare operation";
let description = [{ let description = [{
"Computes the sum square of the input tensor's element along the provided axes. The resulted" "Computes the sum square of the input tensor's element along the provided axes. The resulted"

View File

@ -117,8 +117,8 @@ int main(int argc, char *argv[]) {
} }
mlir::PassManager pm(&context); mlir::PassManager pm(&context);
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createShapeInferencePass());
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());

295
src/pass/onnx_rewrite.cpp Normal file
View File

@ -0,0 +1,295 @@
//===- onnx_rewrite.cpp - ONNX High Level Optimizer -----------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a set of rewriters for operations in the ONNX dialect
// that can be rewritten by using other ONNX operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "src/dialect/onnx/onnx_ops.hpp"
using namespace mlir;
namespace {
// There are two ways to write rewrite rules:
// - Declarative manner: specify rewrite rules in a TableGen record, and
// - Manual Manner: subclass the mlir::RewritePattern.
//
// We prefer to use the former way as much as possible. However, there is a
// limitation about operation definition specification (ODS) in TableGen that
// requires us to write custom builders, that is
// "all ODS-generated `build()` methods require specifying the result type(s),
// unless the op has known traits like `SameOperandsAndResultType` that we can
// use to auto-generate a `build()` method with result type deduction".
//
// More information about the limitation can be found here:
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations
//
// Currently, we use the latter way of writing rewrite rules. There are two
// reasons for this decision:
// - To insert custom builders for operations, it is better to change the script
// gen_doc.py to generate all possibles custom builders for a large class of
// operations. At the time of this patch created, the gen_doc.py was changing,
// so we decided to write manually to reduce conflicts.
// - In declarative rewriting, we should deal with optional attributes. E.g. for
// to handle optional attributes, but I haven't tried it yet.
//
// Once we have done the above issues, we will switch to use the declarative
// manner.
//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
struct ReduceL1OpPattern : public RewritePattern {
ReduceL1OpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceL1Op::getOperationName(),
{ONNXAbsOp::getOperationName(),
ONNXReduceSumOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXAbsOp absOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
absOp = rewriter.create<ONNXAbsOp>(
loc, UnrankedTensorType::get(elementType), opInput);
}
ONNXReduceSumOp sumOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
SmallVector<Value, 1> values;
values.emplace_back(absOp.getResult());
SmallVector<NamedAttribute, 4> attrs;
for (auto attr : opAttrs) {
attrs.emplace_back(attr);
}
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
}
rewriter.replaceOp(op, sumOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
struct ReduceL2OpPattern : public RewritePattern {
ReduceL2OpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceL2Op::getOperationName(),
{ONNXSqrtOp::getOperationName(),
ONNXReduceSumSquareOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXReduceSumSquareOp sumSquareOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
sumSquareOp = rewriter.create<ONNXReduceSumSquareOp>(
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
}
ONNXSqrtOp sqrtOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
sqrtOp = rewriter.create<ONNXSqrtOp>(loc, types, sumSquareOp.getResult());
}
rewriter.replaceOp(op, sqrtOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
struct ReduceLogSumOpPattern : public RewritePattern {
ReduceLogSumOpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceLogSumOp::getOperationName(),
{ONNXReduceSumOp::getOperationName(),
ONNXLogOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXReduceSumOp sumOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
sumOp = rewriter.create<ONNXReduceSumOp>(
loc, UnrankedTensorType::get(elementType), opInput, opAttrs);
}
ONNXLogOp logOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
logOp = rewriter.create<ONNXLogOp>(loc, types, sumOp.getResult());
}
rewriter.replaceOp(op, logOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
struct ReduceLogSumExpOpPattern : public RewritePattern {
ReduceLogSumExpOpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceLogSumExpOp::getOperationName(),
{ONNXExpOp::getOperationName(),
ONNXReduceLogSumOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXExpOp expOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
expOp = rewriter.create<ONNXExpOp>(
loc, UnrankedTensorType::get(elementType), opInput);
}
ONNXReduceLogSumOp logSumOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
SmallVector<Value, 1> values;
values.emplace_back(expOp.getResult());
SmallVector<NamedAttribute, 4> attrs;
for (auto attr : opAttrs) {
attrs.emplace_back(attr);
}
logSumOp = rewriter.create<ONNXReduceLogSumOp>(loc, types, values, attrs);
}
rewriter.replaceOp(op, logSumOp.getResult());
return matchSuccess();
};
};
//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
struct ReduceSumSquareOpPattern : public RewritePattern {
ReduceSumSquareOpPattern(MLIRContext *context)
: RewritePattern(ONNXReduceSumSquareOp::getOperationName(),
{ONNXMulOp::getOperationName(),
ONNXReduceSumOp::getOperationName()},
1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto opInput = op->getOperands()[0]; // %X
auto opResults = op->getResults();
auto opAttrs = op->getAttrs();
// Rewrite
ONNXMulOp mulOp;
{
auto elementType = opInput.getType().cast<TensorType>().getElementType();
mulOp = rewriter.create<ONNXMulOp>(
loc, UnrankedTensorType::get(elementType), opInput, opInput);
}
ONNXReduceSumOp sumOp;
{
SmallVector<Type, 4> types;
for (auto v : opResults) {
types.emplace_back(v.getType());
}
SmallVector<Value, 1> values;
values.emplace_back(mulOp.getResult());
SmallVector<NamedAttribute, 4> attrs;
for (auto attr : opAttrs) {
attrs.emplace_back(attr);
}
sumOp = rewriter.create<ONNXReduceSumOp>(loc, types, values, attrs);
}
rewriter.replaceOp(op, sumOp.getResult());
return matchSuccess();
};
};
} // end anonymous namespace
/// on the ONNXReduceL1Op.
void ONNXReduceL1Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL1OpPattern>(context);
}
/// on the ONNXReduceL2Op.
void ONNXReduceL2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL2OpPattern>(context);
}
/// on the ONNXReduceLogSumOp.
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumOpPattern>(context);
}
/// on the ONNXReduceLogSumExpOp.
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumExpOpPattern>(context);
}
/// on the ONNXReduceSumSquareOp.
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceSumSquareOpPattern>(context);
}

View File

@ -27,3 +27,50 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%2) : (tensor<10x10xf32>) -> () "std.return"(%2) : (tensor<10x10xf32>) -> ()
} }
// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}