Add shape inference for several ops (#163)
* 1. Add shape inference for the following ops: - Atan - Tan - Sin - Cast - ConvTranspose - Flatten - DynamicQuantizeLinear - QuantizeLinear - DequantizeLinear - ConvInteger 2. Import attributes for generic nodes 3. Fixes for cases where .cast<> should be .isa<> (ONNXConcat::inferShapes) * Fix foormatting issues * Address comments: - SmallVector<> * -> SmallVectorImpl<> & - switch-case -> helper function - Inside helper function, preserve signed-ness - add TODOs * Can't use signed integers yet in convertONNXTypeToMLIRType, add TODO Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
43dc1a1e01
commit
ca185002f2
|
@ -157,6 +157,7 @@ private:
|
||||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
result.addOperands(inputs);
|
result.addOperands(inputs);
|
||||||
|
result.addAttributes(ImportNodeAttributes(node));
|
||||||
auto op = builder_.createOperation(result);
|
auto op = builder_.createOperation(result);
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
auto r = op->getResult(i);
|
auto r = op->getResult(i);
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
#include "ONNXOps.hpp"
|
#include "ONNXOps.hpp"
|
||||||
|
|
||||||
|
@ -436,6 +437,37 @@ static LogicalResult RNNShapeInference(T *op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void insertConvTransposeSpatialDim(SmallVectorImpl<int64_t> &outputDims,
|
||||||
|
ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
||||||
|
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
||||||
|
Optional<ArrayAttr> outputPadsOpt, Optional<ArrayAttr> outputShapeOpt,
|
||||||
|
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
|
||||||
|
auto xRank = xShape.size();
|
||||||
|
auto spatialRank = ArrayAttrSize(kernelShape);
|
||||||
|
auto spatialOffset = xRank - spatialRank;
|
||||||
|
|
||||||
|
int64_t dilationVal = 1;
|
||||||
|
int64_t outputPadsVal = 0;
|
||||||
|
// output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] +
|
||||||
|
// ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i]
|
||||||
|
for (int i = 0; i < spatialRank; ++i) {
|
||||||
|
auto inputSize = xShape[spatialOffset + i];
|
||||||
|
auto sumOfPads =
|
||||||
|
ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i);
|
||||||
|
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
|
||||||
|
if (dilationsOpt.hasValue())
|
||||||
|
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
|
||||||
|
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
|
||||||
|
if (outputPadsOpt.hasValue())
|
||||||
|
outputPadsVal = ArrayAttrIntVal(outputPadsOpt, i);
|
||||||
|
// Number of useful values: input plus pad - effective size of kernel (see
|
||||||
|
// processConvTypeParams comments to see how this value is derived).
|
||||||
|
int64_t res = strideVal * (inputSize - 1) + outputPadsVal +
|
||||||
|
((kernelSize - 1) * dilationVal + 1) - sumOfPads;
|
||||||
|
outputDims.emplace_back(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ONNXOpsDialect
|
// ONNXOpsDialect
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -482,6 +514,24 @@ LogicalResult ONNXExpOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Atan
|
||||||
|
/// Infer the output shape of the ONNXAtanOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
LogicalResult ONNXAtanOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Tan
|
||||||
|
/// Infer the output shape of the ONNXTanOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
LogicalResult ONNXTanOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tanh
|
// Tanh
|
||||||
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
||||||
|
@ -491,6 +541,15 @@ LogicalResult ONNXTanhOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Sin
|
||||||
|
/// Infer the output shape of the ONNXSinOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
LogicalResult ONNXSinOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sinh
|
// Sinh
|
||||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||||
|
@ -1316,6 +1375,138 @@ LogicalResult ONNXConvOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// ConvTranspose
|
||||||
|
|
||||||
|
// For this operation, we define the attributes once in the original Conv
|
||||||
|
// operation class. There is no need to redefine the attribute names for the
|
||||||
|
// other classes based on Conv.
|
||||||
|
// Conv attributes output:
|
||||||
|
// - auto_pad set to NOTSET;
|
||||||
|
// - dilations, strides: set to 1 if not defined by user;
|
||||||
|
// - kernelShape: inferred from weight matrix if not defined by user;
|
||||||
|
// - pads: set to proper value, 0 if not defined by user.
|
||||||
|
|
||||||
|
LogicalResult ONNXConvTransposeOp::inferShapes() {
|
||||||
|
// Generic shape for data input X, weight tensor W, and optional bias B
|
||||||
|
// X: (N x C x D1 x D2 ... x Dn)
|
||||||
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||||
|
// B: (M) Optional
|
||||||
|
|
||||||
|
bool hasBias = !B().getType().isa<NoneType>();
|
||||||
|
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
|
!W().getType().isa<RankedTensorType>() ||
|
||||||
|
(hasBias && !B().getType().isa<RankedTensorType>())) {
|
||||||
|
return emitError("Input tensor not ranked");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
|
auto xShape = xTy.getShape();
|
||||||
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
||||||
|
auto weightShape = weightTy.getShape();
|
||||||
|
auto builder = mlir::Builder(this->getContext());
|
||||||
|
|
||||||
|
// Lowest supported convolution is a one dimensional convolution.
|
||||||
|
if (xShape.size() < 3) {
|
||||||
|
return emitError("Data input shape must be at least (NxCxD1)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that shape of weight and data have same length.
|
||||||
|
if (xShape.size() != weightShape.size()) {
|
||||||
|
return emitError("Weight size not compatible with data size");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group is a required attribute and should have default value of 1.
|
||||||
|
int64_t group = ONNXConvTransposeOp::group().getSExtValue();
|
||||||
|
|
||||||
|
// Check if the attribute actually exists. If it does not then add it.
|
||||||
|
if (!groupAttr())
|
||||||
|
groupAttr(builder.getI64IntegerAttr(group));
|
||||||
|
|
||||||
|
// Check that the X.shape[1] == (W.shape[0] * group) == C condition holds.
|
||||||
|
if (xShape[1] != -1 && weightShape[0] != -1 &&
|
||||||
|
xShape[1] != (weightShape[0] * group)) {
|
||||||
|
return emitError("Channel dimension mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the size of bias.
|
||||||
|
if (hasBias) {
|
||||||
|
auto bTx = B().getType().cast<RankedTensorType>();
|
||||||
|
auto bShape = bTx.getShape();
|
||||||
|
if (bShape.size() != 1) {
|
||||||
|
return emitError("bias should be one dimensional");
|
||||||
|
}
|
||||||
|
if (bShape[0] != weightShape[1]) {
|
||||||
|
return emitError(
|
||||||
|
"bias should have same dimensions as weight's second dimension");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: the value of the group attribut only impacts the way the
|
||||||
|
// computation is carried out and not the actual output size.
|
||||||
|
|
||||||
|
// Number of spatial dimensions.
|
||||||
|
auto spatialOffset = 2;
|
||||||
|
int32_t spatialRank = xShape.size() - spatialOffset;
|
||||||
|
|
||||||
|
// Use kernel_shape attribute if present otherwise use size from weight
|
||||||
|
// argument.
|
||||||
|
auto kernelShape = kernel_shape();
|
||||||
|
if (kernelShape.hasValue()) {
|
||||||
|
if (ArrayAttrSize(kernelShape) != spatialRank) {
|
||||||
|
return emitError(
|
||||||
|
"kernel_shape length incompatible with spatial dimensions");
|
||||||
|
}
|
||||||
|
// Have the right number of values, check them.
|
||||||
|
for (int i = 0; i < spatialRank; ++i)
|
||||||
|
if (ArrayAttrIntVal(kernelShape, i) < 1) {
|
||||||
|
return emitError("bad kernel_shape value");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Deduce shape from weight input.
|
||||||
|
SmallVector<int64_t, 2> defaultVals;
|
||||||
|
for (int i = 0; i < spatialRank; ++i)
|
||||||
|
defaultVals.emplace_back(weightShape[spatialOffset + i]);
|
||||||
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
||||||
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
|
||||||
|
kernelShape = kernel_shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process strides, dilations, and pads.
|
||||||
|
processConvTypeParams<>(this, X());
|
||||||
|
auto dilationsOpt = dilations();
|
||||||
|
auto stridesOpt = strides();
|
||||||
|
auto padsOpt = pads();
|
||||||
|
auto outputPads = output_padding();
|
||||||
|
auto outputShape = output_shape();
|
||||||
|
// TODO: handle the spatial dimension computation if output shape is specified
|
||||||
|
assert(!outputShape.hasValue() && "unhandled option in ConvTranspose");
|
||||||
|
|
||||||
|
// First two output dimensions consist of the number of batches and the
|
||||||
|
// number of kernels being applied.
|
||||||
|
SmallVector<int64_t, 4> outputDims;
|
||||||
|
// Insert batch size.
|
||||||
|
outputDims.emplace_back(xShape[0]);
|
||||||
|
// Insert number of filters being applied (number of output channels).
|
||||||
|
outputDims.emplace_back(weightShape[1]);
|
||||||
|
// Compute and insert spatial dims.
|
||||||
|
insertConvTransposeSpatialDim(outputDims, xShape, kernelShape, padsOpt,
|
||||||
|
stridesOpt, outputPads, outputShape, dilationsOpt);
|
||||||
|
|
||||||
|
// Set the output shape if it's not already set
|
||||||
|
if (!outputShape.hasValue()) {
|
||||||
|
output_shapeAttr(builder.getI64ArrayAttr(outputDims));
|
||||||
|
}
|
||||||
|
|
||||||
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// AveragePool
|
// AveragePool
|
||||||
// Infer shape attributes output:
|
// Infer shape attributes output:
|
||||||
// - auto_pad set to NOTSET;
|
// - auto_pad set to NOTSET;
|
||||||
|
@ -1561,6 +1752,34 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Cast
|
||||||
|
|
||||||
|
LogicalResult ONNXCastOp::inferShapes() {
|
||||||
|
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!inputType) {
|
||||||
|
return emitError("Non-shaped input type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getOutputType = [&inputType](Type elementType) -> Type {
|
||||||
|
if (inputType.hasRank()) {
|
||||||
|
return RankedTensorType::get(inputType.getShape(), elementType);
|
||||||
|
}
|
||||||
|
return UnrankedTensorType::get(elementType);
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t targetType = toAttr().getInt();
|
||||||
|
OpBuilder builder(getContext());
|
||||||
|
if (auto elementType = convertONNXTypeToMLIRType(
|
||||||
|
builder, static_cast<onnx::TensorProto_DataType>(targetType))) {
|
||||||
|
getResult().setType(getOutputType(elementType));
|
||||||
|
} else {
|
||||||
|
return emitOpError("Unable to get the element type for to = " +
|
||||||
|
std::to_string(targetType));
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Constant
|
// Constant
|
||||||
|
|
||||||
|
@ -1583,7 +1802,7 @@ LogicalResult ONNXConstantOp::inferShapes() {
|
||||||
LogicalResult ONNXConcatOp::inferShapes() {
|
LogicalResult ONNXConcatOp::inferShapes() {
|
||||||
int inputNum = getNumOperands();
|
int inputNum = getNumOperands();
|
||||||
for (int i = 0; i < inputNum; ++i) {
|
for (int i = 0; i < inputNum; ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().isa<RankedTensorType>())
|
||||||
return emitError("Input tensor(s) not ranked");
|
return emitError("Input tensor(s) not ranked");
|
||||||
}
|
}
|
||||||
// Checking value of axis parameter.
|
// Checking value of axis parameter.
|
||||||
|
@ -1713,6 +1932,219 @@ LogicalResult ONNXSplitOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Flatten
|
||||||
|
|
||||||
|
LogicalResult ONNXFlattenOp::inferShapes() {
|
||||||
|
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
||||||
|
auto inTy = input().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!inTy) {
|
||||||
|
return emitOpError("Input is a non-shaped type");
|
||||||
|
}
|
||||||
|
auto outTy = output().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!outTy) {
|
||||||
|
return emitOpError("Output is a non-shaped type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(tjingrant): Seems like we can also fairly easily support the case
|
||||||
|
// where the batch dimension is dynamic
|
||||||
|
if (!outTy.hasStaticShape()) {
|
||||||
|
auto inShape = inTy.getShape();
|
||||||
|
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
|
||||||
|
uint64_t outDim = 1;
|
||||||
|
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
|
||||||
|
outDim *= *it;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> dims;
|
||||||
|
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
|
||||||
|
dims.emplace_back(inShape[0]);
|
||||||
|
dims.emplace_back(outDim);
|
||||||
|
getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DynamicQuantizeLinear
|
||||||
|
|
||||||
|
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
||||||
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inTy || !inTy.hasStaticShape()) {
|
||||||
|
return emitOpError("Input is not a statically-shaped type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto yTy = y().getType().cast<ShapedType>();
|
||||||
|
auto yScaleTy = y_scale().getType().cast<ShapedType>();
|
||||||
|
auto yZPTy = y_zero_point().getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
IntegerType i8Type = IntegerType::get(8, getContext());
|
||||||
|
RankedTensorType scalarType = RankedTensorType::get({}, i8Type);
|
||||||
|
|
||||||
|
// Set the types for the scalars
|
||||||
|
if (!yScaleTy.hasStaticShape()) {
|
||||||
|
y_scale().setType(scalarType);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!yZPTy.hasStaticShape()) {
|
||||||
|
y_zero_point().setType(scalarType);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!yTy.hasStaticShape()) {
|
||||||
|
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
|
||||||
|
y().setType(outType);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// QuantizeLinear
|
||||||
|
|
||||||
|
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
||||||
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inTy || !inTy.hasStaticShape()) {
|
||||||
|
return emitOpError("Input is not a statically-shaped type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto yTy = y().getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
if (!yTy.hasStaticShape()) {
|
||||||
|
// TODO: Unfortunately, we can't tell if this should be signed or unsigned
|
||||||
|
// here...
|
||||||
|
IntegerType i8Type = IntegerType::get(8, getContext());
|
||||||
|
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
|
||||||
|
y().setType(outType);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DequantizeLinear
|
||||||
|
|
||||||
|
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
||||||
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!inTy || !inTy.hasStaticShape()) {
|
||||||
|
return emitOpError("Input is not a statically-shaped type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto yTy = y().getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
if (!yTy.hasStaticShape()) {
|
||||||
|
FloatType f32 = FloatType::getF32(getContext());
|
||||||
|
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), f32);
|
||||||
|
y().setType(outType);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
|
||||||
|
|
||||||
|
LogicalResult ONNXConvIntegerOp::inferShapes() {
|
||||||
|
// Generic shape for data input X, weight tensor W
|
||||||
|
// X: (N x C x D1 x D2 ... x Dn)
|
||||||
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||||
|
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!x().getType().isa<RankedTensorType>() ||
|
||||||
|
!w().getType().isa<RankedTensorType>()) {
|
||||||
|
return emitOpError("Input tensor not ranked");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto xTy = x().getType().cast<RankedTensorType>();
|
||||||
|
if (!xTy.getElementType().isInteger(8)) {
|
||||||
|
return emitOpError("Invalid input type");
|
||||||
|
}
|
||||||
|
auto xShape = xTy.getShape();
|
||||||
|
auto weightTy = w().getType().cast<RankedTensorType>();
|
||||||
|
if (!weightTy.getElementType().isInteger(8)) {
|
||||||
|
return emitOpError("Invalid input type");
|
||||||
|
}
|
||||||
|
auto weightShape = weightTy.getShape();
|
||||||
|
auto builder = mlir::Builder(this->getContext());
|
||||||
|
|
||||||
|
// Lowest supported convolution is a one dimensional convolution.
|
||||||
|
if (xShape.size() < 3) {
|
||||||
|
return emitOpError("Data input shape must be at least (NxCxD1)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that shape of weight and data have same length.
|
||||||
|
if (xShape.size() != weightShape.size()) {
|
||||||
|
return emitError("Weight size not compatible with data size");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group is a required attribute and should have default value of 1.
|
||||||
|
int64_t group = ONNXConvIntegerOp::group().getSExtValue();
|
||||||
|
|
||||||
|
// Check if the attribute actually exists. If it does not then add it.
|
||||||
|
if (!groupAttr())
|
||||||
|
groupAttr(builder.getI64IntegerAttr(group));
|
||||||
|
|
||||||
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
|
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
||||||
|
xShape[1] != (weightShape[1] * group)) {
|
||||||
|
return emitOpError("Channel dimension mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: the value of the group attribut only impacts the way the
|
||||||
|
// computation is carried out and not the actual output size.
|
||||||
|
|
||||||
|
// Number of spatial dimensions.
|
||||||
|
auto spatialOffset = 2;
|
||||||
|
int32_t spatialRank = xShape.size() - spatialOffset;
|
||||||
|
|
||||||
|
// Use kernel_shape attribute if present otherwise use size from weight
|
||||||
|
// argument.
|
||||||
|
auto kernelShape = kernel_shape();
|
||||||
|
if (kernelShape.hasValue()) {
|
||||||
|
if (ArrayAttrSize(kernelShape) != spatialRank) {
|
||||||
|
return emitOpError(
|
||||||
|
"kernel_shape length incompatible with spatial dimensions");
|
||||||
|
}
|
||||||
|
// Have the right number of values, check them.
|
||||||
|
for (int i = 0; i < spatialRank; ++i)
|
||||||
|
if (ArrayAttrIntVal(kernelShape, i) < 1) {
|
||||||
|
return emitError("bad kernel_shape value");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Deduce shape from weight input.
|
||||||
|
SmallVector<int64_t, 2> defaultVals;
|
||||||
|
for (int i = 0; i < spatialRank; ++i)
|
||||||
|
defaultVals.emplace_back(weightShape[spatialOffset + i]);
|
||||||
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
||||||
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
|
||||||
|
kernelShape = kernel_shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process strides, dilations, and pads.
|
||||||
|
processConvTypeParams<>(this, x());
|
||||||
|
auto dilationsOpt = dilations();
|
||||||
|
auto stridesOpt = strides();
|
||||||
|
auto padsOpt = pads();
|
||||||
|
|
||||||
|
// First two output dimensions consist of the number of batches and the
|
||||||
|
// number of kernels being applied.
|
||||||
|
SmallVector<int64_t, 4> outputDims;
|
||||||
|
// Insert batch size.
|
||||||
|
outputDims.emplace_back(xShape[0]);
|
||||||
|
// Insert number of filters being applied (number of output channels).
|
||||||
|
outputDims.emplace_back(weightShape[0]);
|
||||||
|
// Compute and insert spatial dims.
|
||||||
|
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
|
||||||
|
stridesOpt, dilationsOpt);
|
||||||
|
|
||||||
|
// ONNX spec specifies the output type as an int32
|
||||||
|
Type outputType = IntegerType::get(32, getContext());
|
||||||
|
getResult().setType(RankedTensorType::get(outputDims, outputType));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -278,7 +278,7 @@ def ONNXAsinhOp:ONNX_Op<"Asinh",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXAtanOp:ONNX_Op<"Atan",
|
def ONNXAtanOp:ONNX_Op<"Atan",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Atan operation";
|
let summary = "ONNX Atan operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise."
|
"Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise."
|
||||||
|
@ -449,7 +449,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXCastOp:ONNX_Op<"Cast",
|
def ONNXCastOp:ONNX_Op<"Cast",
|
||||||
[NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||||
let summary = "ONNX Cast operation";
|
let summary = "ONNX Cast operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The operator casts the elements of a given input tensor to a data type"
|
"The operator casts the elements of a given input tensor to a data type"
|
||||||
|
@ -715,7 +715,7 @@ def ONNXConvOp:ONNX_Op<"Conv",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
|
def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX ConvInteger operation";
|
let summary = "ONNX ConvInteger operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point,"
|
"The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point,"
|
||||||
|
@ -746,7 +746,7 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
|
def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX ConvTranspose operation";
|
let summary = "ONNX ConvTranspose operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The convolution transpose operator consumes an input tensor and a filter,"
|
"The convolution transpose operator consumes an input tensor and a filter,"
|
||||||
|
@ -924,7 +924,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
|
def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX DequantizeLinear operation";
|
let summary = "ONNX DequantizeLinear operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The linear dequantization operator. It consumes a quantized tensor, a scale, a zero point to compute the full precision tensor."
|
"The linear dequantization operator. It consumes a quantized tensor, a scale, a zero point to compute the full precision tensor."
|
||||||
|
@ -1053,7 +1053,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
|
def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX DynamicQuantizeLinear operation";
|
let summary = "ONNX DynamicQuantizeLinear operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data."
|
"A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data."
|
||||||
|
@ -1285,7 +1285,7 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXFlattenOp:ONNX_Op<"Flatten",
|
def ONNXFlattenOp:ONNX_Op<"Flatten",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Flatten operation";
|
let summary = "ONNX Flatten operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Flattens the input tensor into a 2D matrix. If input tensor has shape"
|
"Flattens the input tensor into a 2D matrix. If input tensor has shape"
|
||||||
|
@ -3327,7 +3327,7 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
|
def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX QuantizeLinear operation";
|
let summary = "ONNX QuantizeLinear operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The linear per-tensor/layer quantization operator. It consumes a high precision tensor, a scale, a zero point to compute the low precision / quantized tensor."
|
"The linear per-tensor/layer quantization operator. It consumes a high precision tensor, a scale, a zero point to compute the low precision / quantized tensor."
|
||||||
|
@ -4787,7 +4787,7 @@ def ONNXSignOp:ONNX_Op<"Sign",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSinOp:ONNX_Op<"Sin",
|
def ONNXSinOp:ONNX_Op<"Sin",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Sin operation";
|
let summary = "ONNX Sin operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculates the sine of the given input tensor, element-wise."
|
"Calculates the sine of the given input tensor, element-wise."
|
||||||
|
@ -5223,7 +5223,7 @@ def ONNXSumOp:ONNX_Op<"Sum",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXTanOp:ONNX_Op<"Tan",
|
def ONNXTanOp:ONNX_Op<"Tan",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Tan operation";
|
let summary = "ONNX Tan operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculates the tangent of the given input tensor, element-wise."
|
"Calculates the tangent of the given input tensor, element-wise."
|
||||||
|
|
|
@ -44,6 +44,7 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) {
|
||||||
// Convert type to MLIR type.
|
// Convert type to MLIR type.
|
||||||
// A complete list of types can be found in:
|
// A complete list of types can be found in:
|
||||||
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||||
|
// TODO: Update Int*/Uint* to emit signed/unsigned MLIR types
|
||||||
mlir::Type convertONNXTypeToMLIRType(
|
mlir::Type convertONNXTypeToMLIRType(
|
||||||
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) {
|
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) {
|
||||||
switch (onnxType) {
|
switch (onnxType) {
|
||||||
|
|
|
@ -37,8 +37,7 @@ public:
|
||||||
if (returnsDynamicShape(op)) {
|
if (returnsDynamicShape(op)) {
|
||||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
if (failed(shape_op.inferShapes())) {
|
if (failed(shape_op.inferShapes())) {
|
||||||
op->emitError("unable to infer shape of operation without shape "
|
op->emitError("shape inference failed");
|
||||||
"inference method");
|
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -79,7 +78,10 @@ public:
|
||||||
// shaped outputs. All those operation need to implement the inferShape()
|
// shaped outputs. All those operation need to implement the inferShape()
|
||||||
// method.
|
// method.
|
||||||
if (op->getName().getStringRef() != "onnx.Exp" &&
|
if (op->getName().getStringRef() != "onnx.Exp" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Atan" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Tan" &&
|
||||||
op->getName().getStringRef() != "onnx.Tanh" &&
|
op->getName().getStringRef() != "onnx.Tanh" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Sin" &&
|
||||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
op->getName().getStringRef() != "onnx.Sinh" &&
|
||||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
op->getName().getStringRef() != "onnx.Cosh" &&
|
||||||
op->getName().getStringRef() != "onnx.Cos" &&
|
op->getName().getStringRef() != "onnx.Cos" &&
|
||||||
|
@ -130,7 +132,14 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.RNN" &&
|
op->getName().getStringRef() != "onnx.RNN" &&
|
||||||
op->getName().getStringRef() != "onnx.LSTM" &&
|
op->getName().getStringRef() != "onnx.LSTM" &&
|
||||||
op->getName().getStringRef() != "onnx.GRU" &&
|
op->getName().getStringRef() != "onnx.GRU" &&
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Cast" &&
|
||||||
|
op->getName().getStringRef() != "onnx.ConvTranspose" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Flatten" &&
|
||||||
|
op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" &&
|
||||||
|
op->getName().getStringRef() != "onnx.QuantizeLinear" &&
|
||||||
|
op->getName().getStringRef() != "onnx.DequantizeLinear" &&
|
||||||
|
op->getName().getStringRef() != "onnx.ConvInteger")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
return !result_type.isa<NoneType>() &&
|
return !result_type.isa<NoneType>() &&
|
||||||
|
|
|
@ -589,6 +589,19 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test the flatten op inference.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Flatten"(%arg0) {axis = 1 : i64} : (tensor<5x2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_flatten_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 1 : i64} : (tensor<5x2x3x4xf32>) -> tensor<5x24xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x24xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test the reshape op inference when concat are present.
|
/// Test the reshape op inference when concat are present.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -872,3 +885,210 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>)
|
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>)
|
||||||
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test the cast op inference.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_cast_1(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Cast"(%arg0) {to = 1} : (tensor<2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_cast_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_cast_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xui8> {
|
||||||
|
%1 = "onnx.Cast"(%arg0) {to = 2} : (tensor<2x3x4xf32>) -> tensor<*xui8>
|
||||||
|
"std.return"(%1) : (tensor<*xui8>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_cast_2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x3x4xi8>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xsi8> {
|
||||||
|
%1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xsi8>
|
||||||
|
"std.return"(%1) : (tensor<*xsi8>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_cast_3
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x3x4xi8>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_cast_10(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf16> {
|
||||||
|
%1 = "onnx.Cast"(%arg0) {to = 10} : (tensor<2x3x4xf32>) -> tensor<*xf16>
|
||||||
|
"std.return"(%1) : (tensor<*xf16>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_cast_10
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 10 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf16>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x3x4xf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test the quantization op inferences.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xi8> {
|
||||||
|
%1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xi8>, tensor<*xi8>, tensor<*xi8>)
|
||||||
|
"std.return"(%1#0) {} : (tensor<*xi8>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_dyn_quantize_linear_1
|
||||||
|
// CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>)
|
||||||
|
// CHECK: return [[RES]] : tensor<5x2x3x4xi8>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xi8> {
|
||||||
|
%1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<*xi8>
|
||||||
|
"std.return"(%1) {} : (tensor<*xi8>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_quantize_linear_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xi8>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x2x3x4xi8>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) {} : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_dequantize_linear_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x2x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for ConvInteger operation and all its attributes.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Default and required attributes for 1-D convolution.
|
||||||
|
|
||||||
|
func @test_convinteger_0(%arg0 : tensor<1x2x32xi8>, %arg1 : tensor<5x2x6xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xi8>, tensor<5x2x6xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_0
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xi8>, tensor<5x2x6xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x27xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default and required attributes.
|
||||||
|
|
||||||
|
func @test_convinteger_1(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_1
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x27x58xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// kernel_shape attribute.
|
||||||
|
|
||||||
|
func @test_convinteger_2(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_2
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x25x56xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// pads attribute.
|
||||||
|
/// Use pads to make output size equal to input size by adding K - 1 to the result.
|
||||||
|
|
||||||
|
func @test_convinteger_3(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_3
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// auto_pad set to SAME_UPPER and SAME_LOWER.
|
||||||
|
|
||||||
|
func @test_convinteger_4(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_4
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_convinteger_5(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_5
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// auto_pad set to VALID.
|
||||||
|
|
||||||
|
func @test_convinteger_6(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_6
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x27x55xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// With strides attribute.
|
||||||
|
|
||||||
|
func @test_convinteger_7(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_7
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x14x20xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// auto_pad set to SAME_UPPER with strides attribute.
|
||||||
|
/// The auto_pad will pas as if stride is equal to 1.
|
||||||
|
|
||||||
|
func @test_convinteger_8(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_8
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x16x22xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// dilations attribute.
|
||||||
|
|
||||||
|
func @test_convinteger_9(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_9
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x22x46xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// dilations attribute with stride.
|
||||||
|
|
||||||
|
func @test_convinteger_10(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_10
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x11x23xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// dilations attribute with auto_pad set to SAME_UPPER.
|
||||||
|
|
||||||
|
func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor<i8>, %arg3 : tensor<i8>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_convinteger_11
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
|
||||||
|
}
|
||||||
|
|
|
@ -249,13 +249,14 @@ special_op_handler = dict([
|
||||||
|
|
||||||
# Operations supporting shape inference.
|
# Operations supporting shape inference.
|
||||||
OpsWithShapeInference = [
|
OpsWithShapeInference = [
|
||||||
'Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div',
|
'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul',
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||||
'LSTM', 'GRU', 'Split', 'Pad'
|
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
||||||
|
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
Loading…
Reference in New Issue