Shape inference for some ops in RNN models (#207)

* Shape inference for ShapeOp

* Shape inference for TileOp

* Fix check-doc

* Fix onnx-mlir-doc

* Shape inference for GatherOp

* Check validity of GatherOp's indices tensor

* Shape inference for Slice

* Tests for SliceOp

* Fix importing none inputs

* Type inference for constantofshape

* Empty tensor in case of ConstantOfShape

* Remove unrelated changes

Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
Tung D. Le 2020-07-22 23:15:56 +09:00 committed by GitHub
parent 1263d01968
commit 034f98c00c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 629 additions and 10 deletions

View File

@ -5556,7 +5556,7 @@ ONNX Tile operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values
`repeats` | tensor of 64-bit signless integer values or memref of any type values `repeats` | tensor of 64-bit signless integer values or memref of any type values or none type
#### Results: #### Results:

View File

@ -2258,6 +2258,362 @@ LogicalResult ONNXConvIntegerOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Shape
//===----------------------------------------------------------------------===//
LogicalResult ONNXShapeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
// Output is an 1D int64 tensor containing the shape of the input tensor.
int64_t rank = data().getType().cast<RankedTensorType>().getRank();
SmallVector<int64_t, 1> outDims(1, rank);
getResult().setType(
RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
return success();
}
//===----------------------------------------------------------------------===//
// Tile
//===----------------------------------------------------------------------===//
LogicalResult ONNXTileOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!input().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
// Read 'repeats' value.
if (!repeats().getType().isa<RankedTensorType>())
return emitError("Repeats tensor not ranked");
auto inputTensorTy = input().getType().cast<RankedTensorType>();
auto repeatsTensorTy = repeats().getType().cast<RankedTensorType>();
// 'repeats' tensor is an 1D tensor.
if (repeatsTensorTy.getShape().size() != 1)
return emitError("Repeats tensor must have rank one");
// 'repeats' tensor must have constant shape.
int64_t repeatsLength = repeatsTensorTy.getShape()[0];
if (repeatsLength < 0)
return emitError("Repeats tensor must have constant shape");
// Check the 1D repeats tensor length.
int64_t inputRank = inputTensorTy.getShape().size();
if (inputRank != repeatsLength)
return emitError("Repeats tensor must have the same length as the input's "
"dimension number.");
// Check if second argument of TileOp is a constant.
auto constantOp = getONNXConstantOp(repeats());
// Compute output's dimensions: output_dim[i] = input_dim[i] * repeats[i]
SmallVector<int64_t, 2> dims(inputRank, -1);
if (constantOp) {
// 1. Initialize output_dim with values from 'input'.
// output_dim[i] = input[i]
for (decltype(inputRank) i = 0; i < inputRank; ++i)
dims[i] = inputTensorTy.getShape()[i];
// 2. Update output_dim using values from 'repeats'.
// Do this only for static 'input_dim[i]'.
// if (output_dim[i] != -1) output_dim[i] *= repeats[i]
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
if (!valueAttribute)
return emitError("DenseElementsAttr expected");
// Get repeat values from valueAttribute.
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i = 0; i < inputRank; ++i)
if (dims[i] != -1)
dims[i] *= (*valueIt++).cast<IntegerAttr>().getInt();
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
return emitError("Constant value must have same length as output's rank");
}
getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType()));
return success();
}
//===----------------------------------------------------------------------===//
// Gather
//===----------------------------------------------------------------------===//
LogicalResult ONNXGatherOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
if (!indices().getType().isa<RankedTensorType>())
return emitError("Indices tensor not ranked");
auto inputShape = data().getType().cast<RankedTensorType>().getShape();
auto indicesShape = indices().getType().cast<RankedTensorType>().getShape();
int64_t inputRank = inputShape.size();
int64_t indicesRank = indicesShape.size();
if (inputRank < 1)
return emitError("Input tensor must have rank >= 1");
// Read 'axis' attribute.
auto axisIndex = axis().getSExtValue();
// 'axis' must be in [-rank, rank-1]
if (axisIndex < -inputRank || axisIndex >= inputRank)
return emitError("Gather axis value out of bound");
// Convert a negative axis to a positive axis.
if (axisIndex < 0) {
axisIndex += inputRank;
auto builder = mlir::Builder(getContext());
axisAttr(builder.getI64IntegerAttr(axisIndex));
}
// If 'indices' is a constant, check whether its values are valid or not.
auto constantOp = getONNXConstantOp(indices());
if (constantOp && inputShape[axisIndex] != -1) {
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
if (!valueAttribute)
return emitError("DenseElementsAttr expected");
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
auto index = value.cast<IntegerAttr>().getInt();
if (index < -inputShape[axisIndex] || index >= inputShape[axisIndex])
return emitError("Indices tensor contains an out-of-bound index");
}
}
// Output has rank of 'indicesRank + (inputRank - 1).
// Output shape is constructed from 'input' by:
// replacing the dimension at 'axis' in 'input' by the shape of 'indices'.
SmallVector<int64_t, 1> outDims;
for (decltype(inputRank) i = 0; i < inputRank; ++i) {
if (i == axisIndex)
for (decltype(indicesRank) j = 0; j < indicesRank; ++j)
outDims.emplace_back(indicesShape[j]);
else
outDims.emplace_back(inputShape[i]);
}
getResult().setType(RankedTensorType::get(
outDims, data().getType().cast<RankedTensorType>().getElementType()));
return success();
}
//===----------------------------------------------------------------------===//
// ConstantOfShape
//===----------------------------------------------------------------------===//
LogicalResult ONNXConstantOfShapeOp::inferShapes() {
Type elementType;
// 'value' attribute is a one-element tensor whose value and datatype are used
// to set the output tensor's value and datatype..
if (value().hasValue()) {
elementType =
valueAttr().cast<DenseElementsAttr>().getType().getElementType();
} else {
// If 'value' attribute is not specified, it defaults to a tensor of value 0
// and datatype float32.
elementType = FloatType::getF32(getContext());
llvm::SmallVector<int64_t, 2> dims(1, 1);
auto tensorType = mlir::RankedTensorType::get(dims, elementType);
llvm::SmallVector<float, 1> values(1, 0.);
valueAttr(
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)));
}
// 'input' must be a 1D tensor.
auto inputShape = input().getType().cast<RankedTensorType>().getShape();
if (inputShape.size() != 1)
return emitError("Input tensor must be a 1D tensor");
if (inputShape[0] == -1)
return emitError("Input tensor must have static shape");
if (inputShape[0] == 0) {
// If 'input' is an empty tensor, the output would be a scalar.
getResult().setType(RankedTensorType::get({}, elementType));
return success();
}
// Calculate output dimensions.
SmallVector<int64_t, 4> outputDims(inputShape[0], -1);
// If 'input' is a constant, check whether its values are valid or not.
// If the values are valid, it is possible to infer shape.
if (auto constantOp = getONNXConstantOp(input())) {
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
// Get repeat values from valueAttribute.
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i = 0; i < inputShape[0]; ++i) {
auto dim = (*valueIt++).cast<IntegerAttr>().getInt();
if (dim < 0)
return emitError("All values of the input tensor must be >=0");
outputDims[i] = dim;
}
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
return emitError("Constant value must have same length as output's rank");
}
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
//===----------------------------------------------------------------------===//
// Slice
//===----------------------------------------------------------------------===//
LogicalResult ONNXSliceOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
auto elementType = data().getType().cast<ShapedType>().getElementType();
auto dataShape = data().getType().cast<ShapedType>().getShape();
int64_t numDims = dataShape.size();
SmallVector<int64_t, 2> outputDims(numDims, -1);
// If 'starts', 'ends', 'axes', and 'steps' are constants, check whether their
// values are valid or not. If the values are valid, it is possible to infer
// shape.
//
// 'starts', 'ends', and 'steps' are for each axis in the list of axes, so
// processing 'axes' first.
// Check and get 'axes' tensor.
SmallVector<int64_t, 2> axesValue;
if (axes().getType().isa<NoneType>()) {
// If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
for (int i = 0; i < numDims; ++i)
axesValue.emplace_back(i);
} else if (auto constantOp = getONNXConstantOp(axes())) {
// If `axes` are constants, read them."
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
int64_t axis = value.cast<IntegerAttr>().getInt();
if (axis < -numDims || axis >= numDims)
return emitError("Axes contains an out-of-bound index");
if (axis < 0)
axis += numDims;
if (dataShape[axis] == -1) {
// It is unsafe to infer shape for an axis with an unknown dimension,
// since we can not validate 'start' and 'end' values from this
// dimension.
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
axesValue.emplace_back(axis);
}
} else {
// Cannot infer a static shape.
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
// Check 'starts' tensor.
SmallVector<int64_t, 2> startsValue;
if (auto constantOp = getONNXConstantOp(starts())) {
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
int i = 0;
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
int64_t axis = axesValue[i];
int64_t index = value.cast<IntegerAttr>().getInt();
if (index < -dataShape[axis])
index = 0;
else if (index > dataShape[axis])
index = dataShape[axis];
else if (index < 0)
index += dataShape[axis];
startsValue.emplace_back(index);
i++;
}
if (i != axesValue.size())
emitError("starts and axes tensors must have the same length");
} else {
// Cannot infer a static shape.
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
// Check 'ends' tensor.
SmallVector<int64_t, 2> endsValue;
if (auto constantOp = getONNXConstantOp(ends())) {
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
int i = 0;
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
int64_t axis = axesValue[i];
int64_t index = value.cast<IntegerAttr>().getInt();
if (index < -dataShape[axis])
index = 0;
else if (index > dataShape[axis])
index = dataShape[axis];
else if (index < 0)
index += dataShape[axis];
endsValue.emplace_back(index);
i++;
}
if (i != axesValue.size())
emitError("ends and axes tensors must have the same length");
} else {
// Cannot infer a static shape.
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
// Check and get 'steps' tensor.
SmallVector<int64_t, 2> stepsValue;
if (steps().getType().isa<NoneType>()) {
// If `steps` are omitted, they are set to `[1, ..., 1]` of len(starts)."
for (int i = 0; i < startsValue.size(); ++i)
stepsValue.emplace_back(1);
} else if (auto constantOp = getONNXConstantOp(steps())) {
// If `steps` are constants, read them."
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
int i = 0;
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
int64_t index = value.cast<IntegerAttr>().getInt();
if (index == 0)
emitError("step cannot be zero");
stepsValue.emplace_back(index);
i++;
}
if (i != axesValue.size())
emitError("steps and axes tensors must have the same length");
} else {
// Cannot infer a static shape.
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
// All 'starts', 'ends', 'steps' values are valid. Now calculate output
// dimensions for axes in 'axes'.
for (int i = 0; i < axesValue.size(); i++) {
int64_t axis = axesValue[i];
int64_t start = startsValue[i];
int64_t end = endsValue[i];
int64_t step = stepsValue[i];
if (step < 0)
step = -step;
int64_t q = (end - start) / step;
int64_t r = (end - start) % step;
if (r != 0)
q += 1;
outputDims[axis] = q;
}
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -662,7 +662,7 @@ def ONNXConstantOp:ONNX_Op<"Constant",
} }
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
let summary = "ONNX ConstantOfShape operation"; let summary = "ONNX ConstantOfShape operation";
let description = [{ let description = [{
"Generate a tensor with given value and shape." "Generate a tensor with given value and shape."
@ -680,6 +680,17 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
static std::vector<int> getTypeMap() { static std::vector<int> getTypeMap() {
return {-1}; return {-1};
} }
std::vector<mlir::Type> resultTypeInference() {
std::vector<mlir::Type> resultTypes;
if (auto attr = valueAttr()) {
resultTypes.push_back(mlir::UnrankedTensorType::get(
attr.getType().cast<ShapedType>().getElementType()));
} else {
resultTypes.push_back(mlir::UnrankedTensorType::get(
FloatType::getF32(getContext())));
}
return resultTypes;
}
}]; }];
} }
@ -1438,7 +1449,7 @@ def ONNXGRUOp:ONNX_Op<"GRU",
} }
def ONNXGatherOp:ONNX_Op<"Gather", def ONNXGatherOp:ONNX_Op<"Gather",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Gather operation"; let summary = "ONNX Gather operation";
let description = [{ let description = [{
"Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather" "Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather"
@ -4695,7 +4706,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
} }
def ONNXShapeOp:ONNX_Op<"Shape", def ONNXShapeOp:ONNX_Op<"Shape",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Shape operation"; let summary = "ONNX Shape operation";
let description = [{ let description = [{
"Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
@ -4850,7 +4861,7 @@ def ONNXSizeOp:ONNX_Op<"Size",
} }
def ONNXSliceOp:ONNX_Op<"Slice", def ONNXSliceOp:ONNX_Op<"Slice",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Slice operation"; let summary = "ONNX Slice operation";
let description = [{ let description = [{
"Produces a slice of the input tensor along multiple axes. Similar to numpy:" "Produces a slice of the input tensor along multiple axes. Similar to numpy:"
@ -5345,7 +5356,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu",
} }
def ONNXTileOp:ONNX_Op<"Tile", def ONNXTileOp:ONNX_Op<"Tile",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
let summary = "ONNX Tile operation"; let summary = "ONNX Tile operation";
let description = [{ let description = [{
"Constructs a tensor by tiling a given tensor." "Constructs a tensor by tiling a given tensor."
@ -5353,7 +5364,7 @@ def ONNXTileOp:ONNX_Op<"Tile",
"For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]" "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]"
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$input, let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$input,
AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$repeats); AnyTypeOf<[TensorOf<[I64]>, AnyMemRef, NoneType]>:$repeats);
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$output); let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$output);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
@ -5365,6 +5376,9 @@ def ONNXTileOp:ONNX_Op<"Tile",
static std::vector<int> getTypeMap() { static std::vector<int> getTypeMap() {
return {20}; return {20};
} }
std::map<std::string, size_t> promotableConstOperands() {
return {{"repeats", 1}};
}
}]; }];
} }

View File

@ -1128,3 +1128,230 @@ func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32> // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
} }
// -----
func @test_shape(%arg0: tensor<?x3x2xf32>) -> tensor<*xi64> {
%0 = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<*xi64>
return %0 : tensor<*xi64>
// CHECK-LABEL: test_shape
// CHECK: [[RES:%.+]] = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<3xi64>
// CHECK: return [[RES]] : tensor<3xi64>
}
// -----
func @test_tile_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
%0 = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_tile_dynamic
// CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: return [[RES]] : tensor<?x?x?x?xf32>
}
// -----
func @test_tile_constant(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64>
%1 = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_tile_constant
// CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<25x25x16x64xf32>
// CHECK: return [[RES]] : tensor<25x25x16x64xf32>
}
// -----
func @test_gather_axis0(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
%0 = "onnx.Gather"(%arg0, %arg1) {axis = 0} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gather_axis0
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 0 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x2x3xf32>
// CHECK: return [[RES]] : tensor<1x2x3xf32>
}
// -----
func @test_gather_axis1(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
%0 = "onnx.Gather"(%arg0, %arg1) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gather_axis1
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
// CHECK: return [[RES]] : tensor<3x1x2xf32>
}
// -----
func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
%0 = "onnx.Gather"(%arg0, %arg1) {axis = -1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gather_negative_axis
// CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
// CHECK: return [[RES]] : tensor<3x1x2xf32>
}
// -----
func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> {
%0 = "onnx.ConstantOfShape"(%arg0) : (tensor<0xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_of_shape_empty_tensor
// CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<0xi64>) -> tensor<f32>
// CHECK: return [[RES]] : tensor<f32>
}
// -----
func @test_constant_of_shape(%arg0 : tensor<3xi64>) -> tensor<*xf32> {
%0 = "onnx.ConstantOfShape"(%arg0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_of_shape
// CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: return [[RES]] : tensor<?x?x?xf32>
}
// -----
func @test_constant_of_shape_constant() -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64> } : () -> tensor<3xi64>
%1 = "onnx.ConstantOfShape"(%0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_of_shape_constant
// CHECK: [[CONSTANT:%.+]] = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"([[CONSTANT]]) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: return [[RES]] : tensor<3x4x5xf32>
}
// -----
func @test_slice(%arg0 : tensor<2x4xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>, %arg3: tensor<2xi64>, %arg4: tensor<2xi64>) -> tensor<*xf32> {
%1 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: return [[RES:%.+]] : tensor<?x?xf32>
}
// -----
func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
%axes = constant unit
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice_constant_default_axes
// CHECK: [[AXES:%.+]] = constant unit
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<1x2xf32>
// CHECK: return [[RES]] : tensor<1x2xf32>
}
// -----
func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
%steps = constant unit
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice_constant_default_steps
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STEPS:%.+]] = constant unit
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<1x3xf32>
// CHECK: return [[RES]] : tensor<1x3xf32>
}
// -----
func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice_all_constant
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
// CHECK: return [[RES]] : tensor<1x2xf32>
}
// -----
func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
%axes = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64> } : () -> tensor<2xi64>
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%ends = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64> } : () -> tensor<2xi64>
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice_all_constant_negative
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
// CHECK: return [[RES]] : tensor<1x2xf32>
}
// -----
func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%ends = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
%steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice_all_constant_end_outofbound
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
// CHECK: return [[RES]] : tensor<1x2xf32>
}
// -----
func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
%axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
%starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
%steps = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64> } : () -> tensor<2xi64>
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_slice_all_constant_negative_steps
// CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
// CHECK: return [[RES]] : tensor<1x2xf32>
}

View File

@ -51,3 +51,16 @@ func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32> // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
} }
// -----
func @test_tile_should_promote_to_attribute(%arg0 : tensor<?x10x10xf32>) -> tensor<*xf32> {
%shape = constant dense<[6, 7, 42]> : tensor<3xi64>
%0 = "onnx.Tile"(%arg0, %shape) : (tensor<?x10x10xf32>, tensor<3xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: test_tile_should_promote_to_attribute
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[TILE:%.+]] = "onnx.Tile"(%{{.*}}, [[NONE]]) {repeats = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return [[TILE]] : tensor<*xf32>
}

View File

@ -252,7 +252,7 @@ OpsWithShapeInference = [
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
'Squeeze' 'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
@ -266,7 +266,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler']
# tuples, whose first item is the attribute/operand name, and the second item is # tuples, whose first item is the attribute/operand name, and the second item is
# the index at which such operand occurs in the list of the operation's inputs. # the index at which such operand occurs in the list of the operation's inputs.
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
"Pad": [("pads", 1), ("constant_value", 2)]} "Pad": [("pads", 1), ("constant_value", 2)],
"Tile": [("repeats", 1)]}
# Interface for special handling of type inference # Interface for special handling of type inference
# The common code are put into get_type_inference_func # The common code are put into get_type_inference_func
@ -281,7 +282,15 @@ OpsWithResultTypeInference = {
'''auto toAttr = to().getSExtValue(); '''auto toAttr = to().getSExtValue();
auto builder = mlir::OpBuilder(getContext()); auto builder = mlir::OpBuilder(getContext());
resultTypes.push_back(mlir::UnrankedTensorType::get( resultTypes.push_back(mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));''' convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));''',
"ConstantOfShape":
'''if (auto attr = valueAttr()) {
resultTypes.push_back(mlir::UnrankedTensorType::get(
attr.getType().cast<ShapedType>().getElementType()));
} else {
resultTypes.push_back(mlir::UnrankedTensorType::get(
FloatType::getF32(getContext())));
}'''
} }
# Add an Op in this list if the Op needs result type deduction which is required # Add an Op in this list if the Op needs result type deduction which is required