Fuse convolution and batch normalization (#253)
* Rewriting rule * Fix formulas * Reuse op results * Const propagation for Div and Sqrt * Explicitly use ONNXConstantOp * Minor revise * Const propagation for unsqueeze * Do const propagationnce all tensors have inferred shapes * LIT tests for fusion * Add LIT tests for constant propagation on Div, Sqrt, and Unsqueeze * Missing dash Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									38bd77e51a
								
							
						
					
					
						commit
						7c1e67898d
					
				| 
						 | 
					@ -141,6 +141,7 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
 | 
				
			||||||
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
 | 
					def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
 | 
				
			||||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
					    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
  let summary = "ONNX BatchNormalization operation in test mode";
 | 
					  let summary = "ONNX BatchNormalization operation in test mode";
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
    "Carries out batch normalization as described in the paper"
 | 
					    "Carries out batch normalization as described in the paper"
 | 
				
			||||||
    "https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
 | 
					    "https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2894,6 +2894,18 @@ def ONNXNegOp:ONNX_Op<"Neg",
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
  let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X);
 | 
					  let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$X);
 | 
				
			||||||
  let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
 | 
					  let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
 | 
				
			||||||
 | 
					  let builders = [
 | 
				
			||||||
 | 
					    OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
 | 
				
			||||||
 | 
					      auto elementType = X.getType().cast<TensorType>().getElementType();
 | 
				
			||||||
 | 
					      build(builder, state, UnrankedTensorType::get(elementType), X);
 | 
				
			||||||
 | 
					    }]>,
 | 
				
			||||||
 | 
					    OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
 | 
				
			||||||
 | 
					      auto elementType = operands[0].getType().cast<TensorType>().getElementType();
 | 
				
			||||||
 | 
					      std::vector<mlir::Type> outputTypes;
 | 
				
			||||||
 | 
					      outputTypes.emplace_back(UnrankedTensorType::get(elementType));
 | 
				
			||||||
 | 
					      build(builder, state, outputTypes, operands, attributes);
 | 
				
			||||||
 | 
					    }]>
 | 
				
			||||||
 | 
					    ];
 | 
				
			||||||
    let extraClassDeclaration = [{
 | 
					    let extraClassDeclaration = [{
 | 
				
			||||||
      static int getNumberOfOperands() {
 | 
					      static int getNumberOfOperands() {
 | 
				
			||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
| 
						 | 
					@ -5098,6 +5110,18 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
  let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X);
 | 
					  let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$X);
 | 
				
			||||||
  let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
 | 
					  let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, AnyMemRef]>:$Y);
 | 
				
			||||||
 | 
					  let builders = [
 | 
				
			||||||
 | 
					    OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{
 | 
				
			||||||
 | 
					      auto elementType = X.getType().cast<TensorType>().getElementType();
 | 
				
			||||||
 | 
					      build(builder, state, UnrankedTensorType::get(elementType), X);
 | 
				
			||||||
 | 
					    }]>,
 | 
				
			||||||
 | 
					    OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
 | 
				
			||||||
 | 
					      auto elementType = operands[0].getType().cast<TensorType>().getElementType();
 | 
				
			||||||
 | 
					      std::vector<mlir::Type> outputTypes;
 | 
				
			||||||
 | 
					      outputTypes.emplace_back(UnrankedTensorType::get(elementType));
 | 
				
			||||||
 | 
					      build(builder, state, outputTypes, operands, attributes);
 | 
				
			||||||
 | 
					    }]>
 | 
				
			||||||
 | 
					    ];
 | 
				
			||||||
    let extraClassDeclaration = [{
 | 
					    let extraClassDeclaration = [{
 | 
				
			||||||
      static int getNumberOfOperands() {
 | 
					      static int getNumberOfOperands() {
 | 
				
			||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
| 
						 | 
					@ -5574,6 +5598,18 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
 | 
				
			||||||
  let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data,
 | 
					  let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$data,
 | 
				
			||||||
    I64ArrayAttr:$axes);
 | 
					    I64ArrayAttr:$axes);
 | 
				
			||||||
  let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded);
 | 
					  let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$expanded);
 | 
				
			||||||
 | 
					  let builders = [
 | 
				
			||||||
 | 
					    OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes", [{
 | 
				
			||||||
 | 
					      auto elementType = data.getType().cast<TensorType>().getElementType();
 | 
				
			||||||
 | 
					      build(builder, state, UnrankedTensorType::get(elementType), data, axes);
 | 
				
			||||||
 | 
					    }]>,
 | 
				
			||||||
 | 
					    OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
 | 
				
			||||||
 | 
					      auto elementType = operands[0].getType().cast<TensorType>().getElementType();
 | 
				
			||||||
 | 
					      std::vector<mlir::Type> outputTypes;
 | 
				
			||||||
 | 
					      outputTypes.emplace_back(UnrankedTensorType::get(elementType));
 | 
				
			||||||
 | 
					      build(builder, state, outputTypes, operands, attributes);
 | 
				
			||||||
 | 
					    }]>
 | 
				
			||||||
 | 
					    ];
 | 
				
			||||||
    let extraClassDeclaration = [{
 | 
					    let extraClassDeclaration = [{
 | 
				
			||||||
      static int getNumberOfOperands() {
 | 
					      static int getNumberOfOperands() {
 | 
				
			||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -393,6 +393,11 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
 | 
				
			||||||
  pm.addPass(mlir::createAttributePromotionPass());
 | 
					  pm.addPass(mlir::createAttributePromotionPass());
 | 
				
			||||||
  pm.addPass(mlir::createShapeInferencePass());
 | 
					  pm.addPass(mlir::createShapeInferencePass());
 | 
				
			||||||
  pm.addPass(mlir::createAttributePromotionPass());
 | 
					  pm.addPass(mlir::createAttributePromotionPass());
 | 
				
			||||||
 | 
					  // There are more opportunities for const propagation once all tensors have
 | 
				
			||||||
 | 
					  // inferred shapes.
 | 
				
			||||||
 | 
					  pm.addPass(mlir::createConstPropONNXToONNXPass());
 | 
				
			||||||
 | 
					  // Clean dead code.
 | 
				
			||||||
 | 
					  pm.addPass(mlir::createSymbolDCEPass());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void addONNXToKrnlPasses(mlir::PassManager &pm) {
 | 
					void addONNXToKrnlPasses(mlir::PassManager &pm) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,6 +21,8 @@
 | 
				
			||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
 | 
					#include "src/Dialect/ONNX/ONNXOps.hpp"
 | 
				
			||||||
#include "src/Pass/Passes.hpp"
 | 
					#include "src/Pass/Passes.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <math.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace mlir;
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
| 
						 | 
					@ -120,6 +122,26 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
 | 
				
			||||||
  llvm_unreachable("constant propagation for MulOp: unkonwn data type");
 | 
					  llvm_unreachable("constant propagation for MulOp: unkonwn data type");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
 | 
				
			||||||
 | 
					    Attribute &secondAttr) {
 | 
				
			||||||
 | 
					  if (elementType.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
 | 
				
			||||||
 | 
					    double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
 | 
				
			||||||
 | 
					    assert(rhsVal != 0 && "division by a zero");
 | 
				
			||||||
 | 
					    double res = lhsVal / rhsVal;
 | 
				
			||||||
 | 
					    return rewriter.getFloatAttr(elementType, res);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  if (elementType.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					    uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					    uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					    assert(rhsVal != 0 && "division by a zero");
 | 
				
			||||||
 | 
					    uint64_t res = lhsVal / rhsVal;
 | 
				
			||||||
 | 
					    return rewriter.getIntegerAttr(elementType, res);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  llvm_unreachable("constant propagation for DivOp: unkonwn data type");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
// Recursively process one dimension in the rank of the two references. There
 | 
					// Recursively process one dimension in the rank of the two references. There
 | 
				
			||||||
// can be one of 3 cases.
 | 
					// can be one of 3 cases.
 | 
				
			||||||
// 1) We have fully defined accesses for both operands, launch the computations.
 | 
					// 1) We have fully defined accesses for both operands, launch the computations.
 | 
				
			||||||
| 
						 | 
					@ -246,6 +268,17 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
 | 
				
			||||||
  llvm_unreachable("constant propagation for NegOp: unkonwn data type");
 | 
					  llvm_unreachable("constant propagation for NegOp: unkonwn data type");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Type elementType, Attribute &attr) {
 | 
				
			||||||
 | 
					  if (elementType.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    double val = attr.cast<FloatAttr>().getValueAsDouble();
 | 
				
			||||||
 | 
					    double res = sqrt(val);
 | 
				
			||||||
 | 
					    return rewriter.getFloatAttr(elementType, res);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  llvm_unreachable("constant propagation for SqrtOp: unkonwn data type");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename ElementwiseUnaryOp>
 | 
					template <typename ElementwiseUnaryOp>
 | 
				
			||||||
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
 | 
					void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
 | 
				
			||||||
    std::vector<Attribute> &resVector, DenseElementsAttr &attr,
 | 
					    std::vector<Attribute> &resVector, DenseElementsAttr &attr,
 | 
				
			||||||
| 
						 | 
					@ -340,6 +373,28 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
 | 
				
			||||||
  return DenseElementsAttr::get(resType, resRef);
 | 
					  return DenseElementsAttr::get(resType, resRef);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Code to perform constant propagation for unsqueeze.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DenseElementsAttr ConstPropUnsqueeze(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
 | 
				
			||||||
 | 
					  // Read dense attribute, the constant tensor we are transforming.
 | 
				
			||||||
 | 
					  DenseElementsAttr denseAttr =
 | 
				
			||||||
 | 
					      attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
 | 
				
			||||||
 | 
					  assert(denseAttr && "expected dense attribute");
 | 
				
			||||||
 | 
					  ShapedType resType = resOperand.getType().cast<RankedTensorType>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Unqueeze does not change the order of access, so just copy the whole data.
 | 
				
			||||||
 | 
					  std::vector<Attribute> resVector;
 | 
				
			||||||
 | 
					  for (auto value : denseAttr.getValues<Attribute>()) {
 | 
				
			||||||
 | 
					    resVector.emplace_back(value);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ArrayRef<Attribute> resRef(resVector);
 | 
				
			||||||
 | 
					  return DenseElementsAttr::get(resType, resRef);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Pattern definition.
 | 
					// Pattern definition.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,12 +60,21 @@ def CreateSubOfTwoConst :
 | 
				
			||||||
def CreateNegOfConst :
 | 
					def CreateNegOfConst :
 | 
				
			||||||
   NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
 | 
					   NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 def CreateMulOfTwoConst :
 | 
					def CreateSqrtOfConst :
 | 
				
			||||||
 | 
					   NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXSqrtOp>($_builder, $0, $1)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def CreateMulOfTwoConst :
 | 
				
			||||||
   NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
 | 
					   NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def CreateDivOfTwoConst :
 | 
				
			||||||
 | 
					   NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXDivOp>($_builder, $0, $1, $2)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def CreateTransposeOfConst :
 | 
					def CreateTransposeOfConst :
 | 
				
			||||||
   NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
 | 
					   NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def CreateUnsqueezeOfConst:
 | 
				
			||||||
 | 
					   NativeCodeCall<"ConstPropUnsqueeze($_builder, $0, $1)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Patterns to enable opportunities with elementwise ADD operations.
 | 
					// Patterns to enable opportunities with elementwise ADD operations.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -163,6 +172,13 @@ def SubConstToNeg : Pat<
 | 
				
			||||||
    (ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
 | 
					    (ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
 | 
				
			||||||
    [(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
 | 
					    [(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Constant Propagation for Sqrt
 | 
				
			||||||
 | 
					def SqrtofConst :  Pat<
 | 
				
			||||||
 | 
					    // From  onnx.Sqrt(c)
 | 
				
			||||||
 | 
					    (ONNXSqrtOp (ONNXConstantOp:$constOp $s, $v)),
 | 
				
			||||||
 | 
					    // To sqrt(c)
 | 
				
			||||||
 | 
					    (ONNXConstantOp (GetNullAttr), (CreateSqrtOfConst $constOp, $v)),
 | 
				
			||||||
 | 
					    [(AttributeIsNull:$s)]>;
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Patterns to enable opportunities with elementwise MUL operations.
 | 
					// Patterns to enable opportunities with elementwise MUL operations.
 | 
				
			||||||
| 
						 | 
					@ -232,6 +248,16 @@ def MulConstProp : Pat<
 | 
				
			||||||
    // Mulitional constraints (no sparse)
 | 
					    // Mulitional constraints (no sparse)
 | 
				
			||||||
    [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
 | 
					    [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Constant Propagation for Div 
 | 
				
			||||||
 | 
					def DivConstProp : Pat<
 | 
				
			||||||
 | 
					    // From div(c1, c2).
 | 
				
			||||||
 | 
					    (ONNXDivOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
 | 
				
			||||||
 | 
					    // To c1/c2
 | 
				
			||||||
 | 
					    (ONNXConstantOp (GetNullAttr), (CreateDivOfTwoConst $mulOp, $v1, $v2)),
 | 
				
			||||||
 | 
					    // Division constraints (no sparse)
 | 
				
			||||||
 | 
					    [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Patterns to enable opportunities with Transpose operations.
 | 
					// Patterns to enable opportunities with Transpose operations.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -244,5 +270,16 @@ def TransposeofConst :  Pat<
 | 
				
			||||||
    (ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
 | 
					    (ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)),
 | 
				
			||||||
    [(AttributeIsNull:$s)]>;
 | 
					    [(AttributeIsNull:$s)]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Patterns to enable opportunities with Unsqueeze operations.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def UnsqueezeofConst :  Pat<
 | 
				
			||||||
 | 
					    // From Unsqueeze (c, axis)
 | 
				
			||||||
 | 
					    (ONNXUnsqueezeOp:$resOp (ONNXConstantOp $s, $v), $_),
 | 
				
			||||||
 | 
					    // To c' where c' is the unsqueezed value.
 | 
				
			||||||
 | 
					    (ONNXConstantOp (GetNullAttr), (CreateUnsqueezeOfConst $resOp, $v)),
 | 
				
			||||||
 | 
					    [(AttributeIsNull:$s)]>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif // ONNX_CONSTPROP
 | 
					#endif // ONNX_CONSTPROP
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,36 @@ using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a DenseElementsAttr from a float attribute.
 | 
				
			||||||
 | 
					DenseElementsAttr createDenseElementsAttrFromFloatAttr(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Type elementType, FloatAttr attr) {
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> dims(1, 1);
 | 
				
			||||||
 | 
					  SmallVector<float, 1> values(1, attr.getValue().convertToFloat());
 | 
				
			||||||
 | 
					  auto tensorType = mlir::RankedTensorType::get(dims, elementType);
 | 
				
			||||||
 | 
					  return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// If 'lhs' is not NoneType, return 'lhs - rhs'.
 | 
				
			||||||
 | 
					// Otherwise, return '-rhs'.
 | 
				
			||||||
 | 
					Value subtractOrNeg(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Location loc, Value lhs, Value rhs) {
 | 
				
			||||||
 | 
					  if (lhs.getType().isa<NoneType>()) {
 | 
				
			||||||
 | 
					    Value result = rewriter.create<ONNXNegOp>(loc, rhs);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    Value result = rewriter.create<ONNXSubOp>(loc, lhs, rhs);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
 | 
				
			||||||
 | 
					ArrayAttr createArrayAttrOfOneToN(PatternRewriter &rewriter, int N) {
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 4> vals;
 | 
				
			||||||
 | 
					  for (int i = 1; i <= N; ++i)
 | 
				
			||||||
 | 
					    vals.emplace_back(i);
 | 
				
			||||||
 | 
					  return rewriter.getI64ArrayAttr(vals);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Check whether an ArrayAttr contains non-zero values or not.
 | 
					// Check whether an ArrayAttr contains non-zero values or not.
 | 
				
			||||||
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
 | 
					bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
 | 
				
			||||||
  bool allZeros = true;
 | 
					  bool allZeros = true;
 | 
				
			||||||
| 
						 | 
					@ -92,3 +122,9 @@ void ONNXConvOp::getCanonicalizationPatterns(
 | 
				
			||||||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
					    OwningRewritePatternList &results, MLIRContext *context) {
 | 
				
			||||||
  results.insert<ConvOpPaddingPattern>(context);
 | 
					  results.insert<ConvOpPaddingPattern>(context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// on the ONNXBatchNormalizationTestModeOp.
 | 
				
			||||||
 | 
					void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &results, MLIRContext *context) {
 | 
				
			||||||
 | 
					  results.insert<FuseBatchNormTestModeConvPattern>(context);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,6 +24,19 @@ include "src/Dialect/ONNX/ONNXOps.td"
 | 
				
			||||||
///    dag benefitsAdded = (addBenefit 0)
 | 
					///    dag benefitsAdded = (addBenefit 0)
 | 
				
			||||||
/// >;
 | 
					/// >;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a DenseElementsAttr from a float attribute and an element type.
 | 
				
			||||||
 | 
					def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
 | 
				
			||||||
 | 
					  "createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// If '$1' is not NoneType, do subtraction '$1 - $2'.
 | 
				
			||||||
 | 
					// Otherwise, take the negative of '$2'.
 | 
				
			||||||
 | 
					def subtractOrNeg: NativeCodeCall<
 | 
				
			||||||
 | 
					  "subtractOrNeg($_builder, $0.getDefiningOp()->getLoc(), $1, $2)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create an ArrayAttr of IntergerAttr(s) of values in [1, N].
 | 
				
			||||||
 | 
					def createArrayAttrOfOneToRankOf : NativeCodeCall<
 | 
				
			||||||
 | 
					  "createArrayAttrOfOneToN($_builder, $0.getType().cast<ShapedType>().getRank() - 1)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def GetNullAttr :
 | 
					def GetNullAttr :
 | 
				
			||||||
   NativeCodeCall<"Attribute()">;
 | 
					   NativeCodeCall<"Attribute()">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -100,4 +113,63 @@ def ConvOpPaddingPattern: Pat<
 | 
				
			||||||
  [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
 | 
					  [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)]
 | 
				
			||||||
>;
 | 
					>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv'
 | 
				
			||||||
 | 
					// by deriving new 'w' and 'b' for 'Conv':
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// We have:
 | 
				
			||||||
 | 
					//   (Conv)      z = w * x + b 
 | 
				
			||||||
 | 
					//   (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// which corresponds to the following computation:
 | 
				
			||||||
 | 
					//   y = w_ * x + b_
 | 
				
			||||||
 | 
					// where
 | 
				
			||||||
 | 
					//   w_ = scale * w / sqrt(var + eps)
 | 
				
			||||||
 | 
					//   b_ = B + scale * (b - mean) / sqrt(var + eps)
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Hence, we rewrite: 
 | 
				
			||||||
 | 
					//   onnx.BatchNormalizationTestMode(
 | 
				
			||||||
 | 
					//       onnx.Conv(x, w, b),
 | 
				
			||||||
 | 
					//       scale, B, mean, var
 | 
				
			||||||
 | 
					//   ) {eps = ...}
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// as:
 | 
				
			||||||
 | 
					//    onnx.Conv(x, w_, b_)
 | 
				
			||||||
 | 
					//    
 | 
				
			||||||
 | 
					//    where
 | 
				
			||||||
 | 
					//      w_ = scale * w / sqrt(var + eps)
 | 
				
			||||||
 | 
					//      b_ = B + scale * (b - mean) / sqrt(var + eps)
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def FuseBatchNormTestModeConvPattern: Pat<
 | 
				
			||||||
 | 
					  (ONNXBatchNormalizationTestModeOp:$res
 | 
				
			||||||
 | 
					    (ONNXConvOp $x, $w, $b,
 | 
				
			||||||
 | 
					                $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides),
 | 
				
			||||||
 | 
					    $scale, $B, $mean, $var, $epsilon, $momentum),
 | 
				
			||||||
 | 
					  (ONNXConvOp
 | 
				
			||||||
 | 
					     $x,
 | 
				
			||||||
 | 
					     // w_
 | 
				
			||||||
 | 
					     (ONNXMulOp
 | 
				
			||||||
 | 
					        $w,
 | 
				
			||||||
 | 
					        (ONNXUnsqueezeOp
 | 
				
			||||||
 | 
					           (ONNXDivOp:$coefficientW
 | 
				
			||||||
 | 
					              $scale,
 | 
				
			||||||
 | 
					              (ONNXSqrtOp
 | 
				
			||||||
 | 
					                 (ONNXAddOp
 | 
				
			||||||
 | 
					                    $var,
 | 
				
			||||||
 | 
					                    (ONNXConstantOp
 | 
				
			||||||
 | 
					                       (GetNullAttr),
 | 
				
			||||||
 | 
					                       (createDenseElementsAttrFromFloatAttr $res, $epsilon))))),
 | 
				
			||||||
 | 
					           (createArrayAttrOfOneToRankOf $w))),
 | 
				
			||||||
 | 
					     // b_
 | 
				
			||||||
 | 
					     (ONNXAddOp
 | 
				
			||||||
 | 
					        $B,
 | 
				
			||||||
 | 
					        (ONNXMulOp
 | 
				
			||||||
 | 
					           $coefficientW,
 | 
				
			||||||
 | 
					           (subtractOrNeg $res, $b, $mean))),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					     $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides)
 | 
				
			||||||
 | 
					>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif // ONNX_REWRITE
 | 
					#endif // ONNX_REWRITE
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,3 +106,88 @@ func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-NEXT: return %arg0 : tensor<2xf32>
 | 
					  // CHECK-NEXT: return %arg0 : tensor<2xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_conv_batchnormtestmode_fusion_nobias(%arg0 : tensor<1x3x224x224xf32>) -> tensor<1x64x112x112xf32> {
 | 
				
			||||||
 | 
					    %cst = constant unit
 | 
				
			||||||
 | 
					    %0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
 | 
				
			||||||
 | 
					    %1 = "onnx.Conv"(%arg0, %0, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, none) -> tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					    %2 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %3 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %4 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %5 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					    return %6 :  tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK-LABEL: test_conv_batchnormtestmode_fusion_nobias
 | 
				
			||||||
 | 
					    // CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[NEG_MEAN:%.+]] = "onnx.Neg"([[MEAN]]) : (tensor<64xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[NEG_MEAN]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64>
 | 
				
			||||||
 | 
					    // CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: return [[RES]] : tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1 : tensor<64xf32>) -> tensor<1x64x112x112xf32> {
 | 
				
			||||||
 | 
					    %cst = constant unit
 | 
				
			||||||
 | 
					    %0 = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
 | 
				
			||||||
 | 
					    %1 = "onnx.Conv"(%arg0, %0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					    %2 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %3 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %4 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %5 = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    %6 = "onnx.BatchNormalizationTestMode"(%1, %2, %3, %4, %5) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					    return %6 :  tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK-LABEL: test_conv_batchnormtestmode_fusion
 | 
				
			||||||
 | 
					    // CHECK: [[WEIGHT:%.+]] = "onnx.Constant"() : () -> tensor<64x3x7x7xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[SCALE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[B:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[MEAN:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[VARIANCE:%.+]] = "onnx.Constant"() : () -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[EPSILON:%.+]] = "onnx.Constant"() {value = dense<1.00000007E-5> : tensor<1xf32>} : () -> tensor<1xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[VAR_EPSILON:%.+]] = "onnx.Add"([[VARIANCE]], [[EPSILON]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[SQRT:%.+]] = "onnx.Sqrt"([[VAR_EPSILON]]) : (tensor<64xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[COEFFICIENT_W:%.+]] = "onnx.Div"([[SCALE]], [[SQRT]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[UNSQUEEZE:%.+]] = "onnx.Unsqueeze"([[COEFFICIENT_W]]) {axes = [1, 2, 3]} : (tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[NEW_WEIGHT:%.+]] = "onnx.Mul"([[WEIGHT]], [[UNSQUEEZE]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[SUB:%.+]] = "onnx.Sub"(%arg1, [[MEAN]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[MUL:%.+]] = "onnx.Mul"([[COEFFICIENT_W]], [[SUB]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[NEW_BIAS:%.+]] = "onnx.Add"([[B]], [[MUL]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[PAD_ARG1:%.+]] = "onnx.Constant"() {value = dense<[0, 0, 3, 3, 0, 0, 3, 3]> : tensor<8xi64>} : () -> tensor<8xi64>
 | 
				
			||||||
 | 
					    // CHECK: [[PAD_ARG2:%.+]] = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
 | 
				
			||||||
 | 
					    // CHECK: [[PADDED_INPUT:%.+]] = "onnx.Pad"(%arg0, [[PAD_ARG1]], [[PAD_ARG2]]) {mode = "constant"} : (tensor<1x3x224x224xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: [[RES:%.+]] = "onnx.Conv"([[PADDED_INPUT]], [[NEW_WEIGHT]], [[NEW_BIAS]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [7, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK-NOT: {{.*}} = "onnx.BatchNormalizationTestMode"{{.*}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // CHECK: return [[RES]] : tensor<1x64x112x112xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -227,3 +227,47 @@ func @test_default_transpose_const_3() -> tensor<*xi32> {
 | 
				
			||||||
  // CHECK: [[RES:%.+]] =  "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
 | 
					  // CHECK: [[RES:%.+]] =  "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<3x2x4xi32>
 | 
					  // CHECK: return [[RES]] : tensor<3x2x4xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					/// Div tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32>
 | 
				
			||||||
 | 
					func @test_div(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Constant"() {value = dense<[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
 | 
				
			||||||
 | 
					  %1 = "onnx.Constant"() {value = dense<[[2.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
 | 
				
			||||||
 | 
					  %2 = "onnx.Div"(%0, %1) : (tensor<3x2xf32>, tensor<1x1xf32>) -> tensor<3x2xf32>
 | 
				
			||||||
 | 
					  "std.return"(%2) : (tensor<3x2xf32>) -> ()
 | 
				
			||||||
 | 
					  // CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]{{\]}}> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
 | 
				
			||||||
 | 
					  // CHECK-NOT: {{.*}} = "onnx.Div"{{.*}}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					/// Sqrt tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @test_sqrt() -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					func @test_sqrt() -> tensor<1x2xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Constant"() {value = dense<[[4.0, 16.0]]> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  %1 = "onnx.Sqrt"(%0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<1x2xf32>) -> ()
 | 
				
			||||||
 | 
					  // CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}[2.000000e+00, 4.000000e+00]{{\]}}> : tensor<1x2xf32>} : () -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					/// Unsqueeze tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @test_unsqueeze() -> tensor<2x1x1xf32>
 | 
				
			||||||
 | 
					func @test_unsqueeze() -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Constant"() {value = dense<[4.0, 16.0]> : tensor<2xf32>} : () -> tensor<2xf32>
 | 
				
			||||||
 | 
					  %1 = "onnx.Unsqueeze"(%0) {axes = [1, 2]} : (tensor<2xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					  // CHECK: {{.*}} = "onnx.Constant"() {value = dense<{{\[}}{{\[}}[4.000000e+00]{{\]}}, {{\[}}[1.600000e+01]{{\]}}{{\]}}> : tensor<2x1x1xf32>} : () -> tensor<2x1x1xf32>
 | 
				
			||||||
 | 
					  // CHECK-NOT: {{.*}} = "onnx.Unsqueeze"{{.*}}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -364,7 +364,8 @@ OpsWithResultTypeInference = {
 | 
				
			||||||
# Currenlty, there are only two build methods generated:
 | 
					# Currenlty, there are only two build methods generated:
 | 
				
			||||||
#  - one with operands and attributes having a separate parameter, and
 | 
					#  - one with operands and attributes having a separate parameter, and
 | 
				
			||||||
#  - one with operands and attributes having aggregated parameters.
 | 
					#  - one with operands and attributes having aggregated parameters.
 | 
				
			||||||
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
 | 
					custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare',
 | 
				
			||||||
 | 
					                                    'Pad', 'Sqrt', 'Neg', 'Unsqueeze']
 | 
				
			||||||
# Custom builder op list for operations with broadcast; we can deduce the right
 | 
					# Custom builder op list for operations with broadcast; we can deduce the right
 | 
				
			||||||
# output type, no need to leave it undef as in the above list.
 | 
					# output type, no need to leave it undef as in the above list.
 | 
				
			||||||
# Ops must have two operands, not one, not three... And there shall be two.
 | 
					# Ops must have two operands, not one, not three... And there shall be two.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue