//===- onnx_rewrite.cpp - ONNX High Level Optimizer -----------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements a set of rewriters for operations in the ONNX dialect // that can be rewritten by using other ONNX operations. // //===----------------------------------------------------------------------===// #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "src/dialect/onnx/onnx_ops.hpp" using namespace mlir; namespace { // There are two ways to write rewrite rules: // - Declarative manner: specify rewrite rules in a TableGen record, and // - Manual Manner: subclass the mlir::RewritePattern. // // We prefer to use the former way as much as possible. However, there is a // limitation about operation definition specification (ODS) in TableGen that // requires us to write custom builders, that is // "all ODS-generated `build()` methods require specifying the result type(s), // unless the op has known traits like `SameOperandsAndResultType` that we can // use to auto-generate a `build()` method with result type deduction". // // More information about the limitation can be found here: // https://github.com/llvm/llvm-project/blob/master/mlir/docs/DeclarativeRewrites.md#building-operations // // Currently, we use the latter way of writing rewrite rules. There are two // reasons for this decision: // - To insert custom builders for operations, it is better to change the script // gen_doc.py to generate all possibles custom builders for a large class of // operations. At the time of this patch created, the gen_doc.py was changing, // so we decided to write manually to reduce conflicts. // - In declarative rewriting, we should deal with optional attributes. E.g. for // to handle optional attributes, but I haven't tried it yet. // // Once we have done the above issues, we will switch to use the declarative // manner. //===----------------------------------------------------------------------===// // ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) //===----------------------------------------------------------------------===// struct ReduceL1OpPattern : public RewritePattern { ReduceL1OpPattern(MLIRContext *context) : RewritePattern(ONNXReduceL1Op::getOperationName(), {ONNXAbsOp::getOperationName(), ONNXReduceSumOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto opInput = op->getOperands()[0]; // %X auto opResults = op->getResults(); auto opAttrs = op->getAttrs(); // Rewrite ONNXAbsOp absOp; { auto elementType = opInput.getType().cast().getElementType(); absOp = rewriter.create( loc, UnrankedTensorType::get(elementType), opInput); } ONNXReduceSumOp sumOp; { SmallVector types; for (auto v : opResults) { types.emplace_back(v.getType()); } SmallVector values; values.emplace_back(absOp.getResult()); SmallVector attrs; for (auto attr : opAttrs) { attrs.emplace_back(attr); } sumOp = rewriter.create(loc, types, values, attrs); } rewriter.replaceOp(op, sumOp.getResult()); return matchSuccess(); }; }; //===----------------------------------------------------------------------===// // ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) //===----------------------------------------------------------------------===// struct ReduceL2OpPattern : public RewritePattern { ReduceL2OpPattern(MLIRContext *context) : RewritePattern(ONNXReduceL2Op::getOperationName(), {ONNXSqrtOp::getOperationName(), ONNXReduceSumSquareOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto opInput = op->getOperands()[0]; // %X auto opResults = op->getResults(); auto opAttrs = op->getAttrs(); // Rewrite ONNXReduceSumSquareOp sumSquareOp; { auto elementType = opInput.getType().cast().getElementType(); sumSquareOp = rewriter.create( loc, UnrankedTensorType::get(elementType), opInput, opAttrs); } ONNXSqrtOp sqrtOp; { SmallVector types; for (auto v : opResults) { types.emplace_back(v.getType()); } sqrtOp = rewriter.create(loc, types, sumSquareOp.getResult()); } rewriter.replaceOp(op, sqrtOp.getResult()); return matchSuccess(); }; }; //===----------------------------------------------------------------------===// // ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) //===----------------------------------------------------------------------===// struct ReduceLogSumOpPattern : public RewritePattern { ReduceLogSumOpPattern(MLIRContext *context) : RewritePattern(ONNXReduceLogSumOp::getOperationName(), {ONNXReduceSumOp::getOperationName(), ONNXLogOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto opInput = op->getOperands()[0]; // %X auto opResults = op->getResults(); auto opAttrs = op->getAttrs(); // Rewrite ONNXReduceSumOp sumOp; { auto elementType = opInput.getType().cast().getElementType(); sumOp = rewriter.create( loc, UnrankedTensorType::get(elementType), opInput, opAttrs); } ONNXLogOp logOp; { SmallVector types; for (auto v : opResults) { types.emplace_back(v.getType()); } logOp = rewriter.create(loc, types, sumOp.getResult()); } rewriter.replaceOp(op, logOp.getResult()); return matchSuccess(); }; }; //===----------------------------------------------------------------------===// // ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) //===----------------------------------------------------------------------===// struct ReduceLogSumExpOpPattern : public RewritePattern { ReduceLogSumExpOpPattern(MLIRContext *context) : RewritePattern(ONNXReduceLogSumExpOp::getOperationName(), {ONNXExpOp::getOperationName(), ONNXReduceLogSumOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto opInput = op->getOperands()[0]; // %X auto opResults = op->getResults(); auto opAttrs = op->getAttrs(); // Rewrite ONNXExpOp expOp; { auto elementType = opInput.getType().cast().getElementType(); expOp = rewriter.create( loc, UnrankedTensorType::get(elementType), opInput); } ONNXReduceLogSumOp logSumOp; { SmallVector types; for (auto v : opResults) { types.emplace_back(v.getType()); } SmallVector values; values.emplace_back(expOp.getResult()); SmallVector attrs; for (auto attr : opAttrs) { attrs.emplace_back(attr); } logSumOp = rewriter.create(loc, types, values, attrs); } rewriter.replaceOp(op, logSumOp.getResult()); return matchSuccess(); }; }; //===----------------------------------------------------------------------===// // ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) //===----------------------------------------------------------------------===// struct ReduceSumSquareOpPattern : public RewritePattern { ReduceSumSquareOpPattern(MLIRContext *context) : RewritePattern(ONNXReduceSumSquareOp::getOperationName(), {ONNXMulOp::getOperationName(), ONNXReduceSumOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto opInput = op->getOperands()[0]; // %X auto opResults = op->getResults(); auto opAttrs = op->getAttrs(); // Rewrite ONNXMulOp mulOp; { auto elementType = opInput.getType().cast().getElementType(); mulOp = rewriter.create( loc, UnrankedTensorType::get(elementType), opInput, opInput); } ONNXReduceSumOp sumOp; { SmallVector types; for (auto v : opResults) { types.emplace_back(v.getType()); } SmallVector values; values.emplace_back(mulOp.getResult()); SmallVector attrs; for (auto attr : opAttrs) { attrs.emplace_back(attr); } sumOp = rewriter.create(loc, types, values, attrs); } rewriter.replaceOp(op, sumOp.getResult()); return matchSuccess(); }; }; //===----------------------------------------------------------------------===// // Rewrite: // %0 = onnx.ConvNoBiasOp(%D : tensor, %K) // {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> // tensor // // as: // %0 = onnx.PadConstantValuePasOp(%D) // {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> // tensor // %1 = onnx.ConvNoBias(%0 : tensor, %K) {pads = [0, ..., 0]} -> // tensor //===----------------------------------------------------------------------===// struct SplitConvOpPattern : public RewritePattern { SplitConvOpPattern(MLIRContext *context) : RewritePattern(ONNXConvNoBiasOp::getOperationName(), {ONNXPadConstantValuePadOp::getOperationName(), ONNXConvNoBiasOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); // If convolution does not use padding then no rewrite is required. ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); auto padsAttribute = convOp.padsAttr(); if (!padsAttribute) return matchFailure(); // If auto_pad is VALID then no padding happens and no rewrite isrequired. auto autoPad = convOp.auto_pad(); if (autoPad == "VALID") return matchFailure(); auto data = op->getOperands()[0]; auto inputShape = data.getType().cast().getShape(); // Dimensionality of the input: // inputRank // |----------------------| // D : (N x C x D1 x D2 x ... DK) // |______________| // inputDims // int64_t inputRank = inputShape.size(); int64_t inputDims = inputRank - 2; // If all pads values are equal to zero then no rewrite is required. bool allZeros = true; for (auto padsValue : padsAttribute.getValue()) { if (padsValue.cast().getInt() > 0) { allZeros = false; break; } } if (allZeros) return matchFailure(); // Create padding vector for the explicit padding op attribute. SmallVector pads(2 * inputRank, 0); SmallVector outPaddedShape(inputRank, 0); outPaddedShape[0] = inputShape[0]; outPaddedShape[1] = inputShape[1]; for (int i = 0; i < inputDims; ++i) { int64_t beginPad = padsAttribute.getValue()[i].cast().getInt(); int64_t endPad = padsAttribute.getValue()[inputDims + i].cast().getInt(); pads[i + 2] = beginPad; pads[inputRank + i + 2] = endPad; outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad; } // Create padding operation. auto inputElemType = data.getType().cast().getElementType(); ONNXPadConstantValuePadOp paddingOp = rewriter.create( loc, RankedTensorType::get(outPaddedShape, inputElemType), data, rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0), StringAttr::get("constant", loc->getContext())); SmallVector newConvPads(2 * inputDims, 0); auto tensorType = (*op->result_type_begin()).cast(); ONNXConvNoBiasOp newConvOp = rewriter.create( loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], convOp.auto_padAttr(), convOp.dilationsAttr(), convOp.groupAttr(), convOp.kernel_shapeAttr(), rewriter.getI64ArrayAttr(newConvPads), convOp.stridesAttr()); rewriter.replaceOp(op, newConvOp.getResult()); return matchSuccess(); }; }; } // end anonymous namespace /// on the ONNXReduceL1Op. void ONNXReduceL1Op::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } /// on the ONNXReduceL2Op. void ONNXReduceL2Op::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } /// on the ONNXReduceLogSumOp. void ONNXReduceLogSumOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } /// on the ONNXReduceLogSumExpOp. void ONNXReduceLogSumExpOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } /// on the ONNXReduceSumSquareOp. void ONNXReduceSumSquareOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } /// on the ONNXReduceSumSquareOp. void ONNXConvNoBiasOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); }