Fuse convolution and batch normalization (#253)

* Rewriting rule

* Fix formulas

* Reuse op results

* Const propagation for Div and Sqrt

* Explicitly use ONNXConstantOp

* Minor revise

* Const propagation for unsqueeze

* Do const propagationnce all tensors have inferred shapes

* LIT tests for fusion

* Add LIT tests for constant propagation on Div, Sqrt, and Unsqueeze

* Missing dash

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-08-18 17:41:40 +09:00 committed by GitHub
parent 38bd77e51a
commit 7c1e67898d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 409 additions and 37 deletions

View File

@ -141,6 +141,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX BatchNormalization operation in test mode"; let summary = "ONNX BatchNormalization operation in test mode";
let hasCanonicalizer = 1;
let description = [{ let description = [{
"Carries out batch normalization as described in the paper" "Carries out batch normalization as described in the paper"
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run," "https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"

View File

@ -2894,6 +2894,18 @@ def ONNXNegOp:ONNX_Op<"Neg",
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X); let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X);
let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y); let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
auto elementType = X.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), X);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 1; return 1;
@ -5098,6 +5110,18 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X); let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y); let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
auto elementType = X.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), X);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 1; return 1;
@ -5574,6 +5598,18 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data, let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data,
I64ArrayAttr:$axes); I64ArrayAttr:$axes);
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded); let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes", [{
auto elementType = data.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), data, axes);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 1; return 1;

View File

@ -393,6 +393,11 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createAttributePromotionPass()); pm.addPass(mlir::createAttributePromotionPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass()); pm.addPass(mlir::createAttributePromotionPass());
// There are more opportunities for const propagation once all tensors have
// inferred shapes.
pm.addPass(mlir::createConstPropONNXToONNXPass());
// Clean dead code.
pm.addPass(mlir::createSymbolDCEPass());
} }
void addONNXToKrnlPasses(mlir::PassManager &pm) { void addONNXToKrnlPasses(mlir::PassManager &pm) {

View File

@ -21,6 +21,8 @@
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
#include <math.h>
using namespace mlir; using namespace mlir;
namespace { namespace {
@ -120,6 +122,26 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
llvm_unreachable("constant propagation for MulOp: unkonwn data type"); llvm_unreachable("constant propagation for MulOp: unkonwn data type");
} }
template <>
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
assert(rhsVal != 0 && "division by a zero");
double res = lhsVal / rhsVal;
return rewriter.getFloatAttr(elementType, res);
}
if (elementType.isa<IntegerType>()) {
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
assert(rhsVal != 0 && "division by a zero");
uint64_t res = lhsVal / rhsVal;
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for DivOp: unkonwn data type");
}
// Recursively process one dimension in the rank of the two references. There // Recursively process one dimension in the rank of the two references. There
// can be one of 3 cases. // can be one of 3 cases.
// 1) We have fully defined accesses for both operands, launch the computations. // 1) We have fully defined accesses for both operands, launch the computations.
@ -246,6 +268,17 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
llvm_unreachable("constant propagation for NegOp: unkonwn data type"); llvm_unreachable("constant propagation for NegOp: unkonwn data type");
} }
template <>
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = sqrt(val);
return rewriter.getFloatAttr(elementType, res);
}
llvm_unreachable("constant propagation for SqrtOp: unkonwn data type");
}
template <typename ElementwiseUnaryOp> template <typename ElementwiseUnaryOp>
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter, void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr, std::vector<Attribute> &resVector, DenseElementsAttr &attr,
@ -340,6 +373,28 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
return DenseElementsAttr::get(resType, resRef); return DenseElementsAttr::get(resType, resRef);
} }
//===----------------------------------------------------------------------===//
// Code to perform constant propagation for unsqueeze.
//===----------------------------------------------------------------------===//
DenseElementsAttr ConstPropUnsqueeze(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
// Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert(denseAttr && "expected dense attribute");
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
// Unqueeze does not change the order of access, so just copy the whole data.
std::vector<Attribute> resVector;
for (auto value : denseAttr.getValues<Attribute>()) {
resVector.emplace_back(value);
}
ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern definition. // Pattern definition.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -60,12 +60,21 @@ def CreateSubOfTwoConst :
def CreateNegOfConst : def CreateNegOfConst :
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">; NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
def CreateMulOfTwoConst : def CreateSqrtOfConst :
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXSqrtOp>($_builder, $0, $1)">;
def CreateMulOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">; NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
def CreateDivOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXDivOp>($_builder, $0, $1, $2)">;
def CreateTransposeOfConst : def CreateTransposeOfConst :
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">; NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
def CreateUnsqueezeOfConst:
NativeCodeCall<"ConstPropUnsqueeze($_builder, $0, $1)">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise ADD operations. // Patterns to enable opportunities with elementwise ADD operations.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -163,6 +172,13 @@ def SubConstToNeg : Pat<
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))), (ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>; [(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
// Constant Propagation for Sqrt
def SqrtofConst : Pat<
// From onnx.Sqrt(c)
(ONNXSqrtOp (ONNXConstantOp:$constOp $s, $v)),
// To sqrt(c)
(ONNXConstantOp (GetNullAttr), (CreateSqrtOfConst $constOp, $v)),
[(AttributeIsNull:$s)]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise MUL operations. // Patterns to enable opportunities with elementwise MUL operations.
@ -232,6 +248,16 @@ def MulConstProp : Pat<
// Mulitional constraints (no sparse) // Mulitional constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
// Constant Propagation for Div
def DivConstProp : Pat<
// From div(c1, c2).
(ONNXDivOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
// To c1/c2
(ONNXConstantOp (GetNullAttr), (CreateDivOfTwoConst $mulOp, $v1, $v2)),
// Division constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Patterns to enable opportunities with Transpose operations. // Patterns to enable opportunities with Transpose operations.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -244,5 +270,16 @@ def TransposeofConst : Pat<
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)), (ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
[(AttributeIsNull:$s)]>; [(AttributeIsNull:$s)]>;
//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with Unsqueeze operations.
//===----------------------------------------------------------------------===//
def UnsqueezeofConst : Pat<
// From Unsqueeze (c, axis)
(ONNXUnsqueezeOp:$resOp (ONNXConstantOp $s, $v), $_),
// To c' where c' is the unsqueezed value.
(ONNXConstantOp (GetNullAttr), (CreateUnsqueezeOfConst $resOp, $v)),
[(AttributeIsNull:$s)]>;
#endif // ONNX_CONSTPROP #endif // ONNX_CONSTPROP

View File

@ -18,6 +18,36 @@ using namespace mlir;
namespace { namespace {
// Create a DenseElementsAttr from a float attribute.
DenseElementsAttr createDenseElementsAttrFromFloatAttr(
PatternRewriter &rewriter, Type elementType, FloatAttr attr) {
SmallVector<int64_t, 1> dims(1, 1);
SmallVector<float, 1> values(1, attr.getValue().convertToFloat());
auto tensorType = mlir::RankedTensorType::get(dims, elementType);
return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
}
// If 'lhs' is not NoneType, return 'lhs - rhs'.
// Otherwise, return '-rhs'.
Value subtractOrNeg(
PatternRewriter &rewriter, Location loc, Value lhs, Value rhs) {
if (lhs.getType().isa<NoneType>()) {
Value result = rewriter.create<ONNXNegOp>(loc, rhs);
return result;
} else {
Value result = rewriter.create<ONNXSubOp>(loc, lhs, rhs);
return result;
}
}
// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
ArrayAttr createArrayAttrOfOneToN(PatternRewriter &rewriter, int N) {
SmallVector<int64_t, 4> vals;
for (int i = 1; i <= N; ++i)
vals.emplace_back(i);
return rewriter.getI64ArrayAttr(vals);
}
// Check whether an ArrayAttr contains non-zero values or not. // Check whether an ArrayAttr contains non-zero values or not.
bool hasNonZeroInArrayAttr(ArrayAttr attrs) { bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true; bool allZeros = true;
@ -92,3 +122,9 @@ void ONNXConvOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvOpPaddingPattern>(context); results.insert<ConvOpPaddingPattern>(context);
} }
/// on the ONNXBatchNormalizationTestModeOp.
void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FuseBatchNormTestModeConvPattern>(context);
}

View File

@ -24,6 +24,19 @@ include "src/Dialect/ONNX/ONNXOps.td"
/// dag benefitsAdded = (addBenefit 0) /// dag benefitsAdded = (addBenefit 0)
/// >; /// >;
// Create a DenseElementsAttr from a float attribute and an element type.
def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
"createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">;
// If '$1' is not NoneType, do subtraction '$1 - $2'.
// Otherwise, take the negative of '$2'.
def subtractOrNeg: NativeCodeCall<
"subtractOrNeg($_builder, $0.getDefiningOp()->getLoc(), $1, $2)">;
// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
def createArrayAttrOfOneToRankOf : NativeCodeCall<
"createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 1)">;
def GetNullAttr : def GetNullAttr :
NativeCodeCall<"Attribute()">; NativeCodeCall<"Attribute()">;
@ -100,4 +113,63 @@ def ConvOpPaddingPattern: Pat<
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
>; >;
//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv'
// by deriving new 'w' and 'b' for 'Conv':
//
// We have:
// (Conv) z = w * x + b
// (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias
//
// which corresponds to the following computation:
// y = w_ * x + b_
// where
// w_ = scale * w / sqrt(var + eps)
// b_ = B + scale * (b - mean) / sqrt(var + eps)
//
// Hence, we rewrite:
// onnx.BatchNormalizationTestMode(
// onnx.Conv(x, w, b),
// scale, B, mean, var
// ) {eps = ...}
//
// as:
// onnx.Conv(x, w_, b_)
//
// where
// w_ = scale * w / sqrt(var + eps)
// b_ = B + scale * (b - mean) / sqrt(var + eps)
//
//===----------------------------------------------------------------------===//
def FuseBatchNormTestModeConvPattern: Pat<
(ONNXBatchNormalizationTestModeOp:$res
(ONNXConvOp $x, $w, $b,
$auto_pad, $dilation, $group, $kernel_shape, $pads, $strides),
$scale, $B, $mean, $var, $epsilon, $momentum),
(ONNXConvOp
$x,
// w_
(ONNXMulOp
$w,
(ONNXUnsqueezeOp
(ONNXDivOp:$coefficientW
$scale,
(ONNXSqrtOp
(ONNXAddOp
$var,
(ONNXConstantOp
(GetNullAttr),
(createDenseElementsAttrFromFloatAttr $res, $epsilon))))),
(createArrayAttrOfOneToRankOf $w))),
// b_
(ONNXAddOp
$B,
(ONNXMulOp
$coefficientW,
(subtractOrNeg $res, $b, $mean))),
$auto_pad, $dilation, $group, $kernel_shape, $pads, $strides)
>;
#endif // ONNX_REWRITE #endif // ONNX_REWRITE

View File

@ -106,3 +106,88 @@ func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-NEXT: return %arg0 : tensor<2xf32> // CHECK-NEXT: return %arg0 : tensor<2xf32>
} }
// -----
func @test_conv_batchnormtestmode_fusion_nobias(%arg0 : tensor<1x3x224x224xf32>) -> tensor<1x64x112x112xf32> {
%cst = constant unit
%0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
%1 = "onnx.Conv"(%arg0, %0, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, none) -> tensor<1x64x112x112xf32>
%2 = "onnx.Constant"() : () -> tensor<64xf32>
%3 = "onnx.Constant"() : () -> tensor<64xf32>
%4 = "onnx.Constant"() : () -> tensor<64xf32>
%5 = "onnx.Constant"() : () -> tensor<64xf32>
%6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
return %6 : tensor<1x64x112x112xf32>
// CHECK-LABEL: test_conv_batchnormtestmode_fusion_nobias
// CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
// CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
// CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEG_MEAN:%.+]] = "onnx.Neg"([[MEAN]]) : (tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[NEG_MEAN]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64>
// CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32>
// CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}}
// CHECK: return [[RES]] : tensor<1x64x112x112xf32>
}
// -----
func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1 : tensor<64xf32>) -> tensor<1x64x112x112xf32> {
%cst = constant unit
%0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
%1 = "onnx.Conv"(%arg0, %0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
%2 = "onnx.Constant"() : () -> tensor<64xf32>
%3 = "onnx.Constant"() : () -> tensor<64xf32>
%4 = "onnx.Constant"() : () -> tensor<64xf32>
%5 = "onnx.Constant"() : () -> tensor<64xf32>
%6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
return %6 : tensor<1x64x112x112xf32>
// CHECK-LABEL: test_conv_batchnormtestmode_fusion
// CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
// CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
// CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
// CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[SUB:%.+]] = "onnx.Sub"(%arg1, [[MEAN]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[SUB]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64>
// CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32>
// CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}}
// CHECK: return [[RES]] : tensor<1x64x112x112xf32>
}

View File

@ -227,3 +227,47 @@ func @test_default_transpose_const_3() -> tensor<*xi32> {
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32> // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
// CHECK: return [[RES]] : tensor<3x2x4xi32> // CHECK: return [[RES]] : tensor<3x2x4xi32>
} }
//===----------------------------------------------------------------------===//
/// Div tests
// -----
// CHECK-LABEL: @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32>
func @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
%0 = "onnx.Constant"() {value = dense<[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
%1 = "onnx.Constant"() {value = dense<[[2.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%2 = "onnx.Div"(%0, %1) : (tensor<3x2xf32>, tensor<1x1xf32>) -> tensor<3x2xf32>
"std.return"(%2) : (tensor<3x2xf32>) -> ()
// CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]{{\]}}> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
// CHECK-NOT: {{.*}} = "onnx.Div"{{.*}}
}
//===----------------------------------------------------------------------===//
/// Sqrt tests
// -----
// CHECK-LABEL: @test_sqrt() -> tensor<1x2xf32>
func @test_sqrt() -> tensor<1x2xf32> {
%0 = "onnx.Constant"() {value = dense<[[4.0, 16.0]]> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
"std.return"(%1) : (tensor<1x2xf32>) -> ()
// CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[2.000000e+00, 4.000000e+00]{{\]}}> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
// CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}}
}
//===----------------------------------------------------------------------===//
/// Unsqueeze tests
// -----
// CHECK-LABEL: @test_unsqueeze() -> tensor<2x1x1xf32>
func @test_unsqueeze() -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[4.0, 16.0]> : tensor<2xf32>} : () -> tensor<2xf32>
%1 = "onnx.Unsqueeze"(%0) {axes = [1, 2]} : (tensor<2xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}{{\[}}[4.000000e+00]{{\]}}, {{\[}}[1.600000e+01]{{\]}}{{\]}}> : tensor<2x1x1xf32>} : () -> tensor<2x1x1xf32>
// CHECK-NOT: {{.*}} = "onnx.Unsqueeze"{{.*}}
}

View File

@ -364,7 +364,8 @@ OpsWithResultTypeInference = {
# Currenlty, there are only two build methods generated: # Currenlty, there are only two build methods generated:
# - one with operands and attributes having a separate parameter, and # - one with operands and attributes having a separate parameter, and
# - one with operands and attributes having aggregated parameters. # - one with operands and attributes having aggregated parameters.
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad'] custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare',
'Pad', 'Sqrt', 'Neg', 'Unsqueeze']
# Custom builder op list for operations with broadcast; we can deduce the right # Custom builder op list for operations with broadcast; we can deduce the right
# output type, no need to leave it undef as in the above list. # output type, no need to leave it undef as in the above list.
# Ops must have two operands, not one, not three... And there shall be two. # Ops must have two operands, not one, not three... And there shall be two.