Shape inference for some ops in RNN models (#207)
* Shape inference for ShapeOp * Shape inference for TileOp * Fix check-doc * Fix onnx-mlir-doc * Shape inference for GatherOp * Check validity of GatherOp's indices tensor * Shape inference for Slice * Tests for SliceOp * Fix importing none inputs * Type inference for constantofshape * Empty tensor in case of ConstantOfShape * Remove unrelated changes Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
		
							parent
							
								
									1263d01968
								
							
						
					
					
						commit
						034f98c00c
					
				|  | @ -5556,7 +5556,7 @@ ONNX Tile operation | |||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values | ||||
| `repeats` | tensor of 64-bit signless integer values or memref of any type values | ||||
| `repeats` | tensor of 64-bit signless integer values or memref of any type values or none type | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  |  | |||
|  | @ -2258,6 +2258,362 @@ LogicalResult ONNXConvIntegerOp::inferShapes() { | |||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Shape
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| LogicalResult ONNXShapeOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!data().getType().isa<RankedTensorType>()) | ||||
|     return emitError("Input tensor not ranked"); | ||||
| 
 | ||||
|   // Output is an 1D int64 tensor containing the shape of the input tensor.
 | ||||
|   int64_t rank = data().getType().cast<RankedTensorType>().getRank(); | ||||
|   SmallVector<int64_t, 1> outDims(1, rank); | ||||
|   getResult().setType( | ||||
|       RankedTensorType::get(outDims, IntegerType::get(64, getContext()))); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Tile
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| LogicalResult ONNXTileOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!input().getType().isa<RankedTensorType>()) | ||||
|     return emitError("Input tensor not ranked"); | ||||
| 
 | ||||
|   // Read 'repeats' value.
 | ||||
|   if (!repeats().getType().isa<RankedTensorType>()) | ||||
|     return emitError("Repeats tensor not ranked"); | ||||
| 
 | ||||
|   auto inputTensorTy = input().getType().cast<RankedTensorType>(); | ||||
|   auto repeatsTensorTy = repeats().getType().cast<RankedTensorType>(); | ||||
| 
 | ||||
|   // 'repeats' tensor is an 1D tensor.
 | ||||
|   if (repeatsTensorTy.getShape().size() != 1) | ||||
|     return emitError("Repeats tensor must have rank one"); | ||||
| 
 | ||||
|   // 'repeats' tensor must have constant shape.
 | ||||
|   int64_t repeatsLength = repeatsTensorTy.getShape()[0]; | ||||
|   if (repeatsLength < 0) | ||||
|     return emitError("Repeats tensor must have constant shape"); | ||||
| 
 | ||||
|   // Check the 1D repeats tensor length.
 | ||||
|   int64_t inputRank = inputTensorTy.getShape().size(); | ||||
|   if (inputRank != repeatsLength) | ||||
|     return emitError("Repeats tensor must have the same length as the input's " | ||||
|                      "dimension number."); | ||||
| 
 | ||||
|   // Check if second argument of TileOp is a constant.
 | ||||
|   auto constantOp = getONNXConstantOp(repeats()); | ||||
| 
 | ||||
|   // Compute output's dimensions: output_dim[i] = input_dim[i] * repeats[i]
 | ||||
|   SmallVector<int64_t, 2> dims(inputRank, -1); | ||||
|   if (constantOp) { | ||||
|     // 1. Initialize output_dim with values from 'input'.
 | ||||
|     //   output_dim[i] = input[i]
 | ||||
|     for (decltype(inputRank) i = 0; i < inputRank; ++i) | ||||
|       dims[i] = inputTensorTy.getShape()[i]; | ||||
| 
 | ||||
|     // 2. Update output_dim using values from 'repeats'.
 | ||||
|     // Do this only for static 'input_dim[i]'.
 | ||||
|     //   if (output_dim[i] != -1) output_dim[i] *= repeats[i]
 | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     if (!valueAttribute) | ||||
|       return emitError("DenseElementsAttr expected"); | ||||
|     // Get repeat values from valueAttribute.
 | ||||
|     auto valueIt = valueAttribute.getValues<IntegerAttr>().begin(); | ||||
|     for (int i = 0; i < inputRank; ++i) | ||||
|       if (dims[i] != -1) | ||||
|         dims[i] *= (*valueIt++).cast<IntegerAttr>().getInt(); | ||||
| 
 | ||||
|     if (valueIt != valueAttribute.getValues<IntegerAttr>().end()) | ||||
|       return emitError("Constant value must have same length as output's rank"); | ||||
|   } | ||||
| 
 | ||||
