Add shape inference for several ops (#163)

* 1. Add shape inference for the following ops:
 - Atan
 - Tan
 - Sin
 - Cast
 - ConvTranspose
 - Flatten
 - DynamicQuantizeLinear
 - QuantizeLinear
 - DequantizeLinear
 - ConvInteger

2. Import attributes for generic nodes
3. Fixes for cases where .cast<> should be .isa<> (ONNXConcat::inferShapes)

* Fix foormatting issues

* Address comments:
 - SmallVector<> * -> SmallVectorImpl<> &
 - switch-case -> helper function
   - Inside helper function, preserve signed-ness
 - add TODOs

* Can't use signed integers yet in convertONNXTypeToMLIRType, add TODO

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Aman LaChapelle 2020-06-08 23:55:49 -07:00 committed by GitHub
parent 43dc1a1e01
commit ca185002f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 697 additions and 33 deletions

View File

@ -157,6 +157,7 @@ private:
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
result.addOperands(inputs); result.addOperands(inputs);
result.addAttributes(ImportNodeAttributes(node));
auto op = builder_.createOperation(result); auto op = builder_.createOperation(result);
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
auto r = op->getResult(i); auto r = op->getResult(i);

View File

@ -19,6 +19,7 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "ONNXOps.hpp" #include "ONNXOps.hpp"
@ -436,6 +437,37 @@ static LogicalResult RNNShapeInference(T *op) {
return success(); return success();
} }
static void insertConvTransposeSpatialDim(SmallVectorImpl<int64_t> &outputDims,
ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> outputPadsOpt, Optional<ArrayAttr> outputShapeOpt,
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
auto xRank = xShape.size();
auto spatialRank = ArrayAttrSize(kernelShape);
auto spatialOffset = xRank - spatialRank;
int64_t dilationVal = 1;
int64_t outputPadsVal = 0;
// output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] +
// ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i]
for (int i = 0; i < spatialRank; ++i) {
auto inputSize = xShape[spatialOffset + i];
auto sumOfPads =
ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i);
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
if (dilationsOpt.hasValue())
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
if (outputPadsOpt.hasValue())
outputPadsVal = ArrayAttrIntVal(outputPadsOpt, i);
// Number of useful values: input plus pad - effective size of kernel (see
// processConvTypeParams comments to see how this value is derived).
int64_t res = strideVal * (inputSize - 1) + outputPadsVal +
((kernelSize - 1) * dilationVal + 1) - sumOfPads;
outputDims.emplace_back(res);
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ONNXOpsDialect // ONNXOpsDialect
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -482,6 +514,24 @@ LogicalResult ONNXExpOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Atan
/// Infer the output shape of the ONNXAtanOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXAtanOp::inferShapes() {
getResult().setType(getOperand().getType());
return success();
}
//===----------------------------------------------------------------------===//
// Tan
/// Infer the output shape of the ONNXTanOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXTanOp::inferShapes() {
getResult().setType(getOperand().getType());
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tanh // Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the /// Infer the output shape of the ONNXTanhOp. This method is required by the
@ -491,6 +541,15 @@ LogicalResult ONNXTanhOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Sin
/// Infer the output shape of the ONNXSinOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXSinOp::inferShapes() {
getResult().setType(getOperand().getType());
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sinh // Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the /// Infer the output shape of the ONNXSinhOp. This method is required by the
@ -1316,6 +1375,138 @@ LogicalResult ONNXConvOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConvTranspose
// For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the
// other classes based on Conv.
// Conv attributes output:
// - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user;
// - kernelShape: inferred from weight matrix if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
LogicalResult ONNXConvTransposeOp::inferShapes() {
// Generic shape for data input X, weight tensor W, and optional bias B
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
// B: (M) Optional
bool hasBias = !B().getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>() ||
(hasBias && !B().getType().isa<RankedTensorType>())) {
return emitError("Input tensor not ranked");
}
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
auto weightTy = W().getType().cast<RankedTensorType>();
auto weightShape = weightTy.getShape();
auto builder = mlir::Builder(this->getContext());
// Lowest supported convolution is a one dimensional convolution.
if (xShape.size() < 3) {
return emitError("Data input shape must be at least (NxCxD1)");
}
// Check that shape of weight and data have same length.
if (xShape.size() != weightShape.size()) {
return emitError("Weight size not compatible with data size");
}
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvTransposeOp::group().getSExtValue();
// Check if the attribute actually exists. If it does not then add it.
if (!groupAttr())
groupAttr(builder.getI64IntegerAttr(group));
// Check that the X.shape[1] == (W.shape[0] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[0] != -1 &&
xShape[1] != (weightShape[0] * group)) {
return emitError("Channel dimension mismatch");
}
// Check the size of bias.
if (hasBias) {
auto bTx = B().getType().cast<RankedTensorType>();
auto bShape = bTx.getShape();
if (bShape.size() != 1) {
return emitError("bias should be one dimensional");
}
if (bShape[0] != weightShape[1]) {
return emitError(
"bias should have same dimensions as weight's second dimension");
}
}
// Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size.
// Number of spatial dimensions.
auto spatialOffset = 2;
int32_t spatialRank = xShape.size() - spatialOffset;
// Use kernel_shape attribute if present otherwise use size from weight
// argument.
auto kernelShape = kernel_shape();
if (kernelShape.hasValue()) {
if (ArrayAttrSize(kernelShape) != spatialRank) {
return emitError(
"kernel_shape length incompatible with spatial dimensions");
}
// Have the right number of values, check them.
for (int i = 0; i < spatialRank; ++i)
if (ArrayAttrIntVal(kernelShape, i) < 1) {
return emitError("bad kernel_shape value");
}
} else {
// Deduce shape from weight input.
SmallVector<int64_t, 2> defaultVals;
for (int i = 0; i < spatialRank; ++i)
defaultVals.emplace_back(weightShape[spatialOffset + i]);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
auto builder = mlir::Builder(getContext());
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
kernelShape = kernel_shape();
}
// Process strides, dilations, and pads.
processConvTypeParams<>(this, X());
auto dilationsOpt = dilations();
auto stridesOpt = strides();
auto padsOpt = pads();
auto outputPads = output_padding();
auto outputShape = output_shape();
// TODO: handle the spatial dimension computation if output shape is specified
assert(!outputShape.hasValue() && "unhandled option in ConvTranspose");
// First two output dimensions consist of the number of batches and the
// number of kernels being applied.
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
// Insert number of filters being applied (number of output channels).
outputDims.emplace_back(weightShape[1]);
// Compute and insert spatial dims.
insertConvTransposeSpatialDim(outputDims, xShape, kernelShape, padsOpt,
stridesOpt, outputPads, outputShape, dilationsOpt);
// Set the output shape if it's not already set
if (!outputShape.hasValue()) {
output_shapeAttr(builder.getI64ArrayAttr(outputDims));
}
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return success();
}
//===----------------------------------------------------------------------===//
// AveragePool // AveragePool
// Infer shape attributes output: // Infer shape attributes output:
// - auto_pad set to NOTSET; // - auto_pad set to NOTSET;
@ -1561,6 +1752,34 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Cast
LogicalResult ONNXCastOp::inferShapes() {
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
if (!inputType) {
return emitError("Non-shaped input type");
}
auto getOutputType = [&inputType](Type elementType) -> Type {
if (inputType.hasRank()) {
return RankedTensorType::get(inputType.getShape(), elementType);
}
return UnrankedTensorType::get(elementType);
};
int64_t targetType = toAttr().getInt();
OpBuilder builder(getContext());
if (auto elementType = convertONNXTypeToMLIRType(
builder, static_cast<onnx::TensorProto_DataType>(targetType))) {
getResult().setType(getOutputType(elementType));
} else {
return emitOpError("Unable to get the element type for to = " +
std::to_string(targetType));
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Constant // Constant
@ -1583,7 +1802,7 @@ LogicalResult ONNXConstantOp::inferShapes() {
LogicalResult ONNXConcatOp::inferShapes() { LogicalResult ONNXConcatOp::inferShapes() {
int inputNum = getNumOperands(); int inputNum = getNumOperands();
for (int i = 0; i < inputNum; ++i) { for (int i = 0; i < inputNum; ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().isa<RankedTensorType>())
return emitError("Input tensor(s) not ranked"); return emitError("Input tensor(s) not ranked");
} }
// Checking value of axis parameter. // Checking value of axis parameter.
@ -1713,6 +1932,219 @@ LogicalResult ONNXSplitOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Flatten
LogicalResult ONNXFlattenOp::inferShapes() {
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
auto inTy = input().getType().dyn_cast<ShapedType>();
if (!inTy) {
return emitOpError("Input is a non-shaped type");
}
auto outTy = output().getType().dyn_cast<ShapedType>();
if (!outTy) {
return emitOpError("Output is a non-shaped type");
}
// TODO(tjingrant): Seems like we can also fairly easily support the case
// where the batch dimension is dynamic
if (!outTy.hasStaticShape()) {
auto inShape = inTy.getShape();
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
uint64_t outDim = 1;
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
outDim *= *it;
}
SmallVector<int64_t, 2> dims;
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
dims.emplace_back(inShape[0]);
dims.emplace_back(outDim);
getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
}
return success();
}
//===----------------------------------------------------------------------===//
// DynamicQuantizeLinear
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>();
if (!inTy || !inTy.hasStaticShape()) {
return emitOpError("Input is not a statically-shaped type");
}
auto yTy = y().getType().cast<ShapedType>();
auto yScaleTy = y_scale().getType().cast<ShapedType>();
auto yZPTy = y_zero_point().getType().cast<ShapedType>();
IntegerType i8Type = IntegerType::get(8, getContext());
RankedTensorType scalarType = RankedTensorType::get({}, i8Type);
// Set the types for the scalars
if (!yScaleTy.hasStaticShape()) {
y_scale().setType(scalarType);
}
if (!yZPTy.hasStaticShape()) {
y_zero_point().setType(scalarType);
}
if (!yTy.hasStaticShape()) {
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
y().setType(outType);
}
return success();
}
//===----------------------------------------------------------------------===//
// QuantizeLinear
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>();
if (!inTy || !inTy.hasStaticShape()) {
return emitOpError("Input is not a statically-shaped type");
}
auto yTy = y().getType().cast<ShapedType>();
if (!yTy.hasStaticShape()) {
// TODO: Unfortunately, we can't tell if this should be signed or unsigned
// here...
IntegerType i8Type = IntegerType::get(8, getContext());
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
y().setType(outType);
}
return success();
}
//===----------------------------------------------------------------------===//
// DequantizeLinear
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>();
if (!inTy || !inTy.hasStaticShape()) {
return emitOpError("Input is not a statically-shaped type");
}
auto yTy = y().getType().cast<ShapedType>();
if (!yTy.hasStaticShape()) {
FloatType f32 = FloatType::getF32(getContext());
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), f32);
y().setType(outType);
}
return success();
}
//===----------------------------------------------------------------------===//
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
LogicalResult ONNXConvIntegerOp::inferShapes() {
// Generic shape for data input X, weight tensor W
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists.
if (!x().getType().isa<RankedTensorType>() ||
!w().getType().isa<RankedTensorType>()) {
return emitOpError("Input tensor not ranked");
}
auto xTy = x().getType().cast<RankedTensorType>();
if (!xTy.getElementType().isInteger(8)) {
return emitOpError("Invalid input type");
}
auto xShape = xTy.getShape();
auto weightTy = w().getType().cast<RankedTensorType>();
if (!weightTy.getElementType().isInteger(8)) {
return emitOpError("Invalid input type");
}
auto weightShape = weightTy.getShape();
auto builder = mlir::Builder(this->getContext());
// Lowest supported convolution is a one dimensional convolution.
if (xShape.size() < 3) {
return emitOpError("Data input shape must be at least (NxCxD1)");
}
// Check that shape of weight and data have same length.
if (xShape.size() != weightShape.size()) {
return emitError("Weight size not compatible with data size");
}
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvIntegerOp::group().getSExtValue();
// Check if the attribute actually exists. If it does not then add it.
if (!groupAttr())
groupAttr(builder.getI64IntegerAttr(group));
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[1] != -1 &&
xShape[1] != (weightShape[1] * group)) {
return emitOpError("Channel dimension mismatch");
}
// Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size.
// Number of spatial dimensions.
auto spatialOffset = 2;
int32_t spatialRank = xShape.size() - spatialOffset;
// Use kernel_shape attribute if present otherwise use size from weight
// argument.
auto kernelShape = kernel_shape();
if (kernelShape.hasValue()) {
if (ArrayAttrSize(kernelShape) != spatialRank) {
return emitOpError(
"kernel_shape length incompatible with spatial dimensions");
}
// Have the right number of values, check them.
for (int i = 0; i < spatialRank; ++i)
if (ArrayAttrIntVal(kernelShape, i) < 1) {
return emitError("bad kernel_shape value");
}
} else {
// Deduce shape from weight input.
SmallVector<int64_t, 2> defaultVals;
for (int i = 0; i < spatialRank; ++i)
defaultVals.emplace_back(weightShape[spatialOffset + i]);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
auto builder = mlir::Builder(getContext());
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
kernelShape = kernel_shape();
}
// Process strides, dilations, and pads.
processConvTypeParams<>(this, x());
auto dilationsOpt = dilations();
auto stridesOpt = strides();
auto padsOpt = pads();
// First two output dimensions consist of the number of batches and the
// number of kernels being applied.
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
// Insert number of filters being applied (number of output channels).
outputDims.emplace_back(weightShape[0]);
// Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
stridesOpt, dilationsOpt);
// ONNX spec specifies the output type as an int32
Type outputType = IntegerType::get(32, getContext());
getResult().setType(RankedTensorType::get(outputDims, outputType));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -278,7 +278,7 @@ def ONNXAsinhOp:ONNX_Op<"Asinh",
} }
def ONNXAtanOp:ONNX_Op<"Atan", def ONNXAtanOp:ONNX_Op<"Atan",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Atan operation"; let summary = "ONNX Atan operation";
let description = [{ let description = [{
"Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise." "Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise."
@ -449,7 +449,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
} }
def ONNXCastOp:ONNX_Op<"Cast", def ONNXCastOp:ONNX_Op<"Cast",
[NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
let summary = "ONNX Cast operation"; let summary = "ONNX Cast operation";
let description = [{ let description = [{
"The operator casts the elements of a given input tensor to a data type" "The operator casts the elements of a given input tensor to a data type"
@ -715,7 +715,7 @@ def ONNXConvOp:ONNX_Op<"Conv",
} }
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ConvInteger operation"; let summary = "ONNX ConvInteger operation";
let description = [{ let description = [{
"The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point," "The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point,"
@ -746,7 +746,7 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
} }
def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ConvTranspose operation"; let summary = "ONNX ConvTranspose operation";
let description = [{ let description = [{
"The convolution transpose operator consumes an input tensor and a filter," "The convolution transpose operator consumes an input tensor and a filter,"
@ -924,7 +924,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace",
} }
def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX DequantizeLinear operation"; let summary = "ONNX DequantizeLinear operation";
let description = [{ let description = [{
"The linear dequantization operator. It consumes a quantized tensor, a scale, a zero point to compute the full precision tensor." "The linear dequantization operator. It consumes a quantized tensor, a scale, a zero point to compute the full precision tensor."
@ -1053,7 +1053,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout",
} }
def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX DynamicQuantizeLinear operation"; let summary = "ONNX DynamicQuantizeLinear operation";
let description = [{ let description = [{
"A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data." "A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data."
@ -1285,7 +1285,7 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike",
} }
def ONNXFlattenOp:ONNX_Op<"Flatten", def ONNXFlattenOp:ONNX_Op<"Flatten",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Flatten operation"; let summary = "ONNX Flatten operation";
let description = [{ let description = [{
"Flattens the input tensor into a 2D matrix. If input tensor has shape" "Flattens the input tensor into a 2D matrix. If input tensor has shape"
@ -3327,7 +3327,7 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul",
} }
def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX QuantizeLinear operation"; let summary = "ONNX QuantizeLinear operation";
let description = [{ let description = [{
"The linear per-tensor/layer quantization operator. It consumes a high precision tensor, a scale, a zero point to compute the low precision / quantized tensor." "The linear per-tensor/layer quantization operator. It consumes a high precision tensor, a scale, a zero point to compute the low precision / quantized tensor."
@ -4787,7 +4787,7 @@ def ONNXSignOp:ONNX_Op<"Sign",
} }
def ONNXSinOp:ONNX_Op<"Sin", def ONNXSinOp:ONNX_Op<"Sin",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Sin operation"; let summary = "ONNX Sin operation";
let description = [{ let description = [{
"Calculates the sine of the given input tensor, element-wise." "Calculates the sine of the given input tensor, element-wise."
@ -5223,7 +5223,7 @@ def ONNXSumOp:ONNX_Op<"Sum",
} }
def ONNXTanOp:ONNX_Op<"Tan", def ONNXTanOp:ONNX_Op<"Tan",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Tan operation"; let summary = "ONNX Tan operation";
let description = [{ let description = [{
"Calculates the tangent of the given input tensor, element-wise." "Calculates the tangent of the given input tensor, element-wise."

View File

@ -44,6 +44,7 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) {
// Convert type to MLIR type. // Convert type to MLIR type.
// A complete list of types can be found in: // A complete list of types can be found in:
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h // <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
// TODO: Update Int*/Uint* to emit signed/unsigned MLIR types
mlir::Type convertONNXTypeToMLIRType( mlir::Type convertONNXTypeToMLIRType(
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) { mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) {
switch (onnxType) { switch (onnxType) {

View File

@ -37,8 +37,7 @@ public:
if (returnsDynamicShape(op)) { if (returnsDynamicShape(op)) {
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
if (failed(shape_op.inferShapes())) { if (failed(shape_op.inferShapes())) {
op->emitError("unable to infer shape of operation without shape " op->emitError("shape inference failed");
"inference method");
return signalPassFailure(); return signalPassFailure();
} }
} else { } else {
@ -79,7 +78,10 @@ public:
// shaped outputs. All those operation need to implement the inferShape() // shaped outputs. All those operation need to implement the inferShape()
// method. // method.
if (op->getName().getStringRef() != "onnx.Exp" && if (op->getName().getStringRef() != "onnx.Exp" &&
op->getName().getStringRef() != "onnx.Atan" &&
op->getName().getStringRef() != "onnx.Tan" &&
op->getName().getStringRef() != "onnx.Tanh" && op->getName().getStringRef() != "onnx.Tanh" &&
op->getName().getStringRef() != "onnx.Sin" &&
op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Cos" && op->getName().getStringRef() != "onnx.Cos" &&
@ -130,7 +132,14 @@ public:
op->getName().getStringRef() != "onnx.RNN" && op->getName().getStringRef() != "onnx.RNN" &&
op->getName().getStringRef() != "onnx.LSTM" && op->getName().getStringRef() != "onnx.LSTM" &&
op->getName().getStringRef() != "onnx.GRU" && op->getName().getStringRef() != "onnx.GRU" &&
op->getName().getStringRef() != "onnx.Unsqueeze") op->getName().getStringRef() != "onnx.Unsqueeze" &&
op->getName().getStringRef() != "onnx.Cast" &&
op->getName().getStringRef() != "onnx.ConvTranspose" &&
op->getName().getStringRef() != "onnx.Flatten" &&
op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" &&
op->getName().getStringRef() != "onnx.QuantizeLinear" &&
op->getName().getStringRef() != "onnx.DequantizeLinear" &&
op->getName().getStringRef() != "onnx.ConvInteger")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<NoneType>() && return !result_type.isa<NoneType>() &&

View File

@ -589,6 +589,19 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
// ----- // -----
//===----------------------------------------------------------------------===//
/// Test the flatten op inference.
//===----------------------------------------------------------------------===//
func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = 1 : i64} : (tensor<5x2x3x4xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_1
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 1 : i64} : (tensor<5x2x3x4xf32>) -> tensor<5x24xf32>
// CHECK: return [[RES]] : tensor<5x24xf32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test the reshape op inference when concat are present. /// Test the reshape op inference when concat are present.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -872,3 +885,210 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>) // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>)
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32> // CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
} }
//===----------------------------------------------------------------------===//
/// Test the cast op inference.
//===----------------------------------------------------------------------===//
func @test_cast_1(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
%1 = "onnx.Cast"(%arg0) {to = 1} : (tensor<2x3x4xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_cast_1
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
// CHECK: return [[RES]] : tensor<2x3x4xf32>
}
func @test_cast_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xui8> {
%1 = "onnx.Cast"(%arg0) {to = 2} : (tensor<2x3x4xf32>) -> tensor<*xui8>
"std.return"(%1) : (tensor<*xui8>) -> ()
// CHECK-LABEL: test_cast_2
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
// CHECK: return [[RES]] : tensor<2x3x4xi8>
}
func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xsi8> {
%1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xsi8>
"std.return"(%1) : (tensor<*xsi8>) -> ()
// CHECK-LABEL: test_cast_3
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
// CHECK: return [[RES]] : tensor<2x3x4xi8>
}
func @test_cast_10(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf16> {
%1 = "onnx.Cast"(%arg0) {to = 10} : (tensor<2x3x4xf32>) -> tensor<*xf16>
"std.return"(%1) : (tensor<*xf16>) -> ()
// CHECK-LABEL: test_cast_10
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 10 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf16>
// CHECK: return [[RES]] : tensor<2x3x4xf16>
}
//===----------------------------------------------------------------------===//
/// Test the quantization op inferences.
//===----------------------------------------------------------------------===//
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xi8> {
%1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xi8>, tensor<*xi8>, tensor<*xi8>)
"std.return"(%1#0) {} : (tensor<*xi8>) -> ()
// CHECK-LABEL: test_dyn_quantize_linear_1
// CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>)
// CHECK: return [[RES]] : tensor<5x2x3x4xi8>
}
func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xi8> {
%1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<*xi8>
"std.return"(%1) {} : (tensor<*xi8>) -> ()
// CHECK-LABEL: test_quantize_linear_1
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xi8>
// CHECK: return [[RES]] : tensor<5x2x3x4xi8>
}
func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xf32> {
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<*xf32>
"std.return"(%1) {} : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_dequantize_linear_1
// CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xf32>
// CHECK: return [[RES]] : tensor<5x2x3x4xf32>
}
//===----------------------------------------------------------------------===//
/// Test shape inference for ConvInteger operation and all its attributes.
//===----------------------------------------------------------------------===//
/// Default and required attributes for 1-D convolution.
func @test_convinteger_0(%arg0 : tensor<1x2x32xi8>, %arg1 : tensor<5x2x6xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xi8>, tensor<5x2x6xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_0
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xi8>, tensor<5x2x6xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x27xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xi32>
}
/// Default and required attributes.
func @test_convinteger_1(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_1
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x27x58xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xi32>
}
/// kernel_shape attribute.
func @test_convinteger_2(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_2
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x25x56xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xi32>
}
/// pads attribute.
/// Use pads to make output size equal to input size by adding K - 1 to the result.
func @test_convinteger_3(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_3
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
}
/// auto_pad set to SAME_UPPER and SAME_LOWER.
func @test_convinteger_4(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_4
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
}
func @test_convinteger_5(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_5
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
}
/// auto_pad set to VALID.
func @test_convinteger_6(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_6
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x27x55xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xi32>
}
/// With strides attribute.
func @test_convinteger_7(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_7
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x14x20xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xi32>
}
/// auto_pad set to SAME_UPPER with strides attribute.
/// The auto_pad will pas as if stride is equal to 1.
func @test_convinteger_8(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_8
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x16x22xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xi32>
}
/// dilations attribute.
func @test_convinteger_9(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_9
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x22x46xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xi32>
}
/// dilations attribute with stride.
func @test_convinteger_10(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_10
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x11x23xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xi32>
}
/// dilations attribute with auto_pad set to SAME_UPPER.
func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_convinteger_11
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
}

View File

@ -249,13 +249,14 @@ special_op_handler = dict([
# Operations supporting shape inference. # Operations supporting shape inference.
OpsWithShapeInference = [ OpsWithShapeInference = [
'Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', 'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul',
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
'LSTM', 'GRU', 'Split', 'Pad' 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.