diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 63e5277..4da038b 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -5556,7 +5556,7 @@ ONNX Tile operation | Operand | Description | | :-----: | ----------- | `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values -`repeats` | tensor of 64-bit signless integer values or memref of any type values +`repeats` | tensor of 64-bit signless integer values or memref of any type values or none type #### Results: diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index f9ee375..a83e000 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -2258,6 +2258,362 @@ LogicalResult ONNXConvIntegerOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Shape +//===----------------------------------------------------------------------===// + +LogicalResult ONNXShapeOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); + + // Output is an 1D int64 tensor containing the shape of the input tensor. + int64_t rank = data().getType().cast().getRank(); + SmallVector outDims(1, rank); + getResult().setType( + RankedTensorType::get(outDims, IntegerType::get(64, getContext()))); + return success(); +} + +//===----------------------------------------------------------------------===// +// Tile +//===----------------------------------------------------------------------===// + +LogicalResult ONNXTileOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!input().getType().isa()) + return emitError("Input tensor not ranked"); + + // Read 'repeats' value. + if (!repeats().getType().isa()) + return emitError("Repeats tensor not ranked"); + + auto inputTensorTy = input().getType().cast(); + auto repeatsTensorTy = repeats().getType().cast(); + + // 'repeats' tensor is an 1D tensor. + if (repeatsTensorTy.getShape().size() != 1) + return emitError("Repeats tensor must have rank one"); + + // 'repeats' tensor must have constant shape. + int64_t repeatsLength = repeatsTensorTy.getShape()[0]; + if (repeatsLength < 0) + return emitError("Repeats tensor must have constant shape"); + + // Check the 1D repeats tensor length. + int64_t inputRank = inputTensorTy.getShape().size(); + if (inputRank != repeatsLength) + return emitError("Repeats tensor must have the same length as the input's " + "dimension number."); + + // Check if second argument of TileOp is a constant. + auto constantOp = getONNXConstantOp(repeats()); + + // Compute output's dimensions: output_dim[i] = input_dim[i] * repeats[i] + SmallVector dims(inputRank, -1); + if (constantOp) { + // 1. Initialize output_dim with values from 'input'. + // output_dim[i] = input[i] + for (decltype(inputRank) i = 0; i < inputRank; ++i) + dims[i] = inputTensorTy.getShape()[i]; + + // 2. Update output_dim using values from 'repeats'. + // Do this only for static 'input_dim[i]'. + // if (output_dim[i] != -1) output_dim[i] *= repeats[i] + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + if (!valueAttribute) + return emitError("DenseElementsAttr expected"); + // Get repeat values from valueAttribute. + auto valueIt = valueAttribute.getValues().begin(); + for (int i = 0; i < inputRank; ++i) + if (dims[i] != -1) + dims[i] *= (*valueIt++).cast().getInt(); + + if (valueIt != valueAttribute.getValues().end()) + return emitError("Constant value must have same length as output's rank"); + } + + getResult().setType( + RankedTensorType::get(dims, inputTensorTy.getElementType())); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Gather +//===----------------------------------------------------------------------===// + +LogicalResult ONNXGatherOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); + if (!indices().getType().isa()) + return emitError("Indices tensor not ranked"); + + auto inputShape = data().getType().cast().getShape(); + auto indicesShape = indices().getType().cast().getShape(); + int64_t inputRank = inputShape.size(); + int64_t indicesRank = indicesShape.size(); + + if (inputRank < 1) + return emitError("Input tensor must have rank >= 1"); + + // Read 'axis' attribute. + auto axisIndex = axis().getSExtValue(); + // 'axis' must be in [-rank, rank-1] + if (axisIndex < -inputRank || axisIndex >= inputRank) + return emitError("Gather axis value out of bound"); + // Convert a negative axis to a positive axis. + if (axisIndex < 0) { + axisIndex += inputRank; + auto builder = mlir::Builder(getContext()); + axisAttr(builder.getI64IntegerAttr(axisIndex)); + } + + // If 'indices' is a constant, check whether its values are valid or not. + auto constantOp = getONNXConstantOp(indices()); + if (constantOp && inputShape[axisIndex] != -1) { + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + if (!valueAttribute) + return emitError("DenseElementsAttr expected"); + for (auto value : valueAttribute.getValues()) { + auto index = value.cast().getInt(); + if (index < -inputShape[axisIndex] || index >= inputShape[axisIndex]) + return emitError("Indices tensor contains an out-of-bound index"); + } + } + + // Output has rank of 'indicesRank + (inputRank - 1). + // Output shape is constructed from 'input' by: + // replacing the dimension at 'axis' in 'input' by the shape of 'indices'. + SmallVector outDims; + for (decltype(inputRank) i = 0; i < inputRank; ++i) { + if (i == axisIndex) + for (decltype(indicesRank) j = 0; j < indicesRank; ++j) + outDims.emplace_back(indicesShape[j]); + else + outDims.emplace_back(inputShape[i]); + } + + getResult().setType(RankedTensorType::get( + outDims, data().getType().cast().getElementType())); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConstantOfShape +//===----------------------------------------------------------------------===// + +LogicalResult ONNXConstantOfShapeOp::inferShapes() { + Type elementType; + + // 'value' attribute is a one-element tensor whose value and datatype are used + // to set the output tensor's value and datatype.. + if (value().hasValue()) { + elementType = + valueAttr().cast().getType().getElementType(); + } else { + // If 'value' attribute is not specified, it defaults to a tensor of value 0 + // and datatype float32. + elementType = FloatType::getF32(getContext()); + + llvm::SmallVector dims(1, 1); + auto tensorType = mlir::RankedTensorType::get(dims, elementType); + + llvm::SmallVector values(1, 0.); + valueAttr( + mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values))); + } + + // 'input' must be a 1D tensor. + auto inputShape = input().getType().cast().getShape(); + if (inputShape.size() != 1) + return emitError("Input tensor must be a 1D tensor"); + if (inputShape[0] == -1) + return emitError("Input tensor must have static shape"); + if (inputShape[0] == 0) { + // If 'input' is an empty tensor, the output would be a scalar. + getResult().setType(RankedTensorType::get({}, elementType)); + return success(); + } + + // Calculate output dimensions. + SmallVector outputDims(inputShape[0], -1); + // If 'input' is a constant, check whether its values are valid or not. + // If the values are valid, it is possible to infer shape. + if (auto constantOp = getONNXConstantOp(input())) { + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + // Get repeat values from valueAttribute. + auto valueIt = valueAttribute.getValues().begin(); + for (int i = 0; i < inputShape[0]; ++i) { + auto dim = (*valueIt++).cast().getInt(); + if (dim < 0) + return emitError("All values of the input tensor must be >=0"); + outputDims[i] = dim; + } + + if (valueIt != valueAttribute.getValues().end()) + return emitError("Constant value must have same length as output's rank"); + } + + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); +} + +//===----------------------------------------------------------------------===// +// Slice +//===----------------------------------------------------------------------===// + +LogicalResult ONNXSliceOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); + + auto elementType = data().getType().cast().getElementType(); + auto dataShape = data().getType().cast().getShape(); + int64_t numDims = dataShape.size(); + + SmallVector outputDims(numDims, -1); + // If 'starts', 'ends', 'axes', and 'steps' are constants, check whether their + // values are valid or not. If the values are valid, it is possible to infer + // shape. + // + // 'starts', 'ends', and 'steps' are for each axis in the list of axes, so + // processing 'axes' first. + + // Check and get 'axes' tensor. + SmallVector axesValue; + if (axes().getType().isa()) { + // If `axes` are omitted, they are set to `[0, ..., ndim-1]`." + for (int i = 0; i < numDims; ++i) + axesValue.emplace_back(i); + } else if (auto constantOp = getONNXConstantOp(axes())) { + // If `axes` are constants, read them." + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + for (auto value : valueAttribute.getValues()) { + int64_t axis = value.cast().getInt(); + if (axis < -numDims || axis >= numDims) + return emitError("Axes contains an out-of-bound index"); + if (axis < 0) + axis += numDims; + if (dataShape[axis] == -1) { + // It is unsafe to infer shape for an axis with an unknown dimension, + // since we can not validate 'start' and 'end' values from this + // dimension. + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); + } + axesValue.emplace_back(axis); + } + } else { + // Cannot infer a static shape. + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); + } + + // Check 'starts' tensor. + SmallVector startsValue; + if (auto constantOp = getONNXConstantOp(starts())) { + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + int i = 0; + for (auto value : valueAttribute.getValues()) { + int64_t axis = axesValue[i]; + int64_t index = value.cast().getInt(); + if (index < -dataShape[axis]) + index = 0; + else if (index > dataShape[axis]) + index = dataShape[axis]; + else if (index < 0) + index += dataShape[axis]; + startsValue.emplace_back(index); + i++; + } + if (i != axesValue.size()) + emitError("starts and axes tensors must have the same length"); + } else { + // Cannot infer a static shape. + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); + } + + // Check 'ends' tensor. + SmallVector endsValue; + if (auto constantOp = getONNXConstantOp(ends())) { + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + int i = 0; + for (auto value : valueAttribute.getValues()) { + int64_t axis = axesValue[i]; + int64_t index = value.cast().getInt(); + if (index < -dataShape[axis]) + index = 0; + else if (index > dataShape[axis]) + index = dataShape[axis]; + else if (index < 0) + index += dataShape[axis]; + endsValue.emplace_back(index); + i++; + } + if (i != axesValue.size()) + emitError("ends and axes tensors must have the same length"); + } else { + // Cannot infer a static shape. + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); + } + + // Check and get 'steps' tensor. + SmallVector stepsValue; + if (steps().getType().isa()) { + // If `steps` are omitted, they are set to `[1, ..., 1]` of len(starts)." + for (int i = 0; i < startsValue.size(); ++i) + stepsValue.emplace_back(1); + } else if (auto constantOp = getONNXConstantOp(steps())) { + // If `steps` are constants, read them." + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + int i = 0; + for (auto value : valueAttribute.getValues()) { + int64_t index = value.cast().getInt(); + if (index == 0) + emitError("step cannot be zero"); + stepsValue.emplace_back(index); + i++; + } + if (i != axesValue.size()) + emitError("steps and axes tensors must have the same length"); + } else { + // Cannot infer a static shape. + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); + } + + // All 'starts', 'ends', 'steps' values are valid. Now calculate output + // dimensions for axes in 'axes'. + for (int i = 0; i < axesValue.size(); i++) { + int64_t axis = axesValue[i]; + int64_t start = startsValue[i]; + int64_t end = endsValue[i]; + int64_t step = stepsValue[i]; + if (step < 0) + step = -step; + + int64_t q = (end - start) / step; + int64_t r = (end - start) % step; + if (r != 0) + q += 1; + outputDims[axis] = q; + } + + getResult().setType(RankedTensorType::get(outputDims, elementType)); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index c7ae1b9..92b193a 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -662,7 +662,7 @@ def ONNXConstantOp:ONNX_Op<"Constant", } def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"ResultTypeInferenceOpInterface">]> { let summary = "ONNX ConstantOfShape operation"; let description = [{ "Generate a tensor with given value and shape." @@ -680,6 +680,17 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", static std::vector getTypeMap() { return {-1}; } + std::vector resultTypeInference() { + std::vector resultTypes; + if (auto attr = valueAttr()) { + resultTypes.push_back(mlir::UnrankedTensorType::get( + attr.getType().cast().getElementType())); + } else { + resultTypes.push_back(mlir::UnrankedTensorType::get( + FloatType::getF32(getContext()))); + } + return resultTypes; + } }]; } @@ -1438,7 +1449,7 @@ def ONNXGRUOp:ONNX_Op<"GRU", } def ONNXGatherOp:ONNX_Op<"Gather", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Gather operation"; let description = [{ "Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather" @@ -4695,7 +4706,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", } def ONNXShapeOp:ONNX_Op<"Shape", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Shape operation"; let description = [{ "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." @@ -4850,7 +4861,7 @@ def ONNXSizeOp:ONNX_Op<"Size", } def ONNXSliceOp:ONNX_Op<"Slice", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Slice operation"; let description = [{ "Produces a slice of the input tensor along multiple axes. Similar to numpy:" @@ -5345,7 +5356,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", } def ONNXTileOp:ONNX_Op<"Tile", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"PromotableConstOperandsOpInterface">]> { let summary = "ONNX Tile operation"; let description = [{ "Constructs a tensor by tiling a given tensor." @@ -5353,7 +5364,7 @@ def ONNXTileOp:ONNX_Op<"Tile", "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]" }]; let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>, AnyMemRef]>:$input, - AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$repeats); + AnyTypeOf<[TensorOf<[I64]>, AnyMemRef, NoneType]>:$repeats); let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>, AnyMemRef]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { @@ -5365,6 +5376,9 @@ def ONNXTileOp:ONNX_Op<"Tile", static std::vector getTypeMap() { return {20}; } + std::map promotableConstOperands() { + return {{"repeats", 1}}; + } }]; } diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index a5008bd..2d816e2 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1128,3 +1128,230 @@ func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x32x64xi32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> } + +// ----- + +func @test_shape(%arg0: tensor) -> tensor<*xi64> { + %0 = "onnx.Shape"(%arg0) : (tensor) -> tensor<*xi64> + return %0 : tensor<*xi64> + + // CHECK-LABEL: test_shape + // CHECK: [[RES:%.+]] = "onnx.Shape"(%arg0) : (tensor) -> tensor<3xi64> + // CHECK: return [[RES]] : tensor<3xi64> +} + +// ----- + +func @test_tile_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> { + %0 = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_tile_dynamic + // CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor + // CHECK: return [[RES]] : tensor +} + +// ----- + +func @test_tile_constant(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64> + %1 = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_tile_constant + // CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<25x25x16x64xf32> + // CHECK: return [[RES]] : tensor<25x25x16x64xf32> +} + +// ----- + +func @test_gather_axis0(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> { + %0 = "onnx.Gather"(%arg0, %arg1) {axis = 0} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_axis0 + // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 0 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x2x3xf32> + // CHECK: return [[RES]] : tensor<1x2x3xf32> +} + +// ----- + +func @test_gather_axis1(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> { + %0 = "onnx.Gather"(%arg0, %arg1) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_axis1 + // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> + // CHECK: return [[RES]] : tensor<3x1x2xf32> +} + +// ----- + +func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> { + %0 = "onnx.Gather"(%arg0, %arg1) {axis = -1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gather_negative_axis + // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> + // CHECK: return [[RES]] : tensor<3x1x2xf32> +} + +// ----- + +func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> { + %0 = "onnx.ConstantOfShape"(%arg0) : (tensor<0xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_of_shape_empty_tensor + // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<0xi64>) -> tensor + // CHECK: return [[RES]] : tensor +} + +// ----- + +func @test_constant_of_shape(%arg0 : tensor<3xi64>) -> tensor<*xf32> { + %0 = "onnx.ConstantOfShape"(%arg0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_of_shape + // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor + // CHECK: return [[RES]] : tensor +} + +// ----- + +func @test_constant_of_shape_constant() -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64> } : () -> tensor<3xi64> + %1 = "onnx.ConstantOfShape"(%0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_of_shape_constant + // CHECK: [[CONSTANT:%.+]] = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64>} : () -> tensor<3xi64> + // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"([[CONSTANT]]) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<3x4x5xf32> + // CHECK: return [[RES]] : tensor<3x4x5xf32> +} + +// ----- + +func @test_slice(%arg0 : tensor<2x4xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>, %arg3: tensor<2xi64>, %arg4: tensor<2xi64>) -> tensor<*xf32> { + %1 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + // CHECK: return [[RES:%.+]] : tensor +} + +// ----- + +func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { + %axes = constant unit + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice_constant_default_axes + // CHECK: [[AXES:%.+]] = constant unit + // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<1x2xf32> + // CHECK: return [[RES]] : tensor<1x2xf32> +} + +// ----- + +func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { + %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = constant unit + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice_constant_default_steps + // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STEPS:%.+]] = constant unit + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<1x3xf32> + // CHECK: return [[RES]] : tensor<1x3xf32> +} + +// ----- + +func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { + %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice_all_constant + // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> + // CHECK: return [[RES]] : tensor<1x2xf32> +} + +// ----- + +func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { + %axes = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice_all_constant_negative + // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> + // CHECK: return [[RES]] : tensor<1x2xf32> +} + +// ----- + +func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { + %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice_all_constant_end_outofbound + // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> + // CHECK: return [[RES]] : tensor<1x2xf32> +} + +// ----- + +func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { + %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_slice_all_constant_negative_steps + // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> + // CHECK: return [[RES]] : tensor<1x2xf32> +} diff --git a/test/mlir/transform/attribute_promotion.mlir b/test/mlir/transform/attribute_promotion.mlir index 4bd31d5..e85061f 100644 --- a/test/mlir/transform/attribute_promotion.mlir +++ b/test/mlir/transform/attribute_promotion.mlir @@ -51,3 +51,16 @@ func @test_should_promote_to_attribute1(%arg0 : tensor) -> tensor<*xf32 // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor, none, none) -> tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> } + +// ----- + +func @test_tile_should_promote_to_attribute(%arg0 : tensor) -> tensor<*xf32> { + %shape = constant dense<[6, 7, 42]> : tensor<3xi64> + %0 = "onnx.Tile"(%arg0, %shape) : (tensor, tensor<3xi64>) -> tensor<*xf32> + return %0 : tensor<*xf32> + // CHECK-LABEL: test_tile_should_promote_to_attribute + // CHECK-NEXT: [[NONE:%.+]] = constant unit + // CHECK-NEXT: [[TILE:%.+]] = "onnx.Tile"(%{{.*}}, [[NONE]]) {repeats = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor, none) -> tensor<*xf32> + // CHECK-NEXT: return [[TILE]] : tensor<*xf32> +} + diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 7ebb8e7..b2eed83 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -252,7 +252,7 @@ OpsWithShapeInference = [ 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', - 'Squeeze' + 'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice' ] # Operations supporting canonicalization. @@ -266,7 +266,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler'] # tuples, whose first item is the attribute/operand name, and the second item is # the index at which such operand occurs in the list of the operation's inputs. OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], - "Pad": [("pads", 1), ("constant_value", 2)]} + "Pad": [("pads", 1), ("constant_value", 2)], + "Tile": [("repeats", 1)]} # Interface for special handling of type inference # The common code are put into get_type_inference_func @@ -281,7 +282,15 @@ OpsWithResultTypeInference = { '''auto toAttr = to().getSExtValue(); auto builder = mlir::OpBuilder(getContext()); resultTypes.push_back(mlir::UnrankedTensorType::get( - convertONNXTypeToMLIRType(builder, static_cast(toAttr))));''' + convertONNXTypeToMLIRType(builder, static_cast(toAttr))));''', + "ConstantOfShape": + '''if (auto attr = valueAttr()) { + resultTypes.push_back(mlir::UnrankedTensorType::get( + attr.getType().cast().getElementType())); + } else { + resultTypes.push_back(mlir::UnrankedTensorType::get( + FloatType::getF32(getContext()))); + }''' } # Add an Op in this list if the Op needs result type deduction which is required