|   getResult().setType( | ||||
|       RankedTensorType::get(dims, inputTensorTy.getElementType())); | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Gather
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| LogicalResult ONNXGatherOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!data().getType().isa<RankedTensorType>()) | ||||
|     return emitError("Input tensor not ranked"); | ||||
|   if (!indices().getType().isa<RankedTensorType>()) | ||||
|     return emitError("Indices tensor not ranked"); | ||||
| 
 | ||||
|   auto inputShape = data().getType().cast<RankedTensorType>().getShape(); | ||||
|   auto indicesShape = indices().getType().cast<RankedTensorType>().getShape(); | ||||
|   int64_t inputRank = inputShape.size(); | ||||
|   int64_t indicesRank = indicesShape.size(); | ||||
| 
 | ||||
|   if (inputRank < 1) | ||||
|     return emitError("Input tensor must have rank >= 1"); | ||||
| 
 | ||||
|   // Read 'axis' attribute.
 | ||||
|   auto axisIndex = axis().getSExtValue(); | ||||
|   // 'axis' must be in [-rank, rank-1]
 | ||||
|   if (axisIndex < -inputRank || axisIndex >= inputRank) | ||||
|     return emitError("Gather axis value out of bound"); | ||||
|   // Convert a negative axis to a positive axis.
 | ||||
|   if (axisIndex < 0) { | ||||
|     axisIndex += inputRank; | ||||
|     auto builder = mlir::Builder(getContext()); | ||||
|     axisAttr(builder.getI64IntegerAttr(axisIndex)); | ||||
|   } | ||||
| 
 | ||||
|   // If 'indices' is a constant, check whether its values are valid or not.
 | ||||
