Merge pull request #41 from clang-ykt/infer-conv
Infer shape for ConvNoBias operation.
This commit is contained in:
		
						commit
						e64c63b07e
					
				|  | @ -104,7 +104,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", | def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", | ||||||
|     [NoSideEffect]> { |     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||||
|   let summary = "ONNX Conv operation with no Bias operand."; |   let summary = "ONNX Conv operation with no Bias operand."; | ||||||
|   let description = [{ |   let description = [{ | ||||||
|     "The convolution operator consumes an input tensor and a filter, and" |     "The convolution operator consumes an input tensor and a filter, and" | ||||||
|  | @ -112,6 +112,8 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", | ||||||
|   }]; |   }]; | ||||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W); |   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W); | ||||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); |   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||||
|  | 
 | ||||||
|  |   let verifier = [{ return ::verify(*this); }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | ||||||
|  |  | ||||||
|  | @ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() { | ||||||
| void ONNXTransposeOp::inferShapes() { | void ONNXTransposeOp::inferShapes() { | ||||||
|   // Cannot infer shape if no shape exists.
 |   // Cannot infer shape if no shape exists.
 | ||||||
|   if (!getOperand().getType().isa<RankedTensorType>()) |   if (!getOperand().getType().isa<RankedTensorType>()) | ||||||
|     emitError("Shape tensor not ranked."); |     return; | ||||||
| 
 | 
 | ||||||
|   // Naive transposition which handles the default case of
 |   // Naive transposition which handles the default case of
 | ||||||
|   // reversing the shape of the tensor (similar to numpy.transpose).
 |   // reversing the shape of the tensor (similar to numpy.transpose).
 | ||||||
|  | @ -448,6 +448,181 @@ LogicalResult verify(ONNXTransposeOp op) { | ||||||
|   return success(); |   return success(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | // Conv
 | ||||||
|  | 
 | ||||||
|  | // For this operation, we define the attributes once in the original Conv
 | ||||||
|  | // operation class. There is no need to redefine the attribute names for the
 | ||||||
|  | // other classes based on Conv.
 | ||||||
|  | void ONNXConvNoBiasOp::inferShapes() { | ||||||
|  |   // Generic shape for data input X and weight tensor W:
 | ||||||
|  |   // X: (N x C x D1 x D2 ... x Dn)
 | ||||||
|  |   // W: (M x C/group x k1 x k2 x ... x kn)
 | ||||||
|  | 
 | ||||||
|  |   // Cannot infer shape if no shape exists.
 | ||||||
|  |   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||||
|  |       !getOperand(1).getType().isa<RankedTensorType>()) | ||||||
|  |     return; | ||||||
|  | 
 | ||||||
|  |   auto dataTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||||
|  |   auto weightTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||||
|  |   auto dataShape = dataTy.getShape(); | ||||||
|  |   auto weightShape = weightTy.getShape(); | ||||||
|  | 
 | ||||||
|  |   // Check that shape of weight and data have same length.
 | ||||||
|  |   if (dataShape.size() != weightShape.size()) | ||||||
|  |     emitError("Weight size not compatible with data size."); | ||||||
|  | 
 | ||||||
|  |   // Required attribute auto_pad defaults to NOTSET.
 | ||||||
|  |   auto autoPad = getAttrOfType<StringAttr>( | ||||||
|  |       ONNXConvOp::getAutoPadAttrName()).getValue(); | ||||||
|  |   // Group is a required attribute and should have default value of 1.
 | ||||||
|  |   int64_t group = getAttrOfType<IntegerAttr>( | ||||||
|  |       ONNXConvOp::getGroupAttrName()).getInt(); | ||||||
|  |   // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
 | ||||||
|  |   if (dataShape[1] != (weightShape[1] * group)) | ||||||
|  |     emitError("Channel dimension mismatch."); | ||||||
|  | 
 | ||||||
|  |   // Note: the value of the group attribut only impacts the way the
 | ||||||
|  |   // computation is carried out and not the actual output size.
 | ||||||
|  | 
 | ||||||
|  |   // First two output dimensions consist of the number of batches and the
 | ||||||
|  |   // number of kernels being applied.
 | ||||||
|  |   //
 | ||||||
|  |   SmallVector<int64_t, 2> dims; | ||||||
|  |   // Insert batch size.
 | ||||||
|  |   dims.emplace_back(dataShape[0]); | ||||||
|  |   // Insert number of filters being applied (number of output channels).
 | ||||||
|  |   dims.emplace_back(weightShape[0]); | ||||||
|  | 
 | ||||||
|  |   // Spatial dimensions of the output are computed using the formula:
 | ||||||
|  |   //
 | ||||||
|  |   // dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
 | ||||||
|  |   //
 | ||||||
|  |   SmallVector<int64_t, 2> outSpatialDims; | ||||||
|  |   // Number of spatial dimensions.
 | ||||||
|  |   int32_t nDims = dataShape.size() - 2; | ||||||
|  | 
 | ||||||
|  |   // Initialize dimenions based on the input spatial dimensions.
 | ||||||
|  |   for (int i = 2; i < dataShape.size(); ++i) | ||||||
|  |     outSpatialDims.emplace_back(dataShape[i]); | ||||||
|  | 
 | ||||||
|  |   // Use kernel_shape attribute if present otherwise use size from weight
 | ||||||
|  |   // argument.
 | ||||||
|  |   SmallVector<int64_t, 2> kernelDims; | ||||||
|  |   if (auto kernelShape = getAttrOfType<ArrayAttr>( | ||||||
|  |           ONNXConvOp::getKernelShapeAttrName())) { | ||||||
|  |     if (kernelShape.getValue().size() != nDims) | ||||||
|  |       emitError("kernel_shape length incompatible with spatial dimensions."); | ||||||
|  |     for (int i = 0; i < nDims; ++i) | ||||||
|  |       kernelDims.emplace_back( | ||||||
|  |           (kernelShape.getValue()[i]).cast<IntegerAttr>().getInt()); | ||||||
|  |   } else { | ||||||
|  |     for (int i = 0; i < nDims; ++i) | ||||||
|  |       kernelDims.emplace_back(weightShape[i + 2]); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Check if dilations attribute is present.
 | ||||||
|  |   // If it is then compute new kernel size that includes the receptive field.
 | ||||||
|  |   // In this calculation we assume that the receptive field pixels must all be
 | ||||||
|  |   // within the bounds of the image. In this case the new kernel size is given
 | ||||||
|  |   // by:
 | ||||||
|  |   //
 | ||||||
|  |   // ( K + 1 ) * d - 1
 | ||||||
|  |   // where K is a kernel dimension and d is the dilation along that axis.
 | ||||||
|  |   //
 | ||||||
|  |   // From a dimensionality perspective the kernel size becomes the dilated
 | ||||||
|  |   // kernel size.
 | ||||||
|  |   if (auto dilations = getAttrOfType<ArrayAttr>( | ||||||
|  |           ONNXConvOp::getDilationsAttrName())) { | ||||||
|  |     if (dilations.getValue().size() != nDims) | ||||||
|  |       emitError("dilations length incompatible with spatial dimensions."); | ||||||
|  |     for (int i = 0; i < nDims; ++i) | ||||||
|  |       kernelDims[i] = (kernelDims[i] + 1) * | ||||||
|  |           (dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Subtract kernel dimensions from input data dimensions.
 | ||||||
|  |   for (int i = 0; i < nDims; ++i) | ||||||
|  |     outSpatialDims[i] -= kernelDims[i]; | ||||||
|  | 
 | ||||||
|  |   // Add padding information.
 | ||||||
|  |   if (autoPad == "NOTSET") { | ||||||
|  |     // Use pads to to determine the padding. If attribute is not
 | ||||||
|  |     // present then pads is considered to be all zeros (no padding).
 | ||||||
|  |     if (auto pads = getAttrOfType<ArrayAttr>( | ||||||
|  |             ONNXConvOp::getPadsAttrName())) { | ||||||
|  |       // pads consists of two entries for each spatial axis.
 | ||||||
|  |       if (pads.getValue().size() != 2 * nDims) | ||||||
|  |         emitError("pads size is not twice the spatial size."); | ||||||
|  | 
 | ||||||
|  |       for (int i = 0; i < nDims; ++i) { | ||||||
|  |         // Padding for beginning of axis.
 | ||||||
|  |         int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt(); | ||||||
|  |         outSpatialDims[i] += p; | ||||||
|  |         // Padding for end of axis.
 | ||||||
|  |         p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt(); | ||||||
|  |         outSpatialDims[i] += p; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { | ||||||
|  |     // Pad input so that output size matches input size.
 | ||||||
|  |     // Each spatial dimension needs to be padded by a total of:
 | ||||||
|  |     //
 | ||||||
|  |     // K - 1
 | ||||||
|  |     //
 | ||||||
|  |     // where K is a kernel spatial dimension.
 | ||||||
|  |     // Pad as if stride is 1.
 | ||||||
|  |     for (int i = 0; i < nDims; ++i) | ||||||
|  |       outSpatialDims[i] += kernelDims[i] - 1; | ||||||
|  |   } else if (autoPad == "VALID") { | ||||||
|  |     // No padding
 | ||||||
|  |   } else { | ||||||
|  |     emitError("Unexpected attribute value for auto_pad."); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Strides
 | ||||||
|  |   if (auto strides = getAttrOfType<ArrayAttr>( | ||||||
|  |       ONNXConvOp::getStridesAttrName())) { | ||||||
|  |     if (strides.getValue().size() != nDims) | ||||||
|  |       emitError("strides length incompatible with spatial dimensions."); | ||||||
|  |     for (int i = 0; i < nDims; ++i) { | ||||||
|  |       int64_t stride = | ||||||
|  |           (strides.getValue()[i]).cast<IntegerAttr>().getInt(); | ||||||
|  |       outSpatialDims[i] = floor(outSpatialDims[i] / stride); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   for (int i = 0; i < nDims; ++i) | ||||||
|  |     outSpatialDims[i] += 1; | ||||||
|  | 
 | ||||||
|  |   dims.append(outSpatialDims.begin(), outSpatialDims.end()); | ||||||
|  |   getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | LogicalResult verify(ONNXConvNoBiasOp op) { | ||||||
|  |   auto module = op.getParentOfType<ModuleOp>(); | ||||||
|  |   if (!module) | ||||||
|  |     op.emitError("expected to belong to a module"); | ||||||
|  | 
 | ||||||
|  |   auto autoPadAttr = op.getAttrOfType<StringAttr>( | ||||||
|  |       ONNXConvOp::getAutoPadAttrName()); | ||||||
|  |   if (!autoPadAttr) | ||||||
|  |     op.emitError("auto_pad attribute not specified."); | ||||||
|  |   if (autoPadAttr.getValue() != "NOTSET") | ||||||
|  |     if (auto pads = op.getAttrOfType<ArrayAttr>( | ||||||
|  |             ONNXConvOp::getPadsAttrName())) | ||||||
|  |       op.emitError("auto_pad and pads are both set."); | ||||||
|  | 
 | ||||||
|  |   auto groupAttr = | ||||||
|  |       op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName()); | ||||||
|  |   if (!groupAttr) | ||||||
|  |     op.emitError("group attribute not specified."); | ||||||
|  | 
 | ||||||
|  |   return success(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| // TableGen'd op method definitions
 | // TableGen'd op method definitions
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  |  | ||||||
|  | @ -324,6 +324,15 @@ def ONNXConvOp:ONNX_Op<"Conv", | ||||||
|   }]; |   }]; | ||||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); |   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); | ||||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); |   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); | ||||||
|  | 
 | ||||||
|  |   let extraClassDeclaration = [{ | ||||||
|  |      static StringRef getAutoPadAttrName() { return "auto_pad"; } | ||||||
|  |      static StringRef getDilationsAttrName() { return "dilations"; } | ||||||
|  |      static StringRef getGroupAttrName() { return "group"; } | ||||||
|  |      static StringRef getKernelShapeAttrName() { return "kernel_shape"; } | ||||||
|  |      static StringRef getPadsAttrName() { return "pads"; } | ||||||
|  |      static StringRef getStridesAttrName() { return "strides"; } | ||||||
|  |   }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",  | def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",  | ||||||
|  |  | ||||||
|  | @ -117,7 +117,8 @@ public: | ||||||
|         op->getName().getStringRef() != "onnx.GemmNoBias" && |         op->getName().getStringRef() != "onnx.GemmNoBias" && | ||||||
|         op->getName().getStringRef() != "onnx.Reshape" && |         op->getName().getStringRef() != "onnx.Reshape" && | ||||||
|         op->getName().getStringRef() != "onnx.Transpose" && |         op->getName().getStringRef() != "onnx.Transpose" && | ||||||
|         op->getName().getStringRef() != "onnx.Softmax") |         op->getName().getStringRef() != "onnx.Softmax" && | ||||||
|  |         op->getName().getStringRef() != "onnx.ConvNoBias") | ||||||
|       return false; |       return false; | ||||||
|     return llvm::any_of(op->getResultTypes(), [](Type result_type) { |     return llvm::any_of(op->getResultTypes(), [](Type result_type) { | ||||||
|       return !result_type.isa<RankedTensorType>(); |       return !result_type.isa<RankedTensorType>(); | ||||||
|  |  | ||||||
|  | @ -1,7 +1,10 @@ | ||||||
| // RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s | // RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
| /// Test the default behavior of transpose when no information for the | /// Test the default behavior of transpose when no information for the | ||||||
| /// permutation of the axes is provided. | /// permutation of the axes is provided and when a permutation is provided. | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | 
 | ||||||
| func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> |   %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | @ -12,6 +15,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||||
| // CHECK: return [[RES]] : tensor<32x1x5x5xf32> | // CHECK: return [[RES]] : tensor<32x1x5x5xf32> | ||||||
| 
 | 
 | ||||||
| /// Test shape inference for transposition when perm attribute is specified. | /// Test shape inference for transposition when perm attribute is specified. | ||||||
|  | 
 | ||||||
| func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> |   %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | @ -20,3 +24,128 @@ func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||||
| // CHECK-LABEL: test_transpose | // CHECK-LABEL: test_transpose | ||||||
| // CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> | // CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> | ||||||
| // CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> | // CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | /// Test shape inference for ConvNoBias operation and all its attributes. | ||||||
|  | //===----------------------------------------------------------------------===// | ||||||
|  | 
 | ||||||
|  | /// Default and required attributes. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_1 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32> | ||||||
|  | 
 | ||||||
|  | /// kernel_shape attribute. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_2 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32> | ||||||
|  | 
 | ||||||
|  | /// pads attribute. | ||||||
|  | /// Use pads to make output size equal to input size by adding K - 1 to the result. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_3 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
|  | 
 | ||||||
|  | /// auto_pad set to SAME_UPPER and SAME_LOWER. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_4 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_5 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
|  | 
 | ||||||
|  | /// auto_pad set to VALID. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_6 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32> | ||||||
|  | 
 | ||||||
|  | /// With strides attribute. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_7 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32> | ||||||
|  | 
 | ||||||
|  | /// auto_pad set to SAME_UPPER with strides attribute. | ||||||
|  | /// The auto_pad will pas as if stride is equal to 1. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_8 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32> | ||||||
|  | 
 | ||||||
|  | /// dilations attribute. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_9 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x20x42xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x20x42xf32> | ||||||
|  | 
 | ||||||
|  | /// dilations attribute with stride. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_10 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32, strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x10x21xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x10x21xf32> | ||||||
|  | 
 | ||||||
|  | /// dilations attribute with auto_pad set to SAME_UPPER. | ||||||
|  | 
 | ||||||
|  | func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: test_conv_no_bias_11 | ||||||
|  | // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> | ||||||
|  | // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue