Merge branch 'master' into enable-reciprocal
This commit is contained in:
commit
0ae8a0f23c
|
@ -104,7 +104,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
|||
}
|
||||
|
||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Conv operation with no Bias operand.";
|
||||
let description = [{
|
||||
"The convolution operator consumes an input tensor and a filter, and"
|
||||
|
@ -112,6 +112,8 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||
|
|
|
@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() {
|
|||
void ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
return;
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
|
@ -448,6 +448,181 @@ LogicalResult verify(ONNXTransposeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Conv
|
||||
|
||||
// For this operation, we define the attributes once in the original Conv
|
||||
// operation class. There is no need to redefine the attribute names for the
|
||||
// other classes based on Conv.
|
||||
void ONNXConvNoBiasOp::inferShapes() {
|
||||
// Generic shape for data input X and weight tensor W:
|
||||
// X: (N x C x D1 x D2 ... x Dn)
|
||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
// Check that shape of weight and data have same length.
|
||||
if (dataShape.size() != weightShape.size())
|
||||
emitError("Weight size not compatible with data size.");
|
||||
|
||||
// Required attribute auto_pad defaults to NOTSET.
|
||||
auto autoPad = getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName()).getValue();
|
||||
// Group is a required attribute and should have default value of 1.
|
||||
int64_t group = getAttrOfType<IntegerAttr>(
|
||||
ONNXConvOp::getGroupAttrName()).getInt();
|
||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||
if (dataShape[1] != (weightShape[1] * group))
|
||||
emitError("Channel dimension mismatch.");
|
||||
|
||||
// Note: the value of the group attribut only impacts the way the
|
||||
// computation is carried out and not the actual output size.
|
||||
|
||||
// First two output dimensions consist of the number of batches and the
|
||||
// number of kernels being applied.
|
||||
//
|
||||
SmallVector<int64_t, 2> dims;
|
||||
// Insert batch size.
|
||||
dims.emplace_back(dataShape[0]);
|
||||
// Insert number of filters being applied (number of output channels).
|
||||
dims.emplace_back(weightShape[0]);
|
||||
|
||||
// Spatial dimensions of the output are computed using the formula:
|
||||
//
|
||||
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
|
||||
//
|
||||
SmallVector<int64_t, 2> outSpatialDims;
|
||||
// Number of spatial dimensions.
|
||||
int32_t nDims = dataShape.size() - 2;
|
||||
|
||||
// Initialize dimenions based on the input spatial dimensions.
|
||||
for (int i = 2; i < dataShape.size(); ++i)
|
||||
outSpatialDims.emplace_back(dataShape[i]);
|
||||
|
||||
// Use kernel_shape attribute if present otherwise use size from weight
|
||||
// argument.
|
||||
SmallVector<int64_t, 2> kernelDims;
|
||||
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getKernelShapeAttrName())) {
|
||||
if (kernelShape.getValue().size() != nDims)
|
||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims.emplace_back(
|
||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
|
||||
} else {
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims.emplace_back(weightShape[i + 2]);
|
||||
}
|
||||
|
||||
// Check if dilations attribute is present.
|
||||
// If it is then compute new kernel size that includes the receptive field.
|
||||
// In this calculation we assume that the receptive field pixels must all be
|
||||
// within the bounds of the image. In this case the new kernel size is given
|
||||
// by:
|
||||
//
|
||||
// ( K + 1 ) * d - 1
|
||||
// where K is a kernel dimension and d is the dilation along that axis.
|
||||
//
|
||||
// From a dimensionality perspective the kernel size becomes the dilated
|
||||
// kernel size.
|
||||
if (auto dilations = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getDilationsAttrName())) {
|
||||
if (dilations.getValue().size() != nDims)
|
||||
emitError("dilations length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims[i] = (kernelDims[i] + 1) *
|
||||
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
|
||||
}
|
||||
|
||||
// Subtract kernel dimensions from input data dimensions.
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
outSpatialDims[i] -= kernelDims[i];
|
||||
|
||||
// Add padding information.
|
||||
if (autoPad == "NOTSET") {
|
||||
// Use pads to to determine the padding. If attribute is not
|
||||
// present then pads is considered to be all zeros (no padding).
|
||||
if (auto pads = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getPadsAttrName())) {
|
||||
// pads consists of two entries for each spatial axis.
|
||||
if (pads.getValue().size() != 2 * nDims)
|
||||
emitError("pads size is not twice the spatial size.");
|
||||
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
// Padding for beginning of axis.
|
||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
outSpatialDims[i] += p;
|
||||
// Padding for end of axis.
|
||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||
outSpatialDims[i] += p;
|
||||
}
|
||||
}
|
||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
// Pad input so that output size matches input size.
|
||||
// Each spatial dimension needs to be padded by a total of:
|
||||
//
|
||||
// K - 1
|
||||
//
|
||||
// where K is a kernel spatial dimension.
|
||||
// Pad as if stride is 1.
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
outSpatialDims[i] += kernelDims[i] - 1;
|
||||
} else if (autoPad == "VALID") {
|
||||
// No padding
|
||||
} else {
|
||||
emitError("Unexpected attribute value for auto_pad.");
|
||||
}
|
||||
|
||||
// Strides
|
||||
if (auto strides = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getStridesAttrName())) {
|
||||
if (strides.getValue().size() != nDims)
|
||||
emitError("strides length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t stride =
|
||||
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
outSpatialDims[i] += 1;
|
||||
|
||||
dims.append(outSpatialDims.begin(), outSpatialDims.end());
|
||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||
}
|
||||
|
||||
LogicalResult verify(ONNXConvNoBiasOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
if (!module)
|
||||
op.emitError("expected to belong to a module");
|
||||
|
||||
auto autoPadAttr = op.getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName());
|
||||
if (!autoPadAttr)
|
||||
op.emitError("auto_pad attribute not specified.");
|
||||
if (autoPadAttr.getValue() != "NOTSET")
|
||||
if (auto pads = op.getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getPadsAttrName()))
|
||||
op.emitError("auto_pad and pads are both set.");
|
||||
|
||||
auto groupAttr =
|
||||
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
|
||||
if (!groupAttr)
|
||||
op.emitError("group attribute not specified.");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -324,6 +324,15 @@ def ONNXConvOp:ONNX_Op<"Conv",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getAutoPadAttrName() { return "auto_pad"; }
|
||||
static StringRef getDilationsAttrName() { return "dilations"; }
|
||||
static StringRef getGroupAttrName() { return "group"; }
|
||||
static StringRef getKernelShapeAttrName() { return "kernel_shape"; }
|
||||
static StringRef getPadsAttrName() { return "pads"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
|
||||
|
|
|
@ -117,7 +117,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.Softmax")
|
||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||
op->getName().getStringRef() != "onnx.ConvNoBias")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<RankedTensorType>();
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
// RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test the default behavior of transpose when no information for the
|
||||
/// permutation of the axes is provided.
|
||||
/// permutation of the axes is provided and when a permutation is provided.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
@ -12,6 +15,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||
|
||||
/// Test shape inference for transposition when perm attribute is specified.
|
||||
|
||||
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
@ -20,3 +24,128 @@ func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
|||
// CHECK-LABEL: test_transpose
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test shape inference for ConvNoBias operation and all its attributes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Default and required attributes.
|
||||
|
||||
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_1
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
|
||||
|
||||
/// kernel_shape attribute.
|
||||
|
||||
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_2
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
|
||||
|
||||
/// pads attribute.
|
||||
/// Use pads to make output size equal to input size by adding K - 1 to the result.
|
||||
|
||||
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_3
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||
|
||||
/// auto_pad set to SAME_UPPER and SAME_LOWER.
|
||||
|
||||
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_4
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||
|
||||
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_5
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||
|
||||
/// auto_pad set to VALID.
|
||||
|
||||
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_6
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
|
||||
|
||||
/// With strides attribute.
|
||||
|
||||
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_7
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
|
||||
|
||||
/// auto_pad set to SAME_UPPER with strides attribute.
|
||||
/// The auto_pad will pas as if stride is equal to 1.
|
||||
|
||||
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_8
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
|
||||
|
||||
/// dilations attribute.
|
||||
|
||||
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_9
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x20x42xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x20x42xf32>
|
||||
|
||||
/// dilations attribute with stride.
|
||||
|
||||
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_10
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32, strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x10x21xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x10x21xf32>
|
||||
|
||||
/// dilations attribute with auto_pad set to SAME_UPPER.
|
||||
|
||||
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_11
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||
|
|
Loading…
Reference in New Issue