|   auto constantOp = getONNXConstantOp(indices()); | ||||
|   if (constantOp && inputShape[axisIndex] != -1) { | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     if (!valueAttribute) | ||||
|       return emitError("DenseElementsAttr expected"); | ||||
|     for (auto value : valueAttribute.getValues<IntegerAttr>()) { | ||||
|       auto index = value.cast<IntegerAttr>().getInt(); | ||||
|       if (index < -inputShape[axisIndex] || index >= inputShape[axisIndex]) | ||||
|         return emitError("Indices tensor contains an out-of-bound index"); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Output has rank of 'indicesRank + (inputRank - 1).
 | ||||
|   // Output shape is constructed from 'input' by:
 | ||||
|   //    replacing the dimension at 'axis' in 'input' by the shape of 'indices'.
 | ||||
|   SmallVector<int64_t, 1> outDims; | ||||
|   for (decltype(inputRank) i = 0; i < inputRank; ++i) { | ||||
|     if (i == axisIndex) | ||||
|       for (decltype(indicesRank) j = 0; j < indicesRank; ++j) | ||||
|         outDims.emplace_back(indicesShape[j]); | ||||
|     else | ||||
|       outDims.emplace_back(inputShape[i]); | ||||
|   } | ||||
| 
 | ||||
|   getResult().setType(RankedTensorType::get( | ||||
|       outDims, data().getType().cast<RankedTensorType>().getElementType())); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // ConstantOfShape
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| LogicalResult ONNXConstantOfShapeOp::inferShapes() { | ||||
|   Type elementType; | ||||
| 
 | ||||
|   // 'value' attribute is a one-element tensor whose value and datatype are used
 | ||||
|   // to set the output tensor's value and datatype..
 | ||||
|   if (value().hasValue()) { | ||||
|     elementType = | ||||
|         valueAttr().cast<DenseElementsAttr>().getType().getElementType(); | ||||
|   } else { | ||||
|     // If 'value' attribute is not specified, it defaults to a tensor of value 0
 | ||||
|     // and datatype float32.
 | ||||
|     elementType = FloatType::getF32(getContext()); | ||||
| 
 | ||||
|     llvm::SmallVector<int64_t, 2> dims(1, 1); | ||||
|     auto tensorType = mlir::RankedTensorType::get(dims, elementType); | ||||
| 
 | ||||
|     llvm::SmallVector<float, 1> values(1, 0.); | ||||
|     valueAttr( | ||||
|         mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values))); | ||||
|   } | ||||
| 
 | ||||
|   // 'input' must be a 1D tensor.
 | ||||
|   auto inputShape = input().getType().cast<RankedTensorType>().getShape(); | ||||
|   if (inputShape.size() != 1) | ||||
|     return emitError("Input tensor must be a 1D tensor"); | ||||
|   if (inputShape[0] == -1) | ||||
|     return emitError("Input tensor must have static shape"); | ||||
|   if (inputShape[0] == 0) { | ||||
|     // If 'input' is an empty tensor, the output would be a scalar.
 | ||||
|     getResult().setType(RankedTensorType::get({}, elementType)); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   // Calculate output dimensions.
 | ||||
|   SmallVector<int64_t, 4> outputDims(inputShape[0], -1); | ||||
|   // If 'input' is a constant, check whether its values are valid or not.
 | ||||
|   // If the values are valid, it is possible to infer shape.
 | ||||
|   if (auto constantOp = getONNXConstantOp(input())) { | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     // Get repeat values from valueAttribute.
 | ||||
|     auto valueIt = valueAttribute.getValues<IntegerAttr>().begin(); | ||||
|     for (int i = 0; i < inputShape[0]; ++i) { | ||||
|       auto dim = (*valueIt++).cast<IntegerAttr>().getInt(); | ||||
|       if (dim < 0) | ||||
|         return emitError("All values of the input tensor must be >=0"); | ||||
|       outputDims[i] = dim; | ||||
|     } | ||||
| 
 | ||||
|     if (valueIt != valueAttribute.getValues<IntegerAttr>().end()) | ||||
|       return emitError("Constant value must have same length as output's rank"); | ||||
|   } | ||||
| 
 | ||||
|   getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Slice
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| LogicalResult ONNXSliceOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!data().getType().isa<RankedTensorType>()) | ||||
|     return emitError("Input tensor not ranked"); | ||||
| 
 | ||||
|   auto elementType = data().getType().cast<ShapedType>().getElementType(); | ||||
|   auto dataShape = data().getType().cast<ShapedType>().getShape(); | ||||
|   int64_t numDims = dataShape.size(); | ||||
| 
 | ||||
|   SmallVector<int64_t, 2> outputDims(numDims, -1); | ||||
|   // If 'starts', 'ends', 'axes', and 'steps' are constants, check whether their
 | ||||
|   // values are valid or not. If the values are valid, it is possible to infer
 | ||||
|   // shape.
 | ||||
|   //
 | ||||
|   // 'starts', 'ends', and 'steps' are for each axis in the list of axes, so
 | ||||
|   // processing 'axes' first.
 | ||||
| 
 | ||||
|   // Check and get 'axes' tensor.
 | ||||
|   SmallVector<int64_t, 2> axesValue; | ||||
|   if (axes().getType().isa<NoneType>()) { | ||||
|     // If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
 | ||||
|     for (int i = 0; i < numDims; ++i) | ||||
|       axesValue.emplace_back(i); | ||||
|   } else if (auto constantOp = getONNXConstantOp(axes())) { | ||||
|     // If `axes` are constants, read them."
 | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     for (auto value : valueAttribute.getValues<IntegerAttr>()) { | ||||
|       int64_t axis = value.cast<IntegerAttr>().getInt(); | ||||
|       if (axis < -numDims || axis >= numDims) | ||||
|         return emitError("Axes contains an out-of-bound index"); | ||||
|       if (axis < 0) | ||||
|         axis += numDims; | ||||
|       if (dataShape[axis] == -1) { | ||||
|         // It is unsafe to infer shape for an axis with an unknown dimension,
 | ||||
|         // since we can not validate 'start' and 'end' values from this
 | ||||
|         // dimension.
 | ||||
|         getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|         return success(); | ||||
|       } | ||||
|       axesValue.emplace_back(axis); | ||||
|     } | ||||
|   } else { | ||||
|     // Cannot infer a static shape.
 | ||||
|     getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   // Check 'starts' tensor.
 | ||||
|   SmallVector<int64_t, 2> startsValue; | ||||
|   if (auto constantOp = getONNXConstantOp(starts())) { | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     int i = 0; | ||||
|     for (auto value : valueAttribute.getValues<IntegerAttr>()) { | ||||
|       int64_t axis = axesValue[i]; | ||||
|       int64_t index = value.cast<IntegerAttr>().getInt(); | ||||
|       if (index < -dataShape[axis]) | ||||
|         index = 0; | ||||
|       else if (index > dataShape[axis]) | ||||
|         index = dataShape[axis]; | ||||
|       else if (index < 0) | ||||
|         index += dataShape[axis]; | ||||
|       startsValue.emplace_back(index); | ||||
|       i++; | ||||
|     } | ||||
|     if (i != axesValue.size()) | ||||
|       emitError("starts and axes tensors must have the same length"); | ||||
|   } else { | ||||
|     // Cannot infer a static shape.
 | ||||
|     getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   // Check 'ends' tensor.
 | ||||
|   SmallVector<int64_t, 2> endsValue; | ||||
|   if (auto constantOp = getONNXConstantOp(ends())) { | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     int i = 0; | ||||
|     for (auto value : valueAttribute.getValues<IntegerAttr>()) { | ||||
|       int64_t axis = axesValue[i]; | ||||
|       int64_t index = value.cast<IntegerAttr>().getInt(); | ||||
|       if (index < -dataShape[axis]) | ||||
|         index = 0; | ||||
|       else if (index > dataShape[axis]) | ||||
|         index = dataShape[axis]; | ||||
|       else if (index < 0) | ||||
|         index += dataShape[axis]; | ||||
|       endsValue.emplace_back(index); | ||||
|       i++; | ||||
|     } | ||||
|     if (i != axesValue.size()) | ||||
|       emitError("ends and axes tensors must have the same length"); | ||||
|   } else { | ||||
|     // Cannot infer a static shape.
 | ||||
|     getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   // Check and get 'steps' tensor.
 | ||||
|   SmallVector<int64_t, 2> stepsValue; | ||||
|   if (steps().getType().isa<NoneType>()) { | ||||
|     // If `steps` are omitted, they are set to `[1, ..., 1]` of len(starts)."
 | ||||
|     for (int i = 0; i < startsValue.size(); ++i) | ||||
|       stepsValue.emplace_back(1); | ||||
|   } else if (auto constantOp = getONNXConstantOp(steps())) { | ||||
|     // If `steps` are constants, read them."
 | ||||
|     DenseElementsAttr valueAttribute = | ||||
|         constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); | ||||
|     int i = 0; | ||||
|     for (auto value : valueAttribute.getValues<IntegerAttr>()) { | ||||
|       int64_t index = value.cast<IntegerAttr>().getInt(); | ||||
|       if (index == 0) | ||||
|         emitError("step cannot be zero"); | ||||
|       stepsValue.emplace_back(index); | ||||
|       i++; | ||||
|     } | ||||
|     if (i != axesValue.size()) | ||||
|       emitError("steps and axes tensors must have the same length"); | ||||
|   } else { | ||||
|     // Cannot infer a static shape.
 | ||||
|     getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   // All 'starts', 'ends', 'steps' values are valid. Now calculate output
 | ||||
|   // dimensions for axes in 'axes'.
 | ||||
|   for (int i = 0; i < axesValue.size(); i++) { | ||||
|     int64_t axis = axesValue[i]; | ||||
|     int64_t start = startsValue[i]; | ||||
|     int64_t end = endsValue[i]; | ||||
|     int64_t step = stepsValue[i]; | ||||
|     if (step < 0) | ||||
|       step = -step; | ||||
| 
 | ||||
|     int64_t q = (end - start) / step; | ||||
|     int64_t r = (end - start) % step; | ||||
|     if (r != 0) | ||||
|       q += 1; | ||||
|     outputDims[axis] = q; | ||||
|   } | ||||
| 
 | ||||
|   getResult().setType(RankedTensorType::get(outputDims, elementType)); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // TableGen'd op method definitions
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -662,7 +662,7 @@ def ONNXConstantOp:ONNX_Op<"Constant", | |||
| } | ||||
| 
 | ||||
| def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", | ||||
|   [NoSideEffect]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> { | ||||
|   let summary = "ONNX ConstantOfShape operation"; | ||||
|   let description = [{ | ||||
|   "Generate a tensor with given value and shape." | ||||
|  | @ -680,6 +680,17 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", | |||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|     std::vector<mlir::Type> resultTypeInference() { | ||||
|       std::vector<mlir::Type> resultTypes; | ||||
|       if (auto attr = valueAttr()) { | ||||
|         resultTypes.push_back(mlir::UnrankedTensorType::get( | ||||
|           attr.getType().cast<ShapedType>().getElementType())); | ||||
|       } else { | ||||
|         resultTypes.push_back(mlir::UnrankedTensorType::get( | ||||
|           FloatType::getF32(getContext()))); | ||||
|       } | ||||
|       return resultTypes; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
|  | @ -1438,7 +1449,7 @@ def ONNXGRUOp:ONNX_Op<"GRU", | |||
| } | ||||
| 
 | ||||
| def ONNXGatherOp:ONNX_Op<"Gather", | ||||
|   [NoSideEffect]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX Gather operation"; | ||||
|   let description = [{ | ||||
|   "Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather" | ||||
|  | @ -4695,7 +4706,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", | |||
| } | ||||
| 
 | ||||
| def ONNXShapeOp:ONNX_Op<"Shape", | ||||
|   [NoSideEffect]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX Shape operation"; | ||||
|   let description = [{ | ||||
|   "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." | ||||
|  | @ -4850,7 +4861,7 @@ def ONNXSizeOp:ONNX_Op<"Size", | |||
| } | ||||
| 
 | ||||
| def ONNXSliceOp:ONNX_Op<"Slice", | ||||
|   [NoSideEffect]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX Slice operation"; | ||||
|   let description = [{ | ||||
|   "Produces a slice of the input tensor along multiple axes. Similar to numpy:" | ||||
|  | @ -5345,7 +5356,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", | |||
| } | ||||
| 
 | ||||
| def ONNXTileOp:ONNX_Op<"Tile", | ||||
|   [NoSideEffect]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> { | ||||
|   let summary = "ONNX Tile operation"; | ||||
|   let description = [{ | ||||
|   "Constructs a tensor by tiling a given tensor." | ||||
|  | @ -5353,7 +5364,7 @@ def ONNXTileOp:ONNX_Op<"Tile", | |||
|   "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]" | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$input, | ||||
|     AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$repeats); | ||||
|     AnyTypeOf<[TensorOf<[I64]>, AnyMemRef, NoneType]>:$repeats); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$output); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|  | @ -5365,6 +5376,9 @@ def ONNXTileOp:ONNX_Op<"Tile", | |||
|     static std::vector<int> getTypeMap() { | ||||
|       return {20}; | ||||
|     } | ||||
|     std::map<std::string, size_t> promotableConstOperands() { | ||||
|       return {{"repeats", 1}}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -1128,3 +1128,230 @@ func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi | |||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32> | ||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_shape(%arg0: tensor<?x3x2xf32>) -> tensor<*xi64> { | ||||
|   %0 = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<*xi64> | ||||
|   return %0 : tensor<*xi64> | ||||
| 
 | ||||
|   // CHECK-LABEL: test_shape | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<3xi64> | ||||
|   // CHECK: return [[RES]] : tensor<3xi64> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_tile_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_tile_dynamic | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32> | ||||
|   // CHECK: return [[RES]] : tensor<?x?x?x?xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_tile_constant(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64> | ||||
|   %1 = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_tile_constant | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<25x25x16x64xf32> | ||||
|   // CHECK: return [[RES]] : tensor<25x25x16x64xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_gather_axis0(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Gather"(%arg0, %arg1) {axis = 0} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_gather_axis0 | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 0 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x2x3xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x2x3xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_gather_axis1(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Gather"(%arg0, %arg1) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_gather_axis1 | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<3x1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Gather"(%arg0, %arg1) {axis = -1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_gather_negative_axis | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<3x1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.ConstantOfShape"(%arg0) : (tensor<0xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_constant_of_shape_empty_tensor | ||||
|   // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<0xi64>) -> tensor<f32> | ||||
|   // CHECK: return [[RES]] : tensor<f32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_constant_of_shape(%arg0 : tensor<3xi64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.ConstantOfShape"(%arg0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_constant_of_shape | ||||
|   // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<?x?x?xf32> | ||||
|   // CHECK: return [[RES]] : tensor<?x?x?xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_constant_of_shape_constant() -> tensor<*xf32> { | ||||
|   %0 = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64> } : () -> tensor<3xi64> | ||||
|   %1 = "onnx.ConstantOfShape"(%0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_constant_of_shape_constant | ||||
|   // CHECK: [[CONSTANT:%.+]] = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64>} : () -> tensor<3xi64> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"([[CONSTANT]]) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<3x4x5xf32> | ||||
|   // CHECK: return [[RES]] : tensor<3x4x5xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice(%arg0 : tensor<2x4xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>, %arg3: tensor<2xi64>, %arg4: tensor<2xi64>) -> tensor<*xf32> { | ||||
|   %1 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32> | ||||
|   // CHECK: return [[RES:%.+]] : tensor<?x?xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { | ||||
|   %axes = constant unit | ||||
|   %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice_constant_default_axes | ||||
|   // CHECK: [[AXES:%.+]] = constant unit | ||||
|   // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>  | ||||
|   // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { | ||||
|   %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %steps = constant unit | ||||
|   %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice_constant_default_steps | ||||
|   // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>  | ||||
|   // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STEPS:%.+]] = constant unit | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<1x3xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x3xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { | ||||
|   %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice_all_constant | ||||
|   // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>  | ||||
|   // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { | ||||
|   %axes = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %ends = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice_all_constant_negative | ||||
|   // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>  | ||||
|   // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { | ||||
|   %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %ends = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice_all_constant_end_outofbound | ||||
|   // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>  | ||||
|   // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> { | ||||
|   %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %steps = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64> } : () -> tensor<2xi64> | ||||
|   %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> | ||||
|   "std.return"(%1) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_slice_all_constant_negative_steps | ||||
|   // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>  | ||||
|   // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32> | ||||
|   // CHECK: return [[RES]] : tensor<1x2xf32> | ||||
| } | ||||
|  |  | |||
|  | @ -51,3 +51,16 @@ func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32 | |||
|   // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_tile_should_promote_to_attribute(%arg0 : tensor<?x10x10xf32>) -> tensor<*xf32> { | ||||
|   %shape = constant dense<[6, 7, 42]> : tensor<3xi64> | ||||
|   %0 = "onnx.Tile"(%arg0, %shape) : (tensor<?x10x10xf32>, tensor<3xi64>) -> tensor<*xf32> | ||||
|   return %0 : tensor<*xf32> | ||||
|   // CHECK-LABEL: test_tile_should_promote_to_attribute | ||||
|   // CHECK-NEXT: [[NONE:%.+]] = constant unit | ||||
|   // CHECK-NEXT: [[TILE:%.+]] = "onnx.Tile"(%{{.*}}, [[NONE]]) {repeats = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10x10xf32>, none) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: return [[TILE]] : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -252,7 +252,7 @@ OpsWithShapeInference = [ | |||
|     'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', | ||||
|     'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', | ||||
|     'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', | ||||
|     'Squeeze' | ||||
|     'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice' | ||||
| ] | ||||
| 
 | ||||
| # Operations supporting canonicalization. | ||||
|  | @ -266,7 +266,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler'] | |||
| # tuples, whose first item is the attribute/operand name, and the second item is | ||||
| # the index at which such operand occurs in the list of the operation's inputs. | ||||
| OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], | ||||
|                                   "Pad": [("pads", 1), ("constant_value", 2)]} | ||||
|                                   "Pad": [("pads", 1), ("constant_value", 2)], | ||||
|                                   "Tile": [("repeats", 1)]} | ||||
| 
 | ||||
| # Interface for special handling of type inference | ||||
| # The common code are put into get_type_inference_func | ||||
|  | @ -281,7 +282,15 @@ OpsWithResultTypeInference = { | |||
|     '''auto toAttr = to().getSExtValue(); | ||||
|       auto builder = mlir::OpBuilder(getContext()); | ||||
|       resultTypes.push_back(mlir::UnrankedTensorType::get( | ||||
|         convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));''' | ||||
|         convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));''', | ||||
|   "ConstantOfShape": | ||||
|   '''if (auto attr = valueAttr()) { | ||||
|         resultTypes.push_back(mlir::UnrankedTensorType::get( | ||||
|           attr.getType().cast<ShapedType>().getElementType())); | ||||
|       } else { | ||||
|         resultTypes.push_back(mlir::UnrankedTensorType::get( | ||||
|           FloatType::getF32(getContext()))); | ||||
|       }''' | ||||
| } | ||||
| 
 | ||||
| # Add an Op in this list if the Op needs result type deduction which is required | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue