Shape inference for some ops in RNN models (#207)
* Shape inference for ShapeOp * Shape inference for TileOp * Fix check-doc * Fix onnx-mlir-doc * Shape inference for GatherOp * Check validity of GatherOp's indices tensor * Shape inference for Slice * Tests for SliceOp * Fix importing none inputs * Type inference for constantofshape * Empty tensor in case of ConstantOfShape * Remove unrelated changes Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
1263d01968
commit
034f98c00c
|
@ -5556,7 +5556,7 @@ ONNX Tile operation
|
|||
| Operand | Description |
|
||||
| :-----: | ----------- |
|
||||
`input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values
|
||||
`repeats` | tensor of 64-bit signless integer values or memref of any type values
|
||||
`repeats` | tensor of 64-bit signless integer values or memref of any type values or none type
|
||||
|
||||
#### Results:
|
||||
|
||||
|
|
|
@ -2258,6 +2258,362 @@ LogicalResult ONNXConvIntegerOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shape
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXShapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return emitError("Input tensor not ranked");
|
||||
|
||||
// Output is an 1D int64 tensor containing the shape of the input tensor.
|
||||
int64_t rank = data().getType().cast<RankedTensorType>().getRank();
|
||||
SmallVector<int64_t, 1> outDims(1, rank);
|
||||
getResult().setType(
|
||||
RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tile
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXTileOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!input().getType().isa<RankedTensorType>())
|
||||
return emitError("Input tensor not ranked");
|
||||
|
||||
// Read 'repeats' value.
|
||||
if (!repeats().getType().isa<RankedTensorType>())
|
||||
return emitError("Repeats tensor not ranked");
|
||||
|
||||
auto inputTensorTy = input().getType().cast<RankedTensorType>();
|
||||
auto repeatsTensorTy = repeats().getType().cast<RankedTensorType>();
|
||||
|
||||
// 'repeats' tensor is an 1D tensor.
|
||||
if (repeatsTensorTy.getShape().size() != 1)
|
||||
return emitError("Repeats tensor must have rank one");
|
||||
|
||||
// 'repeats' tensor must have constant shape.
|
||||
int64_t repeatsLength = repeatsTensorTy.getShape()[0];
|
||||
if (repeatsLength < 0)
|
||||
return emitError("Repeats tensor must have constant shape");
|
||||
|
||||
// Check the 1D repeats tensor length.
|
||||
int64_t inputRank = inputTensorTy.getShape().size();
|
||||
if (inputRank != repeatsLength)
|
||||
return emitError("Repeats tensor must have the same length as the input's "
|
||||
"dimension number.");
|
||||
|
||||
// Check if second argument of TileOp is a constant.
|
||||
auto constantOp = getONNXConstantOp(repeats());
|
||||
|
||||
// Compute output's dimensions: output_dim[i] = input_dim[i] * repeats[i]
|
||||
SmallVector<int64_t, 2> dims(inputRank, -1);
|
||||
if (constantOp) {
|
||||
// 1. Initialize output_dim with values from 'input'.
|
||||
// output_dim[i] = input[i]
|
||||
for (decltype(inputRank) i = 0; i < inputRank; ++i)
|
||||
dims[i] = inputTensorTy.getShape()[i];
|
||||
|
||||
// 2. Update output_dim using values from 'repeats'.
|
||||
// Do this only for static 'input_dim[i]'.
|
||||
// if (output_dim[i] != -1) output_dim[i] *= repeats[i]
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
if (!valueAttribute)
|
||||
return emitError("DenseElementsAttr expected");
|
||||
// Get repeat values from valueAttribute.
|
||||
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
||||
for (int i = 0; i < inputRank; ++i)
|
||||
if (dims[i] != -1)
|
||||
dims[i] *= (*valueIt++).cast<IntegerAttr>().getInt();
|
||||
|
||||
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
||||
return emitError("Constant value must have same length as output's rank");
|
||||
}
|
||||
|
||||
getResult().setType(
|
||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Gather
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXGatherOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return emitError("Input tensor not ranked");
|
||||
if (!indices().getType().isa<RankedTensorType>())
|
||||
return emitError("Indices tensor not ranked");
|
||||
|
||||
auto inputShape = data().getType().cast<RankedTensorType>().getShape();
|
||||
auto indicesShape = indices().getType().cast<RankedTensorType>().getShape();
|
||||
int64_t inputRank = inputShape.size();
|
||||
int64_t indicesRank = indicesShape.size();
|
||||
|
||||
if (inputRank < 1)
|
||||
return emitError("Input tensor must have rank >= 1");
|
||||
|
||||
// Read 'axis' attribute.
|
||||
auto axisIndex = axis().getSExtValue();
|
||||
// 'axis' must be in [-rank, rank-1]
|
||||
if (axisIndex < -inputRank || axisIndex >= inputRank)
|
||||
return emitError("Gather axis value out of bound");
|
||||
// Convert a negative axis to a positive axis.
|
||||
if (axisIndex < 0) {
|
||||
axisIndex += inputRank;
|
||||
auto builder = mlir::Builder(getContext());
|
||||
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
||||
}
|
||||
|
||||
// If 'indices' is a constant, check whether its values are valid or not.
|
||||
auto constantOp = getONNXConstantOp(indices());
|
||||
if (constantOp && inputShape[axisIndex] != -1) {
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
if (!valueAttribute)
|
||||
return emitError("DenseElementsAttr expected");
|
||||
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
||||
auto index = value.cast<IntegerAttr>().getInt();
|
||||
if (index < -inputShape[axisIndex] || index >= inputShape[axisIndex])
|
||||
return emitError("Indices tensor contains an out-of-bound index");
|
||||
}
|
||||
}
|
||||
|
||||
// Output has rank of 'indicesRank + (inputRank - 1).
|
||||
// Output shape is constructed from 'input' by:
|
||||
// replacing the dimension at 'axis' in 'input' by the shape of 'indices'.
|
||||
SmallVector<int64_t, 1> outDims;
|
||||
for (decltype(inputRank) i = 0; i < inputRank; ++i) {
|
||||
if (i == axisIndex)
|
||||
for (decltype(indicesRank) j = 0; j < indicesRank; ++j)
|
||||
outDims.emplace_back(indicesShape[j]);
|
||||
else
|
||||
outDims.emplace_back(inputShape[i]);
|
||||
}
|
||||
|
||||
getResult().setType(RankedTensorType::get(
|
||||
outDims, data().getType().cast<RankedTensorType>().getElementType()));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOfShape
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXConstantOfShapeOp::inferShapes() {
|
||||
Type elementType;
|
||||
|
||||
// 'value' attribute is a one-element tensor whose value and datatype are used
|
||||
// to set the output tensor's value and datatype..
|
||||
if (value().hasValue()) {
|
||||
elementType =
|
||||
valueAttr().cast<DenseElementsAttr>().getType().getElementType();
|
||||
} else {
|
||||
// If 'value' attribute is not specified, it defaults to a tensor of value 0
|
||||
// and datatype float32.
|
||||
elementType = FloatType::getF32(getContext());
|
||||
|
||||
llvm::SmallVector<int64_t, 2> dims(1, 1);
|
||||
auto tensorType = mlir::RankedTensorType::get(dims, elementType);
|
||||
|
||||
llvm::SmallVector<float, 1> values(1, 0.);
|
||||
valueAttr(
|
||||
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)));
|
||||
}
|
||||
|
||||
// 'input' must be a 1D tensor.
|
||||
auto inputShape = input().getType().cast<RankedTensorType>().getShape();
|
||||
if (inputShape.size() != 1)
|
||||
return emitError("Input tensor must be a 1D tensor");
|
||||
if (inputShape[0] == -1)
|
||||
return emitError("Input tensor must have static shape");
|
||||
if (inputShape[0] == 0) {
|
||||
// If 'input' is an empty tensor, the output would be a scalar.
|
||||
getResult().setType(RankedTensorType::get({}, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Calculate output dimensions.
|
||||
SmallVector<int64_t, 4> outputDims(inputShape[0], -1);
|
||||
// If 'input' is a constant, check whether its values are valid or not.
|
||||
// If the values are valid, it is possible to infer shape.
|
||||
if (auto constantOp = getONNXConstantOp(input())) {
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
// Get repeat values from valueAttribute.
|
||||
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
||||
for (int i = 0; i < inputShape[0]; ++i) {
|
||||
auto dim = (*valueIt++).cast<IntegerAttr>().getInt();
|
||||
if (dim < 0)
|
||||
return emitError("All values of the input tensor must be >=0");
|
||||
outputDims[i] = dim;
|
||||
}
|
||||
|
||||
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
||||
return emitError("Constant value must have same length as output's rank");
|
||||
}
|
||||
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Slice
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXSliceOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return emitError("Input tensor not ranked");
|
||||
|
||||
auto elementType = data().getType().cast<ShapedType>().getElementType();
|
||||
auto dataShape = data().getType().cast<ShapedType>().getShape();
|
||||
int64_t numDims = dataShape.size();
|
||||
|
||||
SmallVector<int64_t, 2> outputDims(numDims, -1);
|
||||
// If 'starts', 'ends', 'axes', and 'steps' are constants, check whether their
|
||||
// values are valid or not. If the values are valid, it is possible to infer
|
||||
// shape.
|
||||
//
|
||||
// 'starts', 'ends', and 'steps' are for each axis in the list of axes, so
|
||||
// processing 'axes' first.
|
||||
|
||||
// Check and get 'axes' tensor.
|
||||
SmallVector<int64_t, 2> axesValue;
|
||||
if (axes().getType().isa<NoneType>()) {
|
||||
// If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
|
||||
for (int i = 0; i < numDims; ++i)
|
||||
axesValue.emplace_back(i);
|
||||
} else if (auto constantOp = getONNXConstantOp(axes())) {
|
||||
// If `axes` are constants, read them."
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
||||
int64_t axis = value.cast<IntegerAttr>().getInt();
|
||||
if (axis < -numDims || axis >= numDims)
|
||||
return emitError("Axes contains an out-of-bound index");
|
||||
if (axis < 0)
|
||||
axis += numDims;
|
||||
if (dataShape[axis] == -1) {
|
||||
// It is unsafe to infer shape for an axis with an unknown dimension,
|
||||
// since we can not validate 'start' and 'end' values from this
|
||||
// dimension.
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
axesValue.emplace_back(axis);
|
||||
}
|
||||
} else {
|
||||
// Cannot infer a static shape.
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check 'starts' tensor.
|
||||
SmallVector<int64_t, 2> startsValue;
|
||||
if (auto constantOp = getONNXConstantOp(starts())) {
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
int i = 0;
|
||||
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
||||
int64_t axis = axesValue[i];
|
||||
int64_t index = value.cast<IntegerAttr>().getInt();
|
||||
if (index < -dataShape[axis])
|
||||
index = 0;
|
||||
else if (index > dataShape[axis])
|
||||
index = dataShape[axis];
|
||||
else if (index < 0)
|
||||
index += dataShape[axis];
|
||||
startsValue.emplace_back(index);
|
||||
i++;
|
||||
}
|
||||
if (i != axesValue.size())
|
||||
emitError("starts and axes tensors must have the same length");
|
||||
} else {
|
||||
// Cannot infer a static shape.
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check 'ends' tensor.
|
||||
SmallVector<int64_t, 2> endsValue;
|
||||
if (auto constantOp = getONNXConstantOp(ends())) {
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
int i = 0;
|
||||
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
||||
int64_t axis = axesValue[i];
|
||||
int64_t index = value.cast<IntegerAttr>().getInt();
|
||||
if (index < -dataShape[axis])
|
||||
index = 0;
|
||||
else if (index > dataShape[axis])
|
||||
index = dataShape[axis];
|
||||
else if (index < 0)
|
||||
index += dataShape[axis];
|
||||
endsValue.emplace_back(index);
|
||||
i++;
|
||||
}
|
||||
if (i != axesValue.size())
|
||||
emitError("ends and axes tensors must have the same length");
|
||||
} else {
|
||||
// Cannot infer a static shape.
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check and get 'steps' tensor.
|
||||
SmallVector<int64_t, 2> stepsValue;
|
||||
if (steps().getType().isa<NoneType>()) {
|
||||
// If `steps` are omitted, they are set to `[1, ..., 1]` of len(starts)."
|
||||
for (int i = 0; i < startsValue.size(); ++i)
|
||||
stepsValue.emplace_back(1);
|
||||
} else if (auto constantOp = getONNXConstantOp(steps())) {
|
||||
// If `steps` are constants, read them."
|
||||
DenseElementsAttr valueAttribute =
|
||||
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||
int i = 0;
|
||||
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
||||
int64_t index = value.cast<IntegerAttr>().getInt();
|
||||
if (index == 0)
|
||||
emitError("step cannot be zero");
|
||||
stepsValue.emplace_back(index);
|
||||
i++;
|
||||
}
|
||||
if (i != axesValue.size())
|
||||
emitError("steps and axes tensors must have the same length");
|
||||
} else {
|
||||
// Cannot infer a static shape.
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
// All 'starts', 'ends', 'steps' values are valid. Now calculate output
|
||||
// dimensions for axes in 'axes'.
|
||||
for (int i = 0; i < axesValue.size(); i++) {
|
||||
int64_t axis = axesValue[i];
|
||||
int64_t start = startsValue[i];
|
||||
int64_t end = endsValue[i];
|
||||
int64_t step = stepsValue[i];
|
||||
if (step < 0)
|
||||
step = -step;
|
||||
|
||||
int64_t q = (end - start) / step;
|
||||
int64_t r = (end - start) % step;
|
||||
if (r != 0)
|
||||
q += 1;
|
||||
outputDims[axis] = q;
|
||||
}
|
||||
|
||||
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -662,7 +662,7 @@ def ONNXConstantOp:ONNX_Op<"Constant",
|
|||
}
|
||||
|
||||
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||
let summary = "ONNX ConstantOfShape operation";
|
||||
let description = [{
|
||||
"Generate a tensor with given value and shape."
|
||||
|
@ -680,6 +680,17 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
|||
static std::vector<int> getTypeMap() {
|
||||
return {-1};
|
||||
}
|
||||
std::vector<mlir::Type> resultTypeInference() {
|
||||
std::vector<mlir::Type> resultTypes;
|
||||
if (auto attr = valueAttr()) {
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
attr.getType().cast<ShapedType>().getElementType()));
|
||||
} else {
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
FloatType::getF32(getContext())));
|
||||
}
|
||||
return resultTypes;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -1438,7 +1449,7 @@ def ONNXGRUOp:ONNX_Op<"GRU",
|
|||
}
|
||||
|
||||
def ONNXGatherOp:ONNX_Op<"Gather",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Gather operation";
|
||||
let description = [{
|
||||
"Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather"
|
||||
|
@ -4695,7 +4706,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
|
|||
}
|
||||
|
||||
def ONNXShapeOp:ONNX_Op<"Shape",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Shape operation";
|
||||
let description = [{
|
||||
"Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
|
||||
|
@ -4850,7 +4861,7 @@ def ONNXSizeOp:ONNX_Op<"Size",
|
|||
}
|
||||
|
||||
def ONNXSliceOp:ONNX_Op<"Slice",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Slice operation";
|
||||
let description = [{
|
||||
"Produces a slice of the input tensor along multiple axes. Similar to numpy:"
|
||||
|
@ -5345,7 +5356,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu",
|
|||
}
|
||||
|
||||
def ONNXTileOp:ONNX_Op<"Tile",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
|
||||
let summary = "ONNX Tile operation";
|
||||
let description = [{
|
||||
"Constructs a tensor by tiling a given tensor."
|
||||
|
@ -5353,7 +5364,7 @@ def ONNXTileOp:ONNX_Op<"Tile",
|
|||
"For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]"
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$input,
|
||||
AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$repeats);
|
||||
AnyTypeOf<[TensorOf<[I64]>, AnyMemRef, NoneType]>:$repeats);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$output);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
|
@ -5365,6 +5376,9 @@ def ONNXTileOp:ONNX_Op<"Tile",
|
|||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
std::map<std::string, size_t> promotableConstOperands() {
|
||||
return {{"repeats", 1}};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -1128,3 +1128,230 @@ func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi
|
|||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_shape(%arg0: tensor<?x3x2xf32>) -> tensor<*xi64> {
|
||||
%0 = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<*xi64>
|
||||
return %0 : tensor<*xi64>
|
||||
|
||||
// CHECK-LABEL: test_shape
|
||||
// CHECK: [[RES:%.+]] = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<3xi64>
|
||||
// CHECK: return [[RES]] : tensor<3xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_tile_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_tile_dynamic
|
||||
// CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: return [[RES]] : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_tile_constant(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64>
|
||||
%1 = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_tile_constant
|
||||
// CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<25x25x16x64xf32>
|
||||
// CHECK: return [[RES]] : tensor<25x25x16x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gather_axis0(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Gather"(%arg0, %arg1) {axis = 0} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gather_axis0
|
||||
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 0 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x2x3xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gather_axis1(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Gather"(%arg0, %arg1) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gather_axis1
|
||||
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<3x1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Gather"(%arg0, %arg1) {axis = -1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gather_negative_axis
|
||||
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<3x1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConstantOfShape"(%arg0) : (tensor<0xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_constant_of_shape_empty_tensor
|
||||
// CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<0xi64>) -> tensor<f32>
|
||||
// CHECK: return [[RES]] : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_constant_of_shape(%arg0 : tensor<3xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConstantOfShape"(%arg0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_constant_of_shape
|
||||
// CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return [[RES]] : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_constant_of_shape_constant() -> tensor<*xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64> } : () -> tensor<3xi64>
|
||||
%1 = "onnx.ConstantOfShape"(%0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_constant_of_shape_constant
|
||||
// CHECK: [[CONSTANT:%.+]] = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
// CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"([[CONSTANT]]) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: return [[RES]] : tensor<3x4x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice(%arg0 : tensor<2x4xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>, %arg3: tensor<2xi64>, %arg4: tensor<2xi64>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
// CHECK: return [[RES:%.+]] : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
|
||||
%axes = constant unit
|
||||
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice_constant_default_axes
|
||||
// CHECK: [[AXES:%.+]] = constant unit
|
||||
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
|
||||
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%steps = constant unit
|
||||
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice_constant_default_steps
|
||||
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STEPS:%.+]] = constant unit
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<1x3xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
|
||||
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice_all_constant
|
||||
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
|
||||
%axes = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%ends = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice_all_constant_negative
|
||||
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
|
||||
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%ends = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice_all_constant_end_outofbound
|
||||
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
|
||||
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%steps = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64> } : () -> tensor<2xi64>
|
||||
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_slice_all_constant_negative_steps
|
||||
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||
}
|
||||
|
|
|
@ -51,3 +51,16 @@ func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32
|
|||
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_tile_should_promote_to_attribute(%arg0 : tensor<?x10x10xf32>) -> tensor<*xf32> {
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi64>
|
||||
%0 = "onnx.Tile"(%arg0, %shape) : (tensor<?x10x10xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: test_tile_should_promote_to_attribute
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[TILE:%.+]] = "onnx.Tile"(%{{.*}}, [[NONE]]) {repeats = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[TILE]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -252,7 +252,7 @@ OpsWithShapeInference = [
|
|||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
||||
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
||||
'Squeeze'
|
||||
'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
|
@ -266,7 +266,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler']
|
|||
# tuples, whose first item is the attribute/operand name, and the second item is
|
||||
# the index at which such operand occurs in the list of the operation's inputs.
|
||||
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
|
||||
"Pad": [("pads", 1), ("constant_value", 2)]}
|
||||
"Pad": [("pads", 1), ("constant_value", 2)],
|
||||
"Tile": [("repeats", 1)]}
|
||||
|
||||
# Interface for special handling of type inference
|
||||
# The common code are put into get_type_inference_func
|
||||
|
@ -281,7 +282,15 @@ OpsWithResultTypeInference = {
|
|||
'''auto toAttr = to().getSExtValue();
|
||||
auto builder = mlir::OpBuilder(getContext());
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));'''
|
||||
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));''',
|
||||
"ConstantOfShape":
|
||||
'''if (auto attr = valueAttr()) {
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
attr.getType().cast<ShapedType>().getElementType()));
|
||||
} else {
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
FloatType::getF32(getContext())));
|
||||
}'''
|
||||
}
|
||||
|
||||
# Add an Op in this list if the Op needs result type deduction which is required
|
||||
|
|
Loading…
Reference in New Issue