diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 0f6a3f3..f71032b 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -765,11 +765,150 @@ void ONNXReduceSumOp::inferShapes() { getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); } +//===----------------------------------------------------------------------===// + +// Conv + +// Support function that computes default values for dilations, strides, and +// pads. + +template +static void processConvTypeParams(T *op, Value inputOperand) { + auto builder = mlir::Builder(op->getContext()); + + // 1) Get shape of input. + auto inputShape = inputOperand.getType().cast().getShape(); + auto inputRank = inputShape.size(); + + // 2) Get kernel sizes from kernel_shape attribute. + auto kernelShape = op->kernel_shape(); + auto kernelRank = ArrayAttrSize(kernelShape); + auto kernelOffset = inputRank - kernelRank; + + // Dilatation. + auto dilationsOpt = op->dilations(); + if (dilationsOpt.hasValue()) { + if (ArrayAttrSize(dilationsOpt) != kernelRank) + op->emitError("dialation rank is not the same as the spatial rank"); + // Test values to be greater than 0. + for (int i = 0; i < kernelRank; ++i) { + if (ArrayAttrIntVal(dilationsOpt, i) < 1) + op->emitError("dialation value must be nonzero positive"); + } + } else { + // Default dilatation is needed, all dimensions init with 1. + SmallVector defaultVals(kernelRank, 1); + // Convert to ArrayRef, then build attribute, then store attribute. + ArrayRef defaultRefs(defaultVals); + op->dilationsAttr(builder.getI64ArrayAttr(defaultRefs)); + } + + // Strides. + auto stridesOpt = op->strides(); + if (stridesOpt.hasValue()) { + if (ArrayAttrSize(stridesOpt) != kernelRank) + op->emitError("strides rank is not the same as the spatial rank"); + // Check values to be greater than 0. + for (int i = 0; i < kernelRank; ++i) { + if (ArrayAttrIntVal(stridesOpt, i) < 1) + op->emitError("strides value must be nonzero positive"); + } + } else { + // Default stride is needed, all dimensions init with 1. + SmallVector defaultVals(kernelRank, 1); + // Convert to ArrayRef, then build attribute, then store attribute. + ArrayRef defaultRefs(defaultVals); + op->stridesAttr(builder.getI64ArrayAttr(defaultRefs)); + } + + // Now try to find padding, getting auto_pad attribute first. + auto autoPad = op->auto_pad(); + // And then investigate the various different cases. Prefill pad values with + // zeros, the most common case. + SmallVector actualPads(2 * kernelRank, 0); + bool updatedPad = false; + if (autoPad == "NOTSET") { + auto padsOpt = op->pads(); + if (padsOpt.hasValue()) { + // Only option where pads are not updated. Pads consists of two entries + // for each spatial axis. + if (ArrayAttrSize(padsOpt) != 2 * kernelRank) + op->emitError("pads rank is not twice the spatial rank"); + // Check values, pads cannot be negative. + for (int i = 0; i < 2 * kernelRank; ++i) { + if (ArrayAttrIntVal(padsOpt, i) < 0) + op->emitError("pads value must be nonnegative"); + } + } else { + // We have notset with no pads, they are assumed to be all zero. + updatedPad = true; + } + } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + // Reload dialtion and strides as they may have gotten default values. + updatedPad = true; + dilationsOpt = op->dilations(); + stridesOpt = op->strides(); + for (int i = 0; i < kernelRank; ++i) { + auto inputSize = inputShape[kernelOffset + i]; + auto kernelSize = ArrayAttrIntVal(kernelShape, i); + auto dilationVal = ArrayAttrIntVal(dilationsOpt, i); + auto strideVal = ArrayAttrIntVal(stridesOpt, i); + // Output size is input size divided by stride. When stride is 1, then + // input and output are the same size, which is the usual case. When + // stride is greater than 1, take the ceil to be sure to have each input + // value used, as padding will be used to fill the gaps. + int64_t outputSize = ceil((1.0 * inputSize) / (1.0 * strideVal)); + // Forumla is from ONNX MaxPool, and can be explained as follows. Pads is + // the difference between the needed values for the computations, minus + // the input values. The needed values for the computation is the + // effective side of the kernel plus the number of times we jump to the + // next kernel. Number of time we jump is (outputSize - 1). That number is + // multiplied with the size of the jump, namely strideVal. Now for the + // effective kernel size. It is the kernelSize + the number of times we + // have dilation holes time the dialtion. The number of dialtion holes is + // (kernelSize -1). Thus the effective size is "kernelSize + + // (kernelSize-1)*dialation". This simplifies to "(kernelSize + // -1)*dialation + 1". + auto sumOfPad = (outputSize - 1) * strideVal + + ((kernelSize - 1) * dilationVal + 1) - inputSize; + // Pad values are assumed equal on both size, at half the total value. + actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2; + // But if the total pad value is odd, we add 1 to begining or end + // depending on autoPad value. + if (sumOfPad % 2 != 0) { + if (autoPad == "SAME_UPPER") { + actualPads[kernelRank + i] += 1; + } else { + actualPads[i] += 1; + } + } + } + } else if (autoPad == "VALID") { + // No pad, default value was set to zero, we are all set. + updatedPad = true; + } else { + op->emitError("auto_pad of unknown / unsupported value"); + } + // Set pads values in attributes, if it is needed. + if (updatedPad) { + ArrayRef defaultRefs(actualPads); + op->padsAttr(builder.getI64ArrayAttr(defaultRefs)); + } + // In all cases now, the acutal pad values are found in the pads attribute. + op->auto_padAttr(builder.getStringAttr("NOTSET")); +} + // Conv // For this operation, we define the attributes once in the original Conv // operation class. There is no need to redefine the attribute names for the // other classes based on Conv. +// Conv attributes output: +// - auto_pad set to NOTSET; +// - dilations, strides: set to 1 if not defined by user; +// - kernelShape: inferred from weight matrix if not defined by user; +// - pads: set to proper value, 0 if not defined by user. + void ONNXConvNoBiasOp::inferShapes() { // Generic shape for data input X and weight tensor W: // X: (N x C x D1 x D2 ... x Dn) @@ -780,179 +919,86 @@ void ONNXConvNoBiasOp::inferShapes() { !W().getType().isa()) return; - auto dataTy = X().getType().cast(); + auto xTy = X().getType().cast(); + auto xShape = xTy.getShape(); auto weightTy = W().getType().cast(); - auto inDataShape = dataTy.getShape(); auto weightShape = weightTy.getShape(); // Lowest supported convolution is a one dimensional convolution. - if (inDataShape.size() < 3) + if (xShape.size() < 3) emitError("Data input shape must be at least (NxCxD1)"); // Check that shape of weight and data have same length. - if (inDataShape.size() != weightShape.size()) + if (xShape.size() != weightShape.size()) emitError("Weight size not compatible with data size"); - // Required attribute auto_pad defaults to NOTSET. - auto autoPad = auto_pad(); // Group is a required attribute and should have default value of 1. - int64_t group = - ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue(); + int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. - if (inDataShape[1] != -1 && weightShape[1] != -1 && - inDataShape[1] != (weightShape[1] * group)) + if (xShape[1] != -1 && weightShape[1] != -1 && + xShape[1] != (weightShape[1] * group)) emitError("Channel dimension mismatch"); // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. - // First two output dimensions consist of the number of batches and the - // number of kernels being applied. - // - SmallVector dims; - // Insert batch size. - dims.emplace_back(inDataShape[0]); - // Insert number of filters being applied (number of output channels). - dims.emplace_back(weightShape[0]); - - // Spatial dimensions of the output are computed using the formula: - // - // dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1 - // - SmallVector outSpatialDims; // Number of spatial dimensions. - int32_t nSpatialDims = inDataShape.size() - 2; - - // Initialize dimenions based on the input spatial dimensions. - for (int i = 2; i < inDataShape.size(); ++i) - outSpatialDims.emplace_back(inDataShape[i]); + auto spatialOffset = 2; + int32_t spatialRank = xShape.size() - spatialOffset; // Use kernel_shape attribute if present otherwise use size from weight // argument. - SmallVector kernelDims; - if (auto kernelShape = kernel_shapeAttr()) { - if (ArrayAttrSize(kernelShape) != nSpatialDims) + auto kernelShape = kernel_shape(); + if (kernelShape.hasValue()) { + if (ArrayAttrSize(kernelShape) != spatialRank) emitError("kernel_shape length incompatible with spatial dimensions"); - for (int i = 0; i < nSpatialDims; ++i) - kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i)); + // Have the right number of values, check them. + for (int i = 0; i < spatialRank; ++i) + if (ArrayAttrIntVal(kernelShape, i) < 1) + emitError("bad kernel_shape value"); } else { - for (int i = 0; i < nSpatialDims; ++i) - kernelDims.emplace_back(weightShape[i + 2]); + // Deduce shape from weight input. + SmallVector defaultVals; + for (int i = 0; i < spatialRank; ++i) + defaultVals.emplace_back(weightShape[spatialOffset + i]); + // Convert to ArrayRef, then build attribute, then store attribute. + ArrayRef defaultRefs(defaultVals); + auto builder = mlir::Builder(getContext()); + kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs)); + kernelShape = kernel_shape(); } - // Check if dilations attribute is present. - // If it is then compute new kernel size that includes the receptive field. - // In this calculation we assume that the receptive field pixels must all be - // within the bounds of the image. In this case the new kernel size is given - // by: - // - // ( K + 1 ) * d - 1 - // where K is a kernel dimension and d is the dilation along that axis. - // - // From a dimensionality perspective the kernel size becomes the dilated - // kernel size. - if (auto dilations = dilationsAttr()) { - if (ArrayAttrSize(dilations) != nSpatialDims) - emitError("dilations length incompatible with spatial dimensions"); - for (int i = 0; i < nSpatialDims; ++i) - kernelDims[i] = - (kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1; + // Process strides, dilations, and pads. + processConvTypeParams<>(this, X()); + auto dilationsOpt = dilations(); + auto stridesOpt = strides(); + auto padsOpt = pads(); + + // First two output dimensions consist of the number of batches and the + // number of kernels being applied. + SmallVector outputDims; + // Insert batch size. + outputDims.emplace_back(xShape[0]); + // Insert number of filters being applied (number of output channels). + outputDims.emplace_back(weightShape[0]); + + // Then the spatial dimensions of the output are computed. + for (int i = 0; i < spatialRank; ++i) { + auto inputSize = xShape[spatialOffset + i]; + auto sumOfPads = + ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i); + auto kernelSize = ArrayAttrIntVal(kernelShape, i); + auto dilationVal = ArrayAttrIntVal(dilationsOpt, i); + auto strideVal = ArrayAttrIntVal(stridesOpt, i); + // Number of useful values: input plus pad - effective size of kernel (see + // processConvTypeParams comments to see how this value is derived). + double numerator = + inputSize + sumOfPads - ((kernelSize - 1) * dilationVal + 1); + // Useful number is divided by the strides. + double denominator = strideVal; + outputDims.emplace_back(floor(numerator / denominator) + 1); } - - // Subtract kernel dimensions from input data dimensions. - for (int i = 0; i < nSpatialDims; ++i) - outSpatialDims[i] -= kernelDims[i]; - - // Array which holds the padding information. - SmallVector actualPads(2 * nSpatialDims, 0); - auto stridesAttr = ONNXConvNoBiasOp::stridesAttr(); - - // Add padding information. - if (autoPad == "NOTSET") { - // Use pads to to determine the padding. If attribute is not - // present then pads is considered to be all zeros (no padding). - if (auto pads = padsAttr()) { - // pads consists of two entries for each spatial axis. - if (ArrayAttrSize(pads) != 2 * nSpatialDims) - emitError("pads size is not twice the spatial size"); - - for (int i = 0; i < nSpatialDims; ++i) { - // Padding for beginning of axis. - outSpatialDims[i] += ArrayAttrIntVal(pads, i); - // Padding for end of axis. - outSpatialDims[i] += ArrayAttrIntVal(pads, i + nSpatialDims); - } - } - } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - // Pad input so that output size matches input size. - // Each spatial dimension needs to be padded by a total of: - // - // stride * (InDim - 1) + KerDim - InDim - // - // where K is a kernel spatial dimension. - for (int i = 0; i < nSpatialDims; ++i) { - // If strides are given use them otherwise stride is 1. - int64_t stride = 1; - if (stridesAttr) - stride = ArrayAttrIntVal(stridesAttr, i); - - // Compute necessary padding. The input dimensions are stored in - // inDataShape. - int64_t totalPadding = stride * (inDataShape[i + 2] - 1) + - kernelDims[i] - inDataShape[i + 2]; - - // Adjust current output value with the value of the padding. - // When dividing by stride later on, the output dimension should - // be equal to the input dimension. - outSpatialDims[i] += totalPadding; - - // Record the upper and lower axis padding. - actualPads[i] = actualPads[i + nSpatialDims] = totalPadding / 2; - if (totalPadding % 2 != 0) { - if (autoPad == "SAME_LOWER") { - actualPads[i]++; - } else { - actualPads[i + nSpatialDims]++; - } - } - } - } else if (autoPad == "VALID") { - // No padding - } else { - emitError("Unexpected attribute value for auto_pad"); - } - - // Strides - if (stridesAttr) { - if (ArrayAttrSize(stridesAttr) != nSpatialDims) - emitError("strides length incompatible with spatial dimensions"); - for (int i = 0; i < nSpatialDims; ++i) { - int64_t stride = ArrayAttrIntVal(stridesAttr, i); - outSpatialDims[i] = floor(outSpatialDims[i] / stride); - } - } - - for (int i = 0; i < nSpatialDims; ++i) - outSpatialDims[i] += 1; - - // Check input and output sizes match. - if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - for (int i = 0; i < nSpatialDims; ++i) - if (outSpatialDims[i] != inDataShape[i + 2]) - emitError("input and output spatial dimension mismatch"); - - // Set pads values in attributes. - auto builder = mlir::Builder(this->getContext()); - ArrayRef defaultRefs(actualPads); - padsAttr(builder.getI64ArrayAttr(defaultRefs)); - - // Change auto padding attribute to NOTSET since padding values - // are now explicitly included in the operation. - auto_padAttr(builder.getStringAttr("NOTSET")); - } - - dims.append(outSpatialDims.begin(), outSpatialDims.end()); - getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); + getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); } //===----------------------------------------------------------------------===// @@ -987,112 +1033,29 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { // Ceil mode. auto ceilMode = ceil_mode().getSExtValue(); - // Dilatation. - auto dilationsOpt = dilations(); - if (dilationsOpt.hasValue()) { - if (ArrayAttrSize(dilationsOpt) != kernelRank) - emitError("dialation rank is not the same as the spatial rank"); - // Test values. - for (int i = 0; i < kernelRank; ++i) { - if (ArrayAttrIntVal(dilationsOpt, i) < 1) - emitError("dialation value must be nonzero positive"); - } - } else { - // Default dilatation is needed. - SmallVector defaultVals(kernelRank, 1); - // Convert to ArrayRef, then build attribute, then store attribute. - ArrayRef defaultRefs(defaultVals); - auto defaultAttr = builder.getI64ArrayAttr(defaultRefs); - dilationsAttr(defaultAttr); - dilationsOpt = dilations(); - } - // Storage order. auto storageOrder = storage_order().getSExtValue(); if (storageOrder != 0) emitError("column major storage order not supported at this time"); - // Strides. - auto stridesOpt = strides(); - if (stridesOpt.hasValue()) { - if (ArrayAttrSize(stridesOpt) != kernelRank) - emitError("strides rank is not the same as the spatial rank"); - // Check values. - for (int i = 0; i < kernelRank; ++i) { - if (ArrayAttrIntVal(stridesOpt, i) < 1) - emitError("strides value must be nonzero positive"); - } - } else { - SmallVector defaultVals(kernelRank, 1); - // Convert to ArrayRef, then build attribute, then store attribute. - ArrayRef defaultRefs(defaultVals); - auto defaultAttr = builder.getI64ArrayAttr(defaultRefs); - stridesAttr(defaultAttr); - stridesOpt = strides(); - } - - // Now try to find padding, getting auto_pad attribute first. - auto autoPad = auto_pad(); - // And then investigate the various different cases. - SmallVector actualPads(2 * kernelRank, 0); - if (autoPad == "NOTSET") { - auto padsOpt = pads(); - if (padsOpt.hasValue()) { - // Pads consists of two entries for each spatial axis. - if (ArrayAttrSize(padsOpt) != 2 * kernelRank) - emitError("pads rank is not twice the spatial rank"); - // Check values - for (int i = 0; i < 2 * kernelRank; ++i) { - int64_t p = ArrayAttrIntVal(padsOpt, i); - if (p < 0) - emitError("pads value must be nonnegative"); - actualPads[i] = p; - } - } - } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - for (int i = 0; i < kernelRank; ++i) { - auto inputSpatialShape = xShape[kernelOffset + i]; - auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i); - auto dilations = ArrayAttrIntVal(dilationsOpt, i); - auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i); - int64_t outputSpatialShape = - ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape)); - auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape + - ((kernelSpatialShape - 1) * dilations + 1) - - inputSpatialShape; - actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2; - if (sumOfPad % 2 != 0) { - if (autoPad == "SAME_UPPER") { - actualPads[kernelRank + i] += 1; - } else { - actualPads[i] += 1; - } - } - } - } else if (autoPad != "VALID") { - emitError("auto_pad of unknown / unsupported value"); - } - // Set pads values in attributes. - { - ArrayRef defaultRefs(actualPads); - auto defaultAttr = builder.getI64ArrayAttr(defaultRefs); - padsAttr(defaultAttr); - auto defaultAutoPadAttr = builder.getStringAttr("NOTSET"); - auto_padAttr(defaultAutoPadAttr); - } + processConvTypeParams(this, X()); // Initialize output shape. SmallVector yShape(xShape.begin(), xShape.end()); + auto dilationsOpt = dilations(); + auto stridesOpt = strides(); + auto padsOpt = pads(); // Process for all kernel dimensions. for (int i = 0; i < kernelRank; ++i) { - auto inputSpatialShape = xShape[kernelOffset + i]; - auto padShape = actualPads[i] + actualPads[kernelRank + i]; - auto kernelSpatialShape = ArrayAttrIntVal(kernelShape, i); - auto dilations = ArrayAttrIntVal(dilationsOpt, i); - auto strideSpatialShape = ArrayAttrIntVal(stridesOpt, i); - double numerator = inputSpatialShape + padShape - - ((kernelSpatialShape - 1) * dilations + 1); - double denominator = strideSpatialShape; + auto inputSize = xShape[kernelOffset + i]; + auto sumOfPads = + ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, kernelRank + i); + auto kernelSize = ArrayAttrIntVal(kernelShape, i); + auto dilationVal = ArrayAttrIntVal(dilationsOpt, i); + auto strideVal = ArrayAttrIntVal(stridesOpt, i); + double numerator = + inputSize + sumOfPads - ((kernelSize - 1) * dilationVal + 1); + double denominator = strideVal; int64_t res; if (ceilMode) { res = ceil(numerator / denominator) + 1; @@ -1118,14 +1081,15 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { if (padsOpt) { auto padsArray = padsOpt.getValue(); // Pads consists of two values for each axis of data. - // The two values specify the number of elements padded before and after respectively. + // The two values specify the number of elements padded before and after + // respectively. for (int i = 0; i < dataRank; ++i) { - int64_t p1 = (padsArray[2*i]).cast().getInt(); - int64_t p2 = (padsArray[2*i+1]).cast().getInt(); - //Have to non-negative constant - if (p1 < 0 || p2 <0) + int64_t p1 = (padsArray[2 * i]).cast().getInt(); + int64_t p2 = (padsArray[2 * i + 1]).cast().getInt(); + // Have to non-negative constant + if (p1 < 0 || p2 < 0) return (Type)NULL; - outputShape[i] += p1+p2; + outputShape[i] += p1 + p2; } return (RankedTensorType::get(outputShape, dataTy.getElementType())); @@ -1136,11 +1100,11 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { // PadConstantPad -void ONNXPadConstantPadOp::inferShapes(){ +void ONNXPadConstantPadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); if (outputType) { getResult().setType(outputType); - } + } return; } @@ -1148,7 +1112,7 @@ void ONNXPadConstantPadOp::inferShapes(){ // PadConstantValuePad -void ONNXPadConstantValuePadOp::inferShapes(){ +void ONNXPadConstantValuePadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); if (outputType) { getResult().setType(outputType); diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 1204dc3..5a350e4 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -150,7 +150,7 @@ func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_0 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<1x5x27xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<1x5x27xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> } @@ -161,7 +161,7 @@ func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_1 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32> } @@ -172,7 +172,7 @@ func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_2 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32> } @@ -184,7 +184,7 @@ func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10 "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_3 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } @@ -195,7 +195,7 @@ func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10 "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_4 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } @@ -204,7 +204,7 @@ func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10 "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_5 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [3, 5, 2, 4]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } @@ -215,7 +215,7 @@ func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10 "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_6 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32> } @@ -226,7 +226,7 @@ func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_7 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32> } @@ -238,8 +238,8 @@ func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_8 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [18, 66, 18, 66], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> - // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32> } /// dilations attribute. @@ -249,8 +249,8 @@ func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_9 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x20x42xf32> - // CHECK: return [[RES_ATTR]] : tensor<1x5x20x42xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x22x46xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32> } /// dilations attribute with stride. @@ -260,8 +260,8 @@ func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7 "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_10 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x10x21xf32> - // CHECK: return [[RES_ATTR]] : tensor<1x5x10x21xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x11x23xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32> } /// dilations attribute with auto_pad set to SAME_UPPER. @@ -269,30 +269,30 @@ func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7 func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_conv_no_bias_11 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, pads = [6, 11, 6, 11]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> - // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } + // CHECK-LABEL: test_conv_no_bias_11 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> -/// Test PadConstantValuePad +//===----------------------------------------------------------------------===// +/// Test shape inference for PadConstantValuePad. +//===----------------------------------------------------------------------===// +/// Test PadConstantValuePad_1 func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<16x13xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_PadConstantValuePad_1 - // CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32> - // CHECK: return [[RES]] : tensor<18x13xf32> } +// CHECK-LABEL: test_PadConstantValuePad_1 +// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32> +// CHECK: return [[RES]] : tensor<18x13xf32> -/// Test PadConstantPad - +/// Test PadConstantPad_1 func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_PadConstantPad_1 - // CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32> - // CHECK: return [[RES]] : tensor<18x17xf32> } +// CHECK-LABEL: test_PadConstantPad_1 +// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32> +// CHECK: return [[RES]] : tensor<18x17xf32> +