Unify Conv implementation (#54)

* fixed readme for new git repo

* conv with bias as an optional input
This commit is contained in:
Alexandre Eichenberger 2020-03-26 11:03:19 -04:00 committed by GitHub
parent 1777c07b1e
commit 653fa69102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 166 additions and 130 deletions

View File

@ -636,35 +636,6 @@ ONNX ConvInteger operation
1. `y`: memref of any type values or tensor of any type values 1. `y`: memref of any type values or tensor of any type values
### onnx.ConvNoBias (ONNXConvNoBiasOp)
ONNX Conv operation with no Bias operand.
#### Description:
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `auto_pad` | `StringAttr` | string attribute attribute |
| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `group` | `IntegerAttr` | 64-bit integer attribute attribute |
| `kernel_shape` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute |
#### Results:
1. `o_Y`: memref of any type values or tensor of any type values
### onnx.Conv (ONNXConvOp) ### onnx.Conv (ONNXConvOp)
ONNX Conv operation ONNX Conv operation

View File

@ -32,7 +32,6 @@ special_attr_defaults = dict([
# Special operation importing handlers. # Special operation importing handlers.
special_op_handler = dict([ special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"), ("MaxPool", "ImportNodeMaxPool"),
("BatchNormalization", "ImportNodeBatchNormalization"), ("BatchNormalization", "ImportNodeBatchNormalization"),
("Pad", "ImportNodePad"), ("Pad", "ImportNodePad"),
@ -47,11 +46,11 @@ OpsWithShapeInference = [
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs' 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm'] OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
# Operations who have operands that, if produced by constant operations, should # Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion). # be promoted to become an attribute (via attribute promotion).

View File

@ -303,30 +303,6 @@ private:
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut); buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
} }
/*!
* Special handle for Conv operations.
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
// dilations_attr = get_attr_ints(node, "dilations",
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
// 1));
// attributes.push_back(dilations_attr)
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
int nOps = node.input().size();
if (nOps == 2)
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
else
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
}
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */

View File

@ -50,7 +50,7 @@ if (opName == "Constant")
if (opName == "ConstantOfShape") if (opName == "ConstantOfShape")
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Conv") if (opName == "Conv")
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); return buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ConvInteger") if (opName == "ConvInteger")
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ConvTranspose") if (opName == "ConvTranspose")

View File

@ -12,18 +12,19 @@
using namespace mlir; using namespace mlir;
struct ONNXConvNoBiasOpLowering : public ConversionPattern { struct ONNXConvOpLowering : public ConversionPattern {
ONNXConvNoBiasOpLowering(MLIRContext *ctx) ONNXConvOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXConvOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
ONNXConvOpOperandAdaptor operandAdaptor(operands);
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op); ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
@ -32,12 +33,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
memRefType, loc, rewriter, insertDealloc, {operands[0]}); memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape(); auto resultShape = memRefType.getShape();
auto &inputOperand = operands[0]; auto inputOperand = operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape(); auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto &kernelOperand = operands[1]; auto kernelOperand = operandAdaptor.W();
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape(); auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
auto biasOperand = operandAdaptor.B();
bool hasBias = !biasOperand.getType().isa<NoneType>();
// R = ConvNoBias(D, K) // R = Conv(D, K)
// //
// The input/output shapes will look like this: // The input/output shapes will look like this:
// //
@ -169,8 +172,23 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// 3.4 Emit inner loop nest. // 3.4 Emit inner loop nest.
innerLoops.createIterateOp(); innerLoops.createIterateOp();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
// Emit the bias, if needed.
if (hasBias) {
auto loadResult =
rewriter.create<LoadOp>(loc, alloc, resultIndices);
SmallVector<Value, 4> biasIndices;
biasIndices.emplace_back(kernel);
auto loadBias =
rewriter.create<LoadOp>(loc, biasOperand, kernel);
auto resultWithBias = rewriter.create<MulFOp>(
loc, loadResult, loadBias);
// Store initializer value into output location.
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
}
//
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{ {
// 4. Emit inner loop body // 4. Emit inner loop body
// R[n][kernel][r1][r2] = // R[n][kernel][r1][r2] =
@ -238,5 +256,5 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
void populateLoweringONNXConvOpPattern( void populateLoweringONNXConvOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) { OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConvNoBiasOpLowering>(ctx); patterns.insert<ONNXConvOpLowering>(ctx);
} }

View File

@ -1022,14 +1022,18 @@ void ONNXReduceSumOp::inferShapes() {
// - kernelShape: inferred from weight matrix if not defined by user; // - kernelShape: inferred from weight matrix if not defined by user;
// - pads: set to proper value, 0 if not defined by user. // - pads: set to proper value, 0 if not defined by user.
void ONNXConvNoBiasOp::inferShapes() { void ONNXConvOp::inferShapes() {
// Generic shape for data input X and weight tensor W: // Generic shape for data input X, weight tensor W, and optional bias B
// X: (N x C x D1 x D2 ... x Dn) // X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn) // W: (M x C/group x k1 x k2 x ... x kn)
// B: (M) Optional
bool hasBias = !B().getType().isa<NoneType>();
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>()) !W().getType().isa<RankedTensorType>() ||
(hasBias && !B().getType().isa<RankedTensorType>()))
return; return;
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
@ -1047,7 +1051,7 @@ void ONNXConvNoBiasOp::inferShapes() {
emitError("Weight size not compatible with data size"); emitError("Weight size not compatible with data size");
// Group is a required attribute and should have default value of 1. // Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); int64_t group = ONNXConvOp::group().getSExtValue();
// Check if the attribute actually exists. If it does not then add it. // Check if the attribute actually exists. If it does not then add it.
if (!groupAttr()) if (!groupAttr())
@ -1058,6 +1062,16 @@ void ONNXConvNoBiasOp::inferShapes() {
xShape[1] != (weightShape[1] * group)) xShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch"); emitError("Channel dimension mismatch");
// Check the size of bias.
if (hasBias) {
auto bTx = B().getType().cast<RankedTensorType>();
auto bShape = bTx.getShape();
if (bShape.size() != 1)
emitError("bias should be one dimensional");
if (bShape[0] != weightShape[0])
emitError("bias should have same dimensions as weight's first dimension");
}
// Note: the value of the group attribut only impacts the way the // Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size. // computation is carried out and not the actual output size.

View File

@ -95,25 +95,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional // or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands. // arguments not ONNX operations with variadic operands.
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Conv operation with no Bias operand.";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
OptionalAttr<I64ArrayAttr>:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
OptionalAttr<I64ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;

View File

@ -363,7 +363,8 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
} }
def ONNXConvOp:ONNX_Op<"Conv", def ONNXConvOp:ONNX_Op<"Conv",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Conv operation"; let summary = "ONNX Conv operation";
let description = [{ let description = [{
"The convolution operator consumes an input tensor and a filter, and" "The convolution operator consumes an input tensor and a filter, and"

View File

@ -72,7 +72,7 @@ ArrayAttr insertZerosForNonPaddedDims(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Rewrite: // Rewrite:
// %0 = onnx.ConvNoBiasOp(%D : tensor<DShape>, %K) // %0 = onnx.Conv(%D : tensor<DShape>, %K)
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> // {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
// tensor<OutShape> // tensor<OutShape>
// //
@ -80,14 +80,14 @@ ArrayAttr insertZerosForNonPaddedDims(
// %0 = onnx.PadConstantValuePasOp(%D) // %0 = onnx.PadConstantValuePasOp(%D)
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> // {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
// tensor<DPaddedShape> // tensor<DPaddedShape>
// %1 = onnx.ConvNoBias(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} -> // %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
// tensor<OutShape> // tensor<OutShape>
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
struct SplitConvOpPattern : public RewritePattern { struct SplitConvOpPattern : public RewritePattern {
SplitConvOpPattern(MLIRContext *context) SplitConvOpPattern(MLIRContext *context)
: RewritePattern(ONNXConvNoBiasOp::getOperationName(), : RewritePattern(ONNXConvOp::getOperationName(),
{ONNXPadConstantValuePadOp::getOperationName(), {ONNXPadConstantValuePadOp::getOperationName(),
ONNXConvNoBiasOp::getOperationName()}, ONNXConvOp::getOperationName()},
1, context) {} 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op, PatternMatchResult matchAndRewrite(Operation *op,
@ -95,7 +95,7 @@ struct SplitConvOpPattern : public RewritePattern {
auto loc = op->getLoc(); auto loc = op->getLoc();
// If convolution does not use padding then no rewrite is required. // If convolution does not use padding then no rewrite is required.
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op); ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
auto padsAttribute = convOp.padsAttr(); auto padsAttribute = convOp.padsAttr();
if (!padsAttribute) if (!padsAttribute)
return matchFailure(); return matchFailure();
@ -155,8 +155,9 @@ struct SplitConvOpPattern : public RewritePattern {
SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0); SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0);
auto tensorType = (*op->result_type_begin()).cast<TensorType>(); auto tensorType = (*op->result_type_begin()).cast<TensorType>();
ONNXConvNoBiasOp newConvOp = rewriter.create<ONNXConvNoBiasOp>( ONNXConvOp newConvOp = rewriter.create<ONNXConvOp>(
loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1],
convOp.getOperands()[2],
convOp.auto_padAttr(), convOp.dilationsAttr(), convOp.auto_padAttr(), convOp.dilationsAttr(),
convOp.groupAttr(), convOp.kernel_shapeAttr(), convOp.groupAttr(), convOp.kernel_shapeAttr(),
rewriter.getI64ArrayAttr(newConvPads), rewriter.getI64ArrayAttr(newConvPads),
@ -173,8 +174,8 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MaxPoolSingleOutOpPaddingPattern>(context); results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
} }
/// on the ONNXConvNoBiasOp. /// on the ONNXConvOp.
void ONNXConvNoBiasOp::getCanonicalizationPatterns( void ONNXConvOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) { OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SplitConvOpPattern>(context); results.insert<SplitConvOpPattern>(context);
} }

View File

@ -113,7 +113,7 @@ public:
op->getName().getStringRef() != "onnx.ReduceSum" && op->getName().getStringRef() != "onnx.ReduceSum" &&
op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.ConvNoBias" && op->getName().getStringRef() != "onnx.Conv" &&
op->getName().getStringRef() != "onnx.PadConstantPad" && op->getName().getStringRef() != "onnx.PadConstantPad" &&
op->getName().getStringRef() != "onnx.PadConstantValuePad" && op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&

View File

@ -48,10 +48,12 @@ func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> { // CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> {
func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: %cst = constant unit
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32>
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> // CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32> // CHECK-NEXT: return %1 : tensor<*xf32>
} }

View File

@ -1149,7 +1149,8 @@ func @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32
} }
func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_no_pad // CHECK-LABEL: test_conv_no_bias_no_pad
@ -1191,8 +1192,55 @@ func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2
// CHECK: return [[RES]] : memref<1x5x27x58xf32> // CHECK: return [[RES]] : memref<1x5x27x58xf32>
} }
func @test_conv_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> {
%0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, tensor<5xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_bias_no_pad
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
// CHECK: [[CONST0:%.+]] = constant 5 : index
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[CONST2:%.+]] = constant 2 : index
// CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg3 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg4 = 0 to 5) {
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {
// CHECK: store [[CONST1]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
// CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg7 = 0 to 2, [[INNER_LOOPS]]#1 -> %arg8 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg9 = 0 to 7) {
// CHECK: [[R1PLUSK1:%.+]] = addi %arg5, %arg8 : index
// CHECK: [[R2PLUSK2:%.+]] = addi %arg6, %arg9 : index
// CHECK: [[DATA:%.+]] = load %arg0[%arg3, %arg7, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x2x32x64xf32>
// CHECK: [[KERNEL:%.+]] = load %arg1[%arg4, %arg7, %arg8, %arg9] : memref<5x2x6x7xf32>
// CHECK: [[ACC_RES:%.+]] = load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
// CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32
// CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
// CHECK: }
// CHECK: [[BIAS1:%.+]] = load [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
// CHECK: [[BIAS2:%.+]] = load %arg2[%arg4] : memref<5xf32>
// CHECK: [[BIAS3:%.+]] = mulf [[BIAS1]], [[BIAS2]] : f32
// CHECK: store [[BIAS3]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<1x5x27x58xf32>
}
func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group // CHECK-LABEL: test_conv_no_bias_no_pad_w_group
@ -1239,7 +1287,8 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
} }
func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_no_pad_w_strides // CHECK-LABEL: test_conv_no_bias_no_pad_w_strides

View File

@ -140,39 +140,42 @@ func @test_matmul_10(%arg0 : tensor<?x42x32xf32>, %arg1 : tensor<32xf32>) -> ten
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test shape inference for ConvNoBias operation and all its attributes. /// Test shape inference for Conv (first with no bias) operation and all its attributes.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Default and required attributes for 1-D convolution. /// Default and required attributes for 1-D convolution.
func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> { func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_0 // CHECK-LABEL: test_conv_no_bias_0
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<1x5x27xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<1x5x27xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32>
} }
/// Default and required attributes. /// Default and required attributes.
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_1 // CHECK-LABEL: test_conv_no_bias_1
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x27x58xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
} }
/// kernel_shape attribute. /// kernel_shape attribute.
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_2 // CHECK-LABEL: test_conv_no_bias_2
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x25x56xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
} }
@ -180,53 +183,58 @@ func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x
/// Use pads to make output size equal to input size by adding K - 1 to the result. /// Use pads to make output size equal to input size by adding K - 1 to the result.
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_3 // CHECK-LABEL: test_conv_no_bias_3
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
} }
/// auto_pad set to SAME_UPPER and SAME_LOWER. /// auto_pad set to SAME_UPPER and SAME_LOWER.
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_4 // CHECK-LABEL: test_conv_no_bias_4
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
} }
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_5 // CHECK-LABEL: test_conv_no_bias_5
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
} }
/// auto_pad set to VALID. /// auto_pad set to VALID.
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_6 // CHECK-LABEL: test_conv_no_bias_6
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x27x55xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
} }
/// With strides attribute. /// With strides attribute.
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_7 // CHECK-LABEL: test_conv_no_bias_7
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x14x20xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
} }
@ -234,45 +242,61 @@ func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x
/// The auto_pad will pas as if stride is equal to 1. /// The auto_pad will pas as if stride is equal to 1.
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_8 // CHECK-LABEL: test_conv_no_bias_8
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x16x22xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
} }
/// dilations attribute. /// dilations attribute.
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_9 // CHECK-LABEL: test_conv_no_bias_9
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x22x46xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x22x46xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32>
} }
/// dilations attribute with stride. /// dilations attribute with stride.
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_10 // CHECK-LABEL: test_conv_no_bias_10
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x11x23xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x11x23xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32>
} }
/// dilations attribute with auto_pad set to SAME_UPPER. /// dilations attribute with auto_pad set to SAME_UPPER.
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_conv_no_bias_11 // CHECK-LABEL: test_conv_no_bias_11
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x32x64xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
}
// Test convolution with bias input.
func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> {
%0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_12
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<1x5x27xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test shape inference for PadConstantValuePad. /// Test shape inference for PadConstantValuePad.