From 7c1e67898d89266e28762a08a11cb2cba26216f4 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 18 Aug 2020 17:41:40 +0900 Subject: [PATCH] Fuse convolution and batch normalization (#253) * Rewriting rule * Fix formulas * Reuse op results * Const propagation for Div and Sqrt * Explicitly use ONNXConstantOp * Minor revise * Const propagation for unsqueeze * Do const propagationnce all tensors have inferred shapes * LIT tests for fusion * Add LIT tests for constant propagation on Div, Sqrt, and Unsqueeze * Missing dash Co-authored-by: Tian Jin --- src/Dialect/ONNX/ONNXOps.td | 1 + src/Dialect/ONNX/ONNXOps.td.inc | 102 +++++++++++++++------- src/MainUtils.cpp | 5 ++ src/Transform/ONNX/ConstProp.cpp | 55 ++++++++++++ src/Transform/ONNX/ConstProp.td | 43 ++++++++- src/Transform/ONNX/Rewrite.cpp | 36 ++++++++ src/Transform/ONNX/Rewrite.td | 72 +++++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 85 ++++++++++++++++++ test/mlir/onnx/onnx_constprop.mlir | 44 ++++++++++ utils/gen_onnx_mlir.py | 3 +- 10 files changed, 409 insertions(+), 37 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps.td b/src/Dialect/ONNX/ONNXOps.td index f926760..93d0f17 100644 --- a/src/Dialect/ONNX/ONNXOps.td +++ b/src/Dialect/ONNX/ONNXOps.td @@ -141,6 +141,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX BatchNormalization operation in test mode"; + let hasCanonicalizer = 1; let description = [{ "Carries out batch normalization as described in the paper" "https://arxiv.org/abs/1502.03167. Depending on the mode it is being run," diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 250c69c..bbed3cc 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -2894,17 +2894,29 @@ def ONNXNegOp:ONNX_Op<"Neg", }]; let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X); let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 1; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{ + auto elementType = X.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), X); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", @@ -5098,17 +5110,29 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", }]; let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X); let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 1; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{ + auto elementType = X.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), X); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSqueezeOp:ONNX_Op<"Squeeze", @@ -5574,17 +5598,29 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>, AnyMemRef]>:$data, I64ArrayAttr:$axes); let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>, AnyMemRef]>:$expanded); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 1; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes", [{ + auto elementType = data.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), data, axes); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXUpsampleOp:ONNX_Op<"Upsample", diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index bebad23..01793c0 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -393,6 +393,11 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) { pm.addPass(mlir::createAttributePromotionPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createAttributePromotionPass()); + // There are more opportunities for const propagation once all tensors have + // inferred shapes. + pm.addPass(mlir::createConstPropONNXToONNXPass()); + // Clean dead code. + pm.addPass(mlir::createSymbolDCEPass()); } void addONNXToKrnlPasses(mlir::PassManager &pm) { diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index a48a41b..773fe64 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -21,6 +21,8 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Pass/Passes.hpp" +#include + using namespace mlir; namespace { @@ -120,6 +122,26 @@ Attribute ComputeConstPropElementwiseBinary( llvm_unreachable("constant propagation for MulOp: unkonwn data type"); } +template <> +Attribute ComputeConstPropElementwiseBinary( + PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, + Attribute &secondAttr) { + if (elementType.isa()) { + double lhsVal = lhsAttr.cast().getValueAsDouble(); + double rhsVal = secondAttr.cast().getValueAsDouble(); + assert(rhsVal != 0 && "division by a zero"); + double res = lhsVal / rhsVal; + return rewriter.getFloatAttr(elementType, res); + } + if (elementType.isa()) { + uint64_t lhsVal = lhsAttr.cast().getInt(); + uint64_t rhsVal = secondAttr.cast().getInt(); + assert(rhsVal != 0 && "division by a zero"); + uint64_t res = lhsVal / rhsVal; + return rewriter.getIntegerAttr(elementType, res); + } + llvm_unreachable("constant propagation for DivOp: unkonwn data type"); +} // Recursively process one dimension in the rank of the two references. There // can be one of 3 cases. // 1) We have fully defined accesses for both operands, launch the computations. @@ -246,6 +268,17 @@ Attribute ComputeConstPropElementwiseUnary( llvm_unreachable("constant propagation for NegOp: unkonwn data type"); } +template <> +Attribute ComputeConstPropElementwiseUnary( + PatternRewriter &rewriter, Type elementType, Attribute &attr) { + if (elementType.isa()) { + double val = attr.cast().getValueAsDouble(); + double res = sqrt(val); + return rewriter.getFloatAttr(elementType, res); + } + llvm_unreachable("constant propagation for SqrtOp: unkonwn data type"); +} + template void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter, std::vector &resVector, DenseElementsAttr &attr, @@ -340,6 +373,28 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter, return DenseElementsAttr::get(resType, resRef); } +//===----------------------------------------------------------------------===// +// Code to perform constant propagation for unsqueeze. +//===----------------------------------------------------------------------===// + +DenseElementsAttr ConstPropUnsqueeze( + PatternRewriter &rewriter, Value resOperand, Attribute &attr) { + // Read dense attribute, the constant tensor we are transforming. + DenseElementsAttr denseAttr = + attr.dyn_cast_or_null(); + assert(denseAttr && "expected dense attribute"); + ShapedType resType = resOperand.getType().cast(); + + // Unqueeze does not change the order of access, so just copy the whole data. + std::vector resVector; + for (auto value : denseAttr.getValues()) { + resVector.emplace_back(value); + } + + ArrayRef resRef(resVector); + return DenseElementsAttr::get(resType, resRef); +} + //===----------------------------------------------------------------------===// // Pattern definition. //===----------------------------------------------------------------------===// diff --git a/src/Transform/ONNX/ConstProp.td b/src/Transform/ONNX/ConstProp.td index 231a94a..cecbb38 100644 --- a/src/Transform/ONNX/ConstProp.td +++ b/src/Transform/ONNX/ConstProp.td @@ -60,12 +60,21 @@ def CreateSubOfTwoConst : def CreateNegOfConst : NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; - def CreateMulOfTwoConst : +def CreateSqrtOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + +def CreateMulOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; +def CreateDivOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + def CreateTransposeOfConst : NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">; +def CreateUnsqueezeOfConst: + NativeCodeCall<"ConstPropUnsqueeze($_builder, $0, $1)">; + //===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise ADD operations. //===----------------------------------------------------------------------===// @@ -163,7 +172,14 @@ def SubConstToNeg : Pat< (ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))), [(IsNotAConstant:$x), (AttributeIsNull:$s)]>; - +// Constant Propagation for Sqrt +def SqrtofConst : Pat< + // From onnx.Sqrt(c) + (ONNXSqrtOp (ONNXConstantOp:$constOp $s, $v)), + // To sqrt(c) + (ONNXConstantOp (GetNullAttr), (CreateSqrtOfConst $constOp, $v)), + [(AttributeIsNull:$s)]>; + //===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise MUL operations. // Exactly the same pattern as for the elementwise ADD operations. @@ -232,6 +248,16 @@ def MulConstProp : Pat< // Mulitional constraints (no sparse) [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; +// Constant Propagation for Div +def DivConstProp : Pat< + // From div(c1, c2). + (ONNXDivOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)), + // To c1/c2 + (ONNXConstantOp (GetNullAttr), (CreateDivOfTwoConst $mulOp, $v1, $v2)), + // Division constraints (no sparse) + [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; + + //===----------------------------------------------------------------------===// // Patterns to enable opportunities with Transpose operations. //===----------------------------------------------------------------------===// @@ -244,5 +270,16 @@ def TransposeofConst : Pat< (ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)), [(AttributeIsNull:$s)]>; - +//===----------------------------------------------------------------------===// +// Patterns to enable opportunities with Unsqueeze operations. +//===----------------------------------------------------------------------===// + +def UnsqueezeofConst : Pat< + // From Unsqueeze (c, axis) + (ONNXUnsqueezeOp:$resOp (ONNXConstantOp $s, $v), $_), + // To c' where c' is the unsqueezed value. + (ONNXConstantOp (GetNullAttr), (CreateUnsqueezeOfConst $resOp, $v)), + [(AttributeIsNull:$s)]>; + + #endif // ONNX_CONSTPROP diff --git a/src/Transform/ONNX/Rewrite.cpp b/src/Transform/ONNX/Rewrite.cpp index b390419..a5fbd7e 100644 --- a/src/Transform/ONNX/Rewrite.cpp +++ b/src/Transform/ONNX/Rewrite.cpp @@ -18,6 +18,36 @@ using namespace mlir; namespace { +// Create a DenseElementsAttr from a float attribute. +DenseElementsAttr createDenseElementsAttrFromFloatAttr( + PatternRewriter &rewriter, Type elementType, FloatAttr attr) { + SmallVector dims(1, 1); + SmallVector values(1, attr.getValue().convertToFloat()); + auto tensorType = mlir::RankedTensorType::get(dims, elementType); + return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); +} + +// If 'lhs' is not NoneType, return 'lhs - rhs'. +// Otherwise, return '-rhs'. +Value subtractOrNeg( + PatternRewriter &rewriter, Location loc, Value lhs, Value rhs) { + if (lhs.getType().isa()) { + Value result = rewriter.create(loc, rhs); + return result; + } else { + Value result = rewriter.create(loc, lhs, rhs); + return result; + } +} + +// Create an ArrayAttr of IntergerAttr(s) of values in [1, N]. +ArrayAttr createArrayAttrOfOneToN(PatternRewriter &rewriter, int N) { + SmallVector vals; + for (int i = 1; i <= N; ++i) + vals.emplace_back(i); + return rewriter.getI64ArrayAttr(vals); +} + // Check whether an ArrayAttr contains non-zero values or not. bool hasNonZeroInArrayAttr(ArrayAttr attrs) { bool allZeros = true; @@ -92,3 +122,9 @@ void ONNXConvOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } + +/// on the ONNXBatchNormalizationTestModeOp. +void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} diff --git a/src/Transform/ONNX/Rewrite.td b/src/Transform/ONNX/Rewrite.td index 4857348..6d6fb10 100644 --- a/src/Transform/ONNX/Rewrite.td +++ b/src/Transform/ONNX/Rewrite.td @@ -24,6 +24,19 @@ include "src/Dialect/ONNX/ONNXOps.td" /// dag benefitsAdded = (addBenefit 0) /// >; +// Create a DenseElementsAttr from a float attribute and an element type. +def createDenseElementsAttrFromFloatAttr : NativeCodeCall< + "createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast().getElementType(), $1)">; + +// If '$1' is not NoneType, do subtraction '$1 - $2'. +// Otherwise, take the negative of '$2'. +def subtractOrNeg: NativeCodeCall< + "subtractOrNeg($_builder, $0.getDefiningOp()->getLoc(), $1, $2)">; + +// Create an ArrayAttr of IntergerAttr(s) of values in [1, N]. +def createArrayAttrOfOneToRankOf : NativeCodeCall< + "createArrayAttrOfOneToN($_builder, $0.getType().cast().getRank() - 1)">; + def GetNullAttr : NativeCodeCall<"Attribute()">; @@ -100,4 +113,63 @@ def ConvOpPaddingPattern: Pat< [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] >; +//===----------------------------------------------------------------------===// +// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv' +// by deriving new 'w' and 'b' for 'Conv': +// +// We have: +// (Conv) z = w * x + b +// (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias +// +// which corresponds to the following computation: +// y = w_ * x + b_ +// where +// w_ = scale * w / sqrt(var + eps) +// b_ = B + scale * (b - mean) / sqrt(var + eps) +// +// Hence, we rewrite: +// onnx.BatchNormalizationTestMode( +// onnx.Conv(x, w, b), +// scale, B, mean, var +// ) {eps = ...} +// +// as: +// onnx.Conv(x, w_, b_) +// +// where +// w_ = scale * w / sqrt(var + eps) +// b_ = B + scale * (b - mean) / sqrt(var + eps) +// +//===----------------------------------------------------------------------===// + +def FuseBatchNormTestModeConvPattern: Pat< + (ONNXBatchNormalizationTestModeOp:$res + (ONNXConvOp $x, $w, $b, + $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides), + $scale, $B, $mean, $var, $epsilon, $momentum), + (ONNXConvOp + $x, + // w_ + (ONNXMulOp + $w, + (ONNXUnsqueezeOp + (ONNXDivOp:$coefficientW + $scale, + (ONNXSqrtOp + (ONNXAddOp + $var, + (ONNXConstantOp + (GetNullAttr), + (createDenseElementsAttrFromFloatAttr $res, $epsilon))))), + (createArrayAttrOfOneToRankOf $w))), + // b_ + (ONNXAddOp + $B, + (ONNXMulOp + $coefficientW, + (subtractOrNeg $res, $b, $mean))), + + $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides) +>; + #endif // ONNX_REWRITE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 954fe75..e74a8be 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -106,3 +106,88 @@ func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-NEXT: return %arg0 : tensor<2xf32> } + +// ----- + +func @test_conv_batchnormtestmode_fusion_nobias(%arg0 : tensor<1x3x224x224xf32>) -> tensor<1x64x112x112xf32> { + %cst = constant unit + %0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32> + %1 = "onnx.Conv"(%arg0, %0, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, none) -> tensor<1x64x112x112xf32> + %2 = "onnx.Constant"() : () -> tensor<64xf32> + %3 = "onnx.Constant"() : () -> tensor<64xf32> + %4 = "onnx.Constant"() : () -> tensor<64xf32> + %5 = "onnx.Constant"() : () -> tensor<64xf32> + %6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> + return %6 : tensor<1x64x112x112xf32> + + // CHECK-LABEL: test_conv_batchnormtestmode_fusion_nobias + // CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32> + // CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32> + + // CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> + // CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32> + // CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[NEG_MEAN:%.+]] = "onnx.Neg"([[MEAN]]) : (tensor<64xf32>) -> tensor<*xf32> + // CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[NEG_MEAN]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64> + // CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + // CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> + + // CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}} + + // CHECK: return [[RES]] : tensor<1x64x112x112xf32> +} + +// ----- + +func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1 : tensor<64xf32>) -> tensor<1x64x112x112xf32> { + %cst = constant unit + %0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32> + %1 = "onnx.Conv"(%arg0, %0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> + %2 = "onnx.Constant"() : () -> tensor<64xf32> + %3 = "onnx.Constant"() : () -> tensor<64xf32> + %4 = "onnx.Constant"() : () -> tensor<64xf32> + %5 = "onnx.Constant"() : () -> tensor<64xf32> + %6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> + return %6 : tensor<1x64x112x112xf32> + + // CHECK-LABEL: test_conv_batchnormtestmode_fusion + // CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32> + // CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32> + // CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32> + + // CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> + // CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32> + // CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[SUB:%.+]] = "onnx.Sub"(%arg1, [[MEAN]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[SUB]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32> + // CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64> + // CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + // CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> + + // CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}} + + // CHECK: return [[RES]] : tensor<1x64x112x112xf32> +} + diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 7934aaf..cc447ca 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -227,3 +227,47 @@ func @test_default_transpose_const_3() -> tensor<*xi32> { // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32> // CHECK: return [[RES]] : tensor<3x2x4xi32> } + +//===----------------------------------------------------------------------===// +/// Div tests + +// ----- + +// CHECK-LABEL: @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> +func @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { + %0 = "onnx.Constant"() {value = dense<[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + %1 = "onnx.Constant"() {value = dense<[[2.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %2 = "onnx.Div"(%0, %1) : (tensor<3x2xf32>, tensor<1x1xf32>) -> tensor<3x2xf32> + "std.return"(%2) : (tensor<3x2xf32>) -> () + // CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]{{\]}}> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + // CHECK-NOT: {{.*}} = "onnx.Div"{{.*}} +} + +//===----------------------------------------------------------------------===// +/// Sqrt tests + +// ----- + +// CHECK-LABEL: @test_sqrt() -> tensor<1x2xf32> +func @test_sqrt() -> tensor<1x2xf32> { + %0 = "onnx.Constant"() {value = dense<[[4.0, 16.0]]> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %1 = "onnx.Sqrt"(%0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + "std.return"(%1) : (tensor<1x2xf32>) -> () + // CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[2.000000e+00, 4.000000e+00]{{\]}}> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + // CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}} +} + +//===----------------------------------------------------------------------===// +/// Unsqueeze tests + +// ----- + +// CHECK-LABEL: @test_unsqueeze() -> tensor<2x1x1xf32> +func @test_unsqueeze() -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[4.0, 16.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %1 = "onnx.Unsqueeze"(%0) {axes = [1, 2]} : (tensor<2xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + // CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}{{\[}}[4.000000e+00]{{\]}}, {{\[}}[1.600000e+01]{{\]}}{{\]}}> : tensor<2x1x1xf32>} : () -> tensor<2x1x1xf32> + // CHECK-NOT: {{.*}} = "onnx.Unsqueeze"{{.*}} +} + diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 1459ac1..87d59a3 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -364,7 +364,8 @@ OpsWithResultTypeInference = { # Currenlty, there are only two build methods generated: # - one with operands and attributes having a separate parameter, and # - one with operands and attributes having aggregated parameters. -custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad'] +custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', + 'Pad', 'Sqrt', 'Neg', 'Unsqueeze'] # Custom builder op list for operations with broadcast; we can deduce the right # output type, no need to leave it undef as in the above list. # Ops must have two operands, not one, not three... And there shall be two.