Fuse convolution and batch normalization (#253)
* Rewriting rule * Fix formulas * Reuse op results * Const propagation for Div and Sqrt * Explicitly use ONNXConstantOp * Minor revise * Const propagation for unsqueeze * Do const propagationnce all tensors have inferred shapes * LIT tests for fusion * Add LIT tests for constant propagation on Div, Sqrt, and Unsqueeze * Missing dash Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
38bd77e51a
commit
7c1e67898d
|
@ -141,6 +141,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX BatchNormalization operation in test mode";
|
let summary = "ONNX BatchNormalization operation in test mode";
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let description = [{
|
let description = [{
|
||||||
"Carries out batch normalization as described in the paper"
|
"Carries out batch normalization as described in the paper"
|
||||||
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
|
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
|
||||||
|
|
|
@ -2894,6 +2894,18 @@ def ONNXNegOp:ONNX_Op<"Neg",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X);
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
|
let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
|
||||||
|
auto elementType = X.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), X);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -5098,6 +5110,18 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X);
|
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
|
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
|
||||||
|
auto elementType = X.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), X);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -5574,6 +5598,18 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data,
|
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data,
|
||||||
I64ArrayAttr:$axes);
|
I64ArrayAttr:$axes);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded);
|
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes", [{
|
||||||
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), data, axes);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -393,6 +393,11 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||||||
pm.addPass(mlir::createAttributePromotionPass());
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createAttributePromotionPass());
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
|
// There are more opportunities for const propagation once all tensors have
|
||||||
|
// inferred shapes.
|
||||||
|
pm.addPass(mlir::createConstPropONNXToONNXPass());
|
||||||
|
// Clean dead code.
|
||||||
|
pm.addPass(mlir::createSymbolDCEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -120,6 +122,26 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
||||||
llvm_unreachable("constant propagation for MulOp: unkonwn data type");
|
llvm_unreachable("constant propagation for MulOp: unkonwn data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
|
Attribute &secondAttr) {
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
assert(rhsVal != 0 && "division by a zero");
|
||||||
|
double res = lhsVal / rhsVal;
|
||||||
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
}
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||||
|
assert(rhsVal != 0 && "division by a zero");
|
||||||
|
uint64_t res = lhsVal / rhsVal;
|
||||||
|
return rewriter.getIntegerAttr(elementType, res);
|
||||||
|
}
|
||||||
|
llvm_unreachable("constant propagation for DivOp: unkonwn data type");
|
||||||
|
}
|
||||||
// Recursively process one dimension in the rank of the two references. There
|
// Recursively process one dimension in the rank of the two references. There
|
||||||
// can be one of 3 cases.
|
// can be one of 3 cases.
|
||||||
// 1) We have fully defined accesses for both operands, launch the computations.
|
// 1) We have fully defined accesses for both operands, launch the computations.
|
||||||
|
@ -246,6 +268,17 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
||||||
llvm_unreachable("constant propagation for NegOp: unkonwn data type");
|
llvm_unreachable("constant propagation for NegOp: unkonwn data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double res = sqrt(val);
|
||||||
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
}
|
||||||
|
llvm_unreachable("constant propagation for SqrtOp: unkonwn data type");
|
||||||
|
}
|
||||||
|
|
||||||
template <typename ElementwiseUnaryOp>
|
template <typename ElementwiseUnaryOp>
|
||||||
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
||||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||||
|
@ -340,6 +373,28 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
||||||
return DenseElementsAttr::get(resType, resRef);
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Code to perform constant propagation for unsqueeze.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
DenseElementsAttr ConstPropUnsqueeze(
|
||||||
|
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
||||||
|
// Read dense attribute, the constant tensor we are transforming.
|
||||||
|
DenseElementsAttr denseAttr =
|
||||||
|
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
assert(denseAttr && "expected dense attribute");
|
||||||
|
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
// Unqueeze does not change the order of access, so just copy the whole data.
|
||||||
|
std::vector<Attribute> resVector;
|
||||||
|
for (auto value : denseAttr.getValues<Attribute>()) {
|
||||||
|
resVector.emplace_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<Attribute> resRef(resVector);
|
||||||
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern definition.
|
// Pattern definition.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -60,12 +60,21 @@ def CreateSubOfTwoConst :
|
||||||
def CreateNegOfConst :
|
def CreateNegOfConst :
|
||||||
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
|
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
|
||||||
|
|
||||||
def CreateMulOfTwoConst :
|
def CreateSqrtOfConst :
|
||||||
|
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXSqrtOp>($_builder, $0, $1)">;
|
||||||
|
|
||||||
|
def CreateMulOfTwoConst :
|
||||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
|
def CreateDivOfTwoConst :
|
||||||
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXDivOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
def CreateTransposeOfConst :
|
def CreateTransposeOfConst :
|
||||||
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
|
NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
|
def CreateUnsqueezeOfConst:
|
||||||
|
NativeCodeCall<"ConstPropUnsqueeze($_builder, $0, $1)">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with elementwise ADD operations.
|
// Patterns to enable opportunities with elementwise ADD operations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -163,6 +172,13 @@ def SubConstToNeg : Pat<
|
||||||
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
|
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
|
||||||
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
// Constant Propagation for Sqrt
|
||||||
|
def SqrtofConst : Pat<
|
||||||
|
// From onnx.Sqrt(c)
|
||||||
|
(ONNXSqrtOp (ONNXConstantOp:$constOp $s, $v)),
|
||||||
|
// To sqrt(c)
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateSqrtOfConst $constOp, $v)),
|
||||||
|
[(AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with elementwise MUL operations.
|
// Patterns to enable opportunities with elementwise MUL operations.
|
||||||
|
@ -232,6 +248,16 @@ def MulConstProp : Pat<
|
||||||
// Mulitional constraints (no sparse)
|
// Mulitional constraints (no sparse)
|
||||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
// Constant Propagation for Div
|
||||||
|
def DivConstProp : Pat<
|
||||||
|
// From div(c1, c2).
|
||||||
|
(ONNXDivOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||||
|
// To c1/c2
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateDivOfTwoConst $mulOp, $v1, $v2)),
|
||||||
|
// Division constraints (no sparse)
|
||||||
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with Transpose operations.
|
// Patterns to enable opportunities with Transpose operations.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -244,5 +270,16 @@ def TransposeofConst : Pat<
|
||||||
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
|
(ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
|
||||||
[(AttributeIsNull:$s)]>;
|
[(AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Patterns to enable opportunities with Unsqueeze operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def UnsqueezeofConst : Pat<
|
||||||
|
// From Unsqueeze (c, axis)
|
||||||
|
(ONNXUnsqueezeOp:$resOp (ONNXConstantOp $s, $v), $_),
|
||||||
|
// To c' where c' is the unsqueezed value.
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateUnsqueezeOfConst $resOp, $v)),
|
||||||
|
[(AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
|
||||||
#endif // ONNX_CONSTPROP
|
#endif // ONNX_CONSTPROP
|
||||||
|
|
|
@ -18,6 +18,36 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Create a DenseElementsAttr from a float attribute.
|
||||||
|
DenseElementsAttr createDenseElementsAttrFromFloatAttr(
|
||||||
|
PatternRewriter &rewriter, Type elementType, FloatAttr attr) {
|
||||||
|
SmallVector<int64_t, 1> dims(1, 1);
|
||||||
|
SmallVector<float, 1> values(1, attr.getValue().convertToFloat());
|
||||||
|
auto tensorType = mlir::RankedTensorType::get(dims, elementType);
|
||||||
|
return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If 'lhs' is not NoneType, return 'lhs - rhs'.
|
||||||
|
// Otherwise, return '-rhs'.
|
||||||
|
Value subtractOrNeg(
|
||||||
|
PatternRewriter &rewriter, Location loc, Value lhs, Value rhs) {
|
||||||
|
if (lhs.getType().isa<NoneType>()) {
|
||||||
|
Value result = rewriter.create<ONNXNegOp>(loc, rhs);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
Value result = rewriter.create<ONNXSubOp>(loc, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
|
||||||
|
ArrayAttr createArrayAttrOfOneToN(PatternRewriter &rewriter, int N) {
|
||||||
|
SmallVector<int64_t, 4> vals;
|
||||||
|
for (int i = 1; i <= N; ++i)
|
||||||
|
vals.emplace_back(i);
|
||||||
|
return rewriter.getI64ArrayAttr(vals);
|
||||||
|
}
|
||||||
|
|
||||||
// Check whether an ArrayAttr contains non-zero values or not.
|
// Check whether an ArrayAttr contains non-zero values or not.
|
||||||
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
|
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
|
||||||
bool allZeros = true;
|
bool allZeros = true;
|
||||||
|
@ -92,3 +122,9 @@ void ONNXConvOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<ConvOpPaddingPattern>(context);
|
results.insert<ConvOpPaddingPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// on the ONNXBatchNormalizationTestModeOp.
|
||||||
|
void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<FuseBatchNormTestModeConvPattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -24,6 +24,19 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
/// dag benefitsAdded = (addBenefit 0)
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
/// >;
|
/// >;
|
||||||
|
|
||||||
|
// Create a DenseElementsAttr from a float attribute and an element type.
|
||||||
|
def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
|
||||||
|
"createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">;
|
||||||
|
|
||||||
|
// If '$1' is not NoneType, do subtraction '$1 - $2'.
|
||||||
|
// Otherwise, take the negative of '$2'.
|
||||||
|
def subtractOrNeg: NativeCodeCall<
|
||||||
|
"subtractOrNeg($_builder, $0.getDefiningOp()->getLoc(), $1, $2)">;
|
||||||
|
|
||||||
|
// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
|
||||||
|
def createArrayAttrOfOneToRankOf : NativeCodeCall<
|
||||||
|
"createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 1)">;
|
||||||
|
|
||||||
def GetNullAttr :
|
def GetNullAttr :
|
||||||
NativeCodeCall<"Attribute()">;
|
NativeCodeCall<"Attribute()">;
|
||||||
|
|
||||||
|
@ -100,4 +113,63 @@ def ConvOpPaddingPattern: Pat<
|
||||||
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
|
[(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv'
|
||||||
|
// by deriving new 'w' and 'b' for 'Conv':
|
||||||
|
//
|
||||||
|
// We have:
|
||||||
|
// (Conv) z = w * x + b
|
||||||
|
// (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias
|
||||||
|
//
|
||||||
|
// which corresponds to the following computation:
|
||||||
|
// y = w_ * x + b_
|
||||||
|
// where
|
||||||
|
// w_ = scale * w / sqrt(var + eps)
|
||||||
|
// b_ = B + scale * (b - mean) / sqrt(var + eps)
|
||||||
|
//
|
||||||
|
// Hence, we rewrite:
|
||||||
|
// onnx.BatchNormalizationTestMode(
|
||||||
|
// onnx.Conv(x, w, b),
|
||||||
|
// scale, B, mean, var
|
||||||
|
// ) {eps = ...}
|
||||||
|
//
|
||||||
|
// as:
|
||||||
|
// onnx.Conv(x, w_, b_)
|
||||||
|
//
|
||||||
|
// where
|
||||||
|
// w_ = scale * w / sqrt(var + eps)
|
||||||
|
// b_ = B + scale * (b - mean) / sqrt(var + eps)
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def FuseBatchNormTestModeConvPattern: Pat<
|
||||||
|
(ONNXBatchNormalizationTestModeOp:$res
|
||||||
|
(ONNXConvOp $x, $w, $b,
|
||||||
|
$auto_pad, $dilation, $group, $kernel_shape, $pads, $strides),
|
||||||
|
$scale, $B, $mean, $var, $epsilon, $momentum),
|
||||||
|
(ONNXConvOp
|
||||||
|
$x,
|
||||||
|
// w_
|
||||||
|
(ONNXMulOp
|
||||||
|
$w,
|
||||||
|
(ONNXUnsqueezeOp
|
||||||
|
(ONNXDivOp:$coefficientW
|
||||||
|
$scale,
|
||||||
|
(ONNXSqrtOp
|
||||||
|
(ONNXAddOp
|
||||||
|
$var,
|
||||||
|
(ONNXConstantOp
|
||||||
|
(GetNullAttr),
|
||||||
|
(createDenseElementsAttrFromFloatAttr $res, $epsilon))))),
|
||||||
|
(createArrayAttrOfOneToRankOf $w))),
|
||||||
|
// b_
|
||||||
|
(ONNXAddOp
|
||||||
|
$B,
|
||||||
|
(ONNXMulOp
|
||||||
|
$coefficientW,
|
||||||
|
(subtractOrNeg $res, $b, $mean))),
|
||||||
|
|
||||||
|
$auto_pad, $dilation, $group, $kernel_shape, $pads, $strides)
|
||||||
|
>;
|
||||||
|
|
||||||
#endif // ONNX_REWRITE
|
#endif // ONNX_REWRITE
|
||||||
|
|
|
@ -106,3 +106,88 @@ func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
|
||||||
// CHECK-NEXT: return %arg0 : tensor<2xf32>
|
// CHECK-NEXT: return %arg0 : tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_conv_batchnormtestmode_fusion_nobias(%arg0 : tensor<1x3x224x224xf32>) -> tensor<1x64x112x112xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
|
||||||
|
%1 = "onnx.Conv"(%arg0, %0, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, none) -> tensor<1x64x112x112xf32>
|
||||||
|
%2 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%3 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%4 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%5 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
|
||||||
|
return %6 : tensor<1x64x112x112xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_conv_batchnormtestmode_fusion_nobias
|
||||||
|
// CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
|
||||||
|
// CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
|
||||||
|
// CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
|
||||||
|
// CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[NEG_MEAN:%.+]] = "onnx.Neg"([[MEAN]]) : (tensor<64xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[NEG_MEAN]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64>
|
||||||
|
// CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
// CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32>
|
||||||
|
|
||||||
|
// CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}}
|
||||||
|
|
||||||
|
// CHECK: return [[RES]] : tensor<1x64x112x112xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1 : tensor<64xf32>) -> tensor<1x64x112x112xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
|
||||||
|
%1 = "onnx.Conv"(%arg0, %0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
|
||||||
|
%2 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%3 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%4 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%5 = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
%6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
|
||||||
|
return %6 : tensor<1x64x112x112xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_conv_batchnormtestmode_fusion
|
||||||
|
// CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
|
||||||
|
// CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
|
||||||
|
// CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
|
||||||
|
// CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
|
||||||
|
// CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[SUB:%.+]] = "onnx.Sub"(%arg1, [[MEAN]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>
|
||||||
|
// CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[SUB]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64>
|
||||||
|
// CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
|
// CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32>
|
||||||
|
|
||||||
|
// CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}}
|
||||||
|
|
||||||
|
// CHECK: return [[RES]] : tensor<1x64x112x112xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -227,3 +227,47 @@ func @test_default_transpose_const_3() -> tensor<*xi32> {
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
|
||||||
// CHECK: return [[RES]] : tensor<3x2x4xi32>
|
// CHECK: return [[RES]] : tensor<3x2x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Div tests
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32>
|
||||||
|
func @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[2.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
|
||||||
|
%2 = "onnx.Div"(%0, %1) : (tensor<3x2xf32>, tensor<1x1xf32>) -> tensor<3x2xf32>
|
||||||
|
"std.return"(%2) : (tensor<3x2xf32>) -> ()
|
||||||
|
// CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]{{\]}}> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
||||||
|
// CHECK-NOT: {{.*}} = "onnx.Div"{{.*}}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Sqrt tests
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_sqrt() -> tensor<1x2xf32>
|
||||||
|
func @test_sqrt() -> tensor<1x2xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[4.0, 16.0]]> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
|
||||||
|
%1 = "onnx.Sqrt"(%0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||||
|
"std.return"(%1) : (tensor<1x2xf32>) -> ()
|
||||||
|
// CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[2.000000e+00, 4.000000e+00]{{\]}}> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
|
||||||
|
// CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Unsqueeze tests
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_unsqueeze() -> tensor<2x1x1xf32>
|
||||||
|
func @test_unsqueeze() -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[4.0, 16.0]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||||
|
%1 = "onnx.Unsqueeze"(%0) {axes = [1, 2]} : (tensor<2xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}{{\[}}[4.000000e+00]{{\]}}, {{\[}}[1.600000e+01]{{\]}}{{\]}}> : tensor<2x1x1xf32>} : () -> tensor<2x1x1xf32>
|
||||||
|
// CHECK-NOT: {{.*}} = "onnx.Unsqueeze"{{.*}}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -364,7 +364,8 @@ OpsWithResultTypeInference = {
|
||||||
# Currenlty, there are only two build methods generated:
|
# Currenlty, there are only two build methods generated:
|
||||||
# - one with operands and attributes having a separate parameter, and
|
# - one with operands and attributes having a separate parameter, and
|
||||||
# - one with operands and attributes having aggregated parameters.
|
# - one with operands and attributes having aggregated parameters.
|
||||||
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare',
|
||||||
|
'Pad', 'Sqrt', 'Neg', 'Unsqueeze']
|
||||||
# Custom builder op list for operations with broadcast; we can deduce the right
|
# Custom builder op list for operations with broadcast; we can deduce the right
|
||||||
# output type, no need to leave it undef as in the above list.
|
# output type, no need to leave it undef as in the above list.
|
||||||
# Ops must have two operands, not one, not three... And there shall be two.
|
# Ops must have two operands, not one, not three... And there shall be two.
|
||||||
|
|
Loading…
Reference in New Issue