Shape inference for some ops in RNN models (#207)
* Shape inference for ShapeOp * Shape inference for TileOp * Fix check-doc * Fix onnx-mlir-doc * Shape inference for GatherOp * Check validity of GatherOp's indices tensor * Shape inference for Slice * Tests for SliceOp * Fix importing none inputs * Type inference for constantofshape * Empty tensor in case of ConstantOfShape * Remove unrelated changes Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
		
							parent
							
								
									1263d01968
								
							
						
					
					
						commit
						034f98c00c
					
				| 
						 | 
					@ -5556,7 +5556,7 @@ ONNX Tile operation
 | 
				
			||||||
| Operand | Description |
 | 
					| Operand | Description |
 | 
				
			||||||
| :-----: | ----------- |
 | 
					| :-----: | ----------- |
 | 
				
			||||||
`input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values
 | 
					`input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values
 | 
				
			||||||
`repeats` | tensor of 64-bit signless integer values or memref of any type values
 | 
					`repeats` | tensor of 64-bit signless integer values or memref of any type values or none type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#### Results:
 | 
					#### Results:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2258,6 +2258,362 @@ LogicalResult ONNXConvIntegerOp::inferShapes() {
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Shape
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ONNXShapeOp::inferShapes() {
 | 
				
			||||||
 | 
					  // Cannot infer shape if no shape exists.
 | 
				
			||||||
 | 
					  if (!data().getType().isa<RankedTensorType>())
 | 
				
			||||||
 | 
					    return emitError("Input tensor not ranked");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Output is an 1D int64 tensor containing the shape of the input tensor.
 | 
				
			||||||
 | 
					  int64_t rank = data().getType().cast<RankedTensorType>().getRank();
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> outDims(1, rank);
 | 
				
			||||||
 | 
					  getResult().setType(
 | 
				
			||||||
 | 
					      RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Tile
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ONNXTileOp::inferShapes() {
 | 
				
			||||||
 | 
					  // Cannot infer shape if no shape exists.
 | 
				
			||||||
 | 
					  if (!input().getType().isa<RankedTensorType>())
 | 
				
			||||||
 | 
					    return emitError("Input tensor not ranked");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Read 'repeats' value.
 | 
				
			||||||
 | 
					  if (!repeats().getType().isa<RankedTensorType>())
 | 
				
			||||||
 | 
					    return emitError("Repeats tensor not ranked");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto inputTensorTy = input().getType().cast<RankedTensorType>();
 | 
				
			||||||
 | 
					  auto repeatsTensorTy = repeats().getType().cast<RankedTensorType>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // 'repeats' tensor is an 1D tensor.
 | 
				
			||||||
 | 
					  if (repeatsTensorTy.getShape().size() != 1)
 | 
				
			||||||
 | 
					    return emitError("Repeats tensor must have rank one");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // 'repeats' tensor must have constant shape.
 | 
				
			||||||
 | 
					  int64_t repeatsLength = repeatsTensorTy.getShape()[0];
 | 
				
			||||||
 | 
					  if (repeatsLength < 0)
 | 
				
			||||||
 | 
					    return emitError("Repeats tensor must have constant shape");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check the 1D repeats tensor length.
 | 
				
			||||||
 | 
					  int64_t inputRank = inputTensorTy.getShape().size();
 | 
				
			||||||
 | 
					  if (inputRank != repeatsLength)
 | 
				
			||||||
 | 
					    return emitError("Repeats tensor must have the same length as the input's "
 | 
				
			||||||
 | 
					                     "dimension number.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check if second argument of TileOp is a constant.
 | 
				
			||||||
 | 
					  auto constantOp = getONNXConstantOp(repeats());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Compute output's dimensions: output_dim[i] = input_dim[i] * repeats[i]
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 2> dims(inputRank, -1);
 | 
				
			||||||
 | 
					  if (constantOp) {
 | 
				
			||||||
 | 
					    // 1. Initialize output_dim with values from 'input'.
 | 
				
			||||||
 | 
					    //   output_dim[i] = input[i]
 | 
				
			||||||
 | 
					    for (decltype(inputRank) i = 0; i < inputRank; ++i)
 | 
				
			||||||
 | 
					      dims[i] = inputTensorTy.getShape()[i];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 2. Update output_dim using values from 'repeats'.
 | 
				
			||||||
 | 
					    // Do this only for static 'input_dim[i]'.
 | 
				
			||||||
 | 
					    //   if (output_dim[i] != -1) output_dim[i] *= repeats[i]
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    if (!valueAttribute)
 | 
				
			||||||
 | 
					      return emitError("DenseElementsAttr expected");
 | 
				
			||||||
 | 
					    // Get repeat values from valueAttribute.
 | 
				
			||||||
 | 
					    auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
 | 
				
			||||||
 | 
					    for (int i = 0; i < inputRank; ++i)
 | 
				
			||||||
 | 
					      if (dims[i] != -1)
 | 
				
			||||||
 | 
					        dims[i] *= (*valueIt++).cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
 | 
				
			||||||
 | 
					      return emitError("Constant value must have same length as output's rank");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  getResult().setType(
 | 
				
			||||||
 | 
					      RankedTensorType::get(dims, inputTensorTy.getElementType()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Gather
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ONNXGatherOp::inferShapes() {
 | 
				
			||||||
 | 
					  // Cannot infer shape if no shape exists.
 | 
				
			||||||
 | 
					  if (!data().getType().isa<RankedTensorType>())
 | 
				
			||||||
 | 
					    return emitError("Input tensor not ranked");
 | 
				
			||||||
 | 
					  if (!indices().getType().isa<RankedTensorType>())
 | 
				
			||||||
 | 
					    return emitError("Indices tensor not ranked");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto inputShape = data().getType().cast<RankedTensorType>().getShape();
 | 
				
			||||||
 | 
					  auto indicesShape = indices().getType().cast<RankedTensorType>().getShape();
 | 
				
			||||||
 | 
					  int64_t inputRank = inputShape.size();
 | 
				
			||||||
 | 
					  int64_t indicesRank = indicesShape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (inputRank < 1)
 | 
				
			||||||
 | 
					    return emitError("Input tensor must have rank >= 1");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Read 'axis' attribute.
 | 
				
			||||||
 | 
					  auto axisIndex = axis().getSExtValue();
 | 
				
			||||||
 | 
					  // 'axis' must be in [-rank, rank-1]
 | 
				
			||||||
 | 
					  if (axisIndex < -inputRank || axisIndex >= inputRank)
 | 
				
			||||||
 | 
					    return emitError("Gather axis value out of bound");
 | 
				
			||||||
 | 
					  // Convert a negative axis to a positive axis.
 | 
				
			||||||
 | 
					  if (axisIndex < 0) {
 | 
				
			||||||
 | 
					    axisIndex += inputRank;
 | 
				
			||||||
 | 
					    auto builder = mlir::Builder(getContext());
 | 
				
			||||||
 | 
					    axisAttr(builder.getI64IntegerAttr(axisIndex));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // If 'indices' is a constant, check whether its values are valid or not.
 | 
				
			||||||
 | 
					  auto constantOp = getONNXConstantOp(indices());
 | 
				
			||||||
 | 
					  if (constantOp && inputShape[axisIndex] != -1) {
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    if (!valueAttribute)
 | 
				
			||||||
 | 
					      return emitError("DenseElementsAttr expected");
 | 
				
			||||||
 | 
					    for (auto value : valueAttribute.getValues<IntegerAttr>()) {
 | 
				
			||||||
 | 
					      auto index = value.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      if (index < -inputShape[axisIndex] || index >= inputShape[axisIndex])
 | 
				
			||||||
 | 
					        return emitError("Indices tensor contains an out-of-bound index");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Output has rank of 'indicesRank + (inputRank - 1).
 | 
				
			||||||
 | 
					  // Output shape is constructed from 'input' by:
 | 
				
			||||||
 | 
					  //    replacing the dimension at 'axis' in 'input' by the shape of 'indices'.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> outDims;
 | 
				
			||||||
 | 
					  for (decltype(inputRank) i = 0; i < inputRank; ++i) {
 | 
				
			||||||
 | 
					    if (i == axisIndex)
 | 
				
			||||||
 | 
					      for (decltype(indicesRank) j = 0; j < indicesRank; ++j)
 | 
				
			||||||
 | 
					        outDims.emplace_back(indicesShape[j]);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      outDims.emplace_back(inputShape[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  getResult().setType(RankedTensorType::get(
 | 
				
			||||||
 | 
					      outDims, data().getType().cast<RankedTensorType>().getElementType()));
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// ConstantOfShape
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ONNXConstantOfShapeOp::inferShapes() {
 | 
				
			||||||
 | 
					  Type elementType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // 'value' attribute is a one-element tensor whose value and datatype are used
 | 
				
			||||||
 | 
					  // to set the output tensor's value and datatype..
 | 
				
			||||||
 | 
					  if (value().hasValue()) {
 | 
				
			||||||
 | 
					    elementType =
 | 
				
			||||||
 | 
					        valueAttr().cast<DenseElementsAttr>().getType().getElementType();
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // If 'value' attribute is not specified, it defaults to a tensor of value 0
 | 
				
			||||||
 | 
					    // and datatype float32.
 | 
				
			||||||
 | 
					    elementType = FloatType::getF32(getContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llvm::SmallVector<int64_t, 2> dims(1, 1);
 | 
				
			||||||
 | 
					    auto tensorType = mlir::RankedTensorType::get(dims, elementType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llvm::SmallVector<float, 1> values(1, 0.);
 | 
				
			||||||
 | 
					    valueAttr(
 | 
				
			||||||
 | 
					        mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // 'input' must be a 1D tensor.
 | 
				
			||||||
 | 
					  auto inputShape = input().getType().cast<RankedTensorType>().getShape();
 | 
				
			||||||
 | 
					  if (inputShape.size() != 1)
 | 
				
			||||||
 | 
					    return emitError("Input tensor must be a 1D tensor");
 | 
				
			||||||
 | 
					  if (inputShape[0] == -1)
 | 
				
			||||||
 | 
					    return emitError("Input tensor must have static shape");
 | 
				
			||||||
 | 
					  if (inputShape[0] == 0) {
 | 
				
			||||||
 | 
					    // If 'input' is an empty tensor, the output would be a scalar.
 | 
				
			||||||
 | 
					    getResult().setType(RankedTensorType::get({}, elementType));
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Calculate output dimensions.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 4> outputDims(inputShape[0], -1);
 | 
				
			||||||
 | 
					  // If 'input' is a constant, check whether its values are valid or not.
 | 
				
			||||||
 | 
					  // If the values are valid, it is possible to infer shape.
 | 
				
			||||||
 | 
					  if (auto constantOp = getONNXConstantOp(input())) {
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    // Get repeat values from valueAttribute.
 | 
				
			||||||
 | 
					    auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
 | 
				
			||||||
 | 
					    for (int i = 0; i < inputShape[0]; ++i) {
 | 
				
			||||||
 | 
					      auto dim = (*valueIt++).cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      if (dim < 0)
 | 
				
			||||||
 | 
					        return emitError("All values of the input tensor must be >=0");
 | 
				
			||||||
 | 
					      outputDims[i] = dim;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
 | 
				
			||||||
 | 
					      return emitError("Constant value must have same length as output's rank");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Slice
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ONNXSliceOp::inferShapes() {
 | 
				
			||||||
 | 
					  // Cannot infer shape if no shape exists.
 | 
				
			||||||
 | 
					  if (!data().getType().isa<RankedTensorType>())
 | 
				
			||||||
 | 
					    return emitError("Input tensor not ranked");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto elementType = data().getType().cast<ShapedType>().getElementType();
 | 
				
			||||||
 | 
					  auto dataShape = data().getType().cast<ShapedType>().getShape();
 | 
				
			||||||
 | 
					  int64_t numDims = dataShape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 2> outputDims(numDims, -1);
 | 
				
			||||||
 | 
					  // If 'starts', 'ends', 'axes', and 'steps' are constants, check whether their
 | 
				
			||||||
 | 
					  // values are valid or not. If the values are valid, it is possible to infer
 | 
				
			||||||
 | 
					  // shape.
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  // 'starts', 'ends', and 'steps' are for each axis in the list of axes, so
 | 
				
			||||||
 | 
					  // processing 'axes' first.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check and get 'axes' tensor.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 2> axesValue;
 | 
				
			||||||
 | 
					  if (axes().getType().isa<NoneType>()) {
 | 
				
			||||||
 | 
					    // If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
 | 
				
			||||||
 | 
					    for (int i = 0; i < numDims; ++i)
 | 
				
			||||||
 | 
					      axesValue.emplace_back(i);
 | 
				
			||||||
 | 
					  } else if (auto constantOp = getONNXConstantOp(axes())) {
 | 
				
			||||||
 | 
					    // If `axes` are constants, read them."
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    for (auto value : valueAttribute.getValues<IntegerAttr>()) {
 | 
				
			||||||
 | 
					      int64_t axis = value.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      if (axis < -numDims || axis >= numDims)
 | 
				
			||||||
 | 
					        return emitError("Axes contains an out-of-bound index");
 | 
				
			||||||
 | 
					      if (axis < 0)
 | 
				
			||||||
 | 
					        axis += numDims;
 | 
				
			||||||
 | 
					      if (dataShape[axis] == -1) {
 | 
				
			||||||
 | 
					        // It is unsafe to infer shape for an axis with an unknown dimension,
 | 
				
			||||||
 | 
					        // since we can not validate 'start' and 'end' values from this
 | 
				
			||||||
 | 
					        // dimension.
 | 
				
			||||||
 | 
					        getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					        return success();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      axesValue.emplace_back(axis);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // Cannot infer a static shape.
 | 
				
			||||||
 | 
					    getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check 'starts' tensor.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 2> startsValue;
 | 
				
			||||||
 | 
					  if (auto constantOp = getONNXConstantOp(starts())) {
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    int i = 0;
 | 
				
			||||||
 | 
					    for (auto value : valueAttribute.getValues<IntegerAttr>()) {
 | 
				
			||||||
 | 
					      int64_t axis = axesValue[i];
 | 
				
			||||||
 | 
					      int64_t index = value.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      if (index < -dataShape[axis])
 | 
				
			||||||
 | 
					        index = 0;
 | 
				
			||||||
 | 
					      else if (index > dataShape[axis])
 | 
				
			||||||
 | 
					        index = dataShape[axis];
 | 
				
			||||||
 | 
					      else if (index < 0)
 | 
				
			||||||
 | 
					        index += dataShape[axis];
 | 
				
			||||||
 | 
					      startsValue.emplace_back(index);
 | 
				
			||||||
 | 
					      i++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (i != axesValue.size())
 | 
				
			||||||
 | 
					      emitError("starts and axes tensors must have the same length");
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // Cannot infer a static shape.
 | 
				
			||||||
 | 
					    getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check 'ends' tensor.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 2> endsValue;
 | 
				
			||||||
 | 
					  if (auto constantOp = getONNXConstantOp(ends())) {
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    int i = 0;
 | 
				
			||||||
 | 
					    for (auto value : valueAttribute.getValues<IntegerAttr>()) {
 | 
				
			||||||
 | 
					      int64_t axis = axesValue[i];
 | 
				
			||||||
 | 
					      int64_t index = value.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      if (index < -dataShape[axis])
 | 
				
			||||||
 | 
					        index = 0;
 | 
				
			||||||
 | 
					      else if (index > dataShape[axis])
 | 
				
			||||||
 | 
					        index = dataShape[axis];
 | 
				
			||||||
 | 
					      else if (index < 0)
 | 
				
			||||||
 | 
					        index += dataShape[axis];
 | 
				
			||||||
 | 
					      endsValue.emplace_back(index);
 | 
				
			||||||
 | 
					      i++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (i != axesValue.size())
 | 
				
			||||||
 | 
					      emitError("ends and axes tensors must have the same length");
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // Cannot infer a static shape.
 | 
				
			||||||
 | 
					    getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check and get 'steps' tensor.
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 2> stepsValue;
 | 
				
			||||||
 | 
					  if (steps().getType().isa<NoneType>()) {
 | 
				
			||||||
 | 
					    // If `steps` are omitted, they are set to `[1, ..., 1]` of len(starts)."
 | 
				
			||||||
 | 
					    for (int i = 0; i < startsValue.size(); ++i)
 | 
				
			||||||
 | 
					      stepsValue.emplace_back(1);
 | 
				
			||||||
 | 
					  } else if (auto constantOp = getONNXConstantOp(steps())) {
 | 
				
			||||||
 | 
					    // If `steps` are constants, read them."
 | 
				
			||||||
 | 
					    DenseElementsAttr valueAttribute =
 | 
				
			||||||
 | 
					        constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
 | 
				
			||||||
 | 
					    int i = 0;
 | 
				
			||||||
 | 
					    for (auto value : valueAttribute.getValues<IntegerAttr>()) {
 | 
				
			||||||
 | 
					      int64_t index = value.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      if (index == 0)
 | 
				
			||||||
 | 
					        emitError("step cannot be zero");
 | 
				
			||||||
 | 
					      stepsValue.emplace_back(index);
 | 
				
			||||||
 | 
					      i++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (i != axesValue.size())
 | 
				
			||||||
 | 
					      emitError("steps and axes tensors must have the same length");
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    // Cannot infer a static shape.
 | 
				
			||||||
 | 
					    getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // All 'starts', 'ends', 'steps' values are valid. Now calculate output
 | 
				
			||||||
 | 
					  // dimensions for axes in 'axes'.
 | 
				
			||||||
 | 
					  for (int i = 0; i < axesValue.size(); i++) {
 | 
				
			||||||
 | 
					    int64_t axis = axesValue[i];
 | 
				
			||||||
 | 
					    int64_t start = startsValue[i];
 | 
				
			||||||
 | 
					    int64_t end = endsValue[i];
 | 
				
			||||||
 | 
					    int64_t step = stepsValue[i];
 | 
				
			||||||
 | 
					    if (step < 0)
 | 
				
			||||||
 | 
					      step = -step;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int64_t q = (end - start) / step;
 | 
				
			||||||
 | 
					    int64_t r = (end - start) % step;
 | 
				
			||||||
 | 
					    if (r != 0)
 | 
				
			||||||
 | 
					      q += 1;
 | 
				
			||||||
 | 
					    outputDims[axis] = q;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  getResult().setType(RankedTensorType::get(outputDims, elementType));
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// TableGen'd op method definitions
 | 
					// TableGen'd op method definitions
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -662,7 +662,7 @@ def ONNXConstantOp:ONNX_Op<"Constant",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
 | 
					def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
 | 
				
			||||||
  let summary = "ONNX ConstantOfShape operation";
 | 
					  let summary = "ONNX ConstantOfShape operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Generate a tensor with given value and shape."
 | 
					  "Generate a tensor with given value and shape."
 | 
				
			||||||
| 
						 | 
					@ -680,6 +680,17 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
 | 
				
			||||||
    static std::vector<int> getTypeMap() {
 | 
					    static std::vector<int> getTypeMap() {
 | 
				
			||||||
      return {-1};
 | 
					      return {-1};
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    std::vector<mlir::Type> resultTypeInference() {
 | 
				
			||||||
 | 
					      std::vector<mlir::Type> resultTypes;
 | 
				
			||||||
 | 
					      if (auto attr = valueAttr()) {
 | 
				
			||||||
 | 
					        resultTypes.push_back(mlir::UnrankedTensorType::get(
 | 
				
			||||||
 | 
					          attr.getType().cast<ShapedType>().getElementType()));
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        resultTypes.push_back(mlir::UnrankedTensorType::get(
 | 
				
			||||||
 | 
					          FloatType::getF32(getContext())));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return resultTypes;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1438,7 +1449,7 @@ def ONNXGRUOp:ONNX_Op<"GRU",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXGatherOp:ONNX_Op<"Gather",
 | 
					def ONNXGatherOp:ONNX_Op<"Gather",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
  let summary = "ONNX Gather operation";
 | 
					  let summary = "ONNX Gather operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather"
 | 
					  "Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather"
 | 
				
			||||||
| 
						 | 
					@ -4695,7 +4706,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXShapeOp:ONNX_Op<"Shape",
 | 
					def ONNXShapeOp:ONNX_Op<"Shape",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
  let summary = "ONNX Shape operation";
 | 
					  let summary = "ONNX Shape operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
 | 
					  "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
 | 
				
			||||||
| 
						 | 
					@ -4850,7 +4861,7 @@ def ONNXSizeOp:ONNX_Op<"Size",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXSliceOp:ONNX_Op<"Slice",
 | 
					def ONNXSliceOp:ONNX_Op<"Slice",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
  let summary = "ONNX Slice operation";
 | 
					  let summary = "ONNX Slice operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Produces a slice of the input tensor along multiple axes. Similar to numpy:"
 | 
					  "Produces a slice of the input tensor along multiple axes. Similar to numpy:"
 | 
				
			||||||
| 
						 | 
					@ -5345,7 +5356,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXTileOp:ONNX_Op<"Tile",
 | 
					def ONNXTileOp:ONNX_Op<"Tile",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
 | 
				
			||||||
  let summary = "ONNX Tile operation";
 | 
					  let summary = "ONNX Tile operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Constructs a tensor by tiling a given tensor."
 | 
					  "Constructs a tensor by tiling a given tensor."
 | 
				
			||||||
| 
						 | 
					@ -5353,7 +5364,7 @@ def ONNXTileOp:ONNX_Op<"Tile",
 | 
				
			||||||
  "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]"
 | 
					  "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]"
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
  let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$input,
 | 
					  let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$input,
 | 
				
			||||||
    AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$repeats);
 | 
					    AnyTypeOf<[TensorOf<[I64]>, AnyMemRef, NoneType]>:$repeats);
 | 
				
			||||||
  let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$output);
 | 
					  let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$output);
 | 
				
			||||||
  let extraClassDeclaration = [{
 | 
					  let extraClassDeclaration = [{
 | 
				
			||||||
    static int getNumberOfOperands() {
 | 
					    static int getNumberOfOperands() {
 | 
				
			||||||
| 
						 | 
					@ -5365,6 +5376,9 @@ def ONNXTileOp:ONNX_Op<"Tile",
 | 
				
			||||||
    static std::vector<int> getTypeMap() {
 | 
					    static std::vector<int> getTypeMap() {
 | 
				
			||||||
      return {20};
 | 
					      return {20};
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    std::map<std::string, size_t> promotableConstOperands() {
 | 
				
			||||||
 | 
					      return {{"repeats", 1}};
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1128,3 +1128,230 @@ func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi
 | 
				
			||||||
  // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
 | 
					  // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor<i8>, tensor<i8>) -> tensor<1x5x32x64xi32>
 | 
				
			||||||
  // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
 | 
					  // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_shape(%arg0: tensor<?x3x2xf32>) -> tensor<*xi64> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_shape
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Shape"(%arg0) : (tensor<?x3x2xf32>) -> tensor<3xi64>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<3xi64>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_tile_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_tile_dynamic
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<?x?x?x?xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_tile_constant(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_tile_constant
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Tile"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<25x25x16x64xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<25x25x16x64xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_gather_axis0(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Gather"(%arg0, %arg1) {axis = 0} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_gather_axis0
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 0 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<1x2x3xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x2x3xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_gather_axis1(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Gather"(%arg0, %arg1) {axis = 1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_gather_axis1
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<3x1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Gather"(%arg0, %arg1) {axis = -1} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_gather_negative_axis
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Gather"(%arg0, %arg1) {axis = 1 : i64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<3x1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.ConstantOfShape"(%arg0) : (tensor<0xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_constant_of_shape_empty_tensor
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<0.000000e+00> : tensor<1xf32>} : (tensor<0xi64>) -> tensor<f32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_constant_of_shape(%arg0 : tensor<3xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.ConstantOfShape"(%arg0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_constant_of_shape
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"(%arg0) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<?x?x?xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<?x?x?xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_constant_of_shape_constant() -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64> } : () -> tensor<3xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.ConstantOfShape"(%0) {value = dense<[1.0]> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_constant_of_shape_constant
 | 
				
			||||||
 | 
					  // CHECK: [[CONSTANT:%.+]] = "onnx.Constant"() {value = dense<[3, 4, 5]> : tensor<3xi64>} : () -> tensor<3xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.ConstantOfShape"([[CONSTANT]]) {value = dense<1.000000e+00> : tensor<1xf32>} : (tensor<3xi64>) -> tensor<3x4x5xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<3x4x5xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice(%arg0 : tensor<2x4xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2xi64>, %arg3: tensor<2xi64>, %arg4: tensor<2xi64>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES:%.+]] : tensor<?x?xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice_constant_default_axes(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %axes = constant unit
 | 
				
			||||||
 | 
					  %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice_constant_default_axes
 | 
				
			||||||
 | 
					  // CHECK: [[AXES:%.+]] = constant unit
 | 
				
			||||||
 | 
					  // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> 
 | 
				
			||||||
 | 
					  // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, none, tensor<2xi64>) -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %steps = constant unit
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice_constant_default_steps
 | 
				
			||||||
 | 
					  // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> 
 | 
				
			||||||
 | 
					  // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STEPS:%.+]] = constant unit
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<1x3xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x3xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice_all_constant(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice_all_constant
 | 
				
			||||||
 | 
					  // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> 
 | 
				
			||||||
 | 
					  // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %axes = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %ends = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice_all_constant_negative
 | 
				
			||||||
 | 
					  // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> 
 | 
				
			||||||
 | 
					  // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %ends = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %steps = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice_all_constant_end_outofbound
 | 
				
			||||||
 | 
					  // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> 
 | 
				
			||||||
 | 
					  // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %steps = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64> } : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_slice_all_constant_negative_steps
 | 
				
			||||||
 | 
					  // CHECK: [[AXES:%.+]] = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STARTS:%.+]] = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> 
 | 
				
			||||||
 | 
					  // CHECK: [[ENDS:%.+]] = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[STEPS:%.+]] = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Slice"(%arg0, [[STARTS]], [[ENDS]], [[AXES]], [[STEPS]]) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x2xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x2xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,3 +51,16 @@ func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32
 | 
				
			||||||
  // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
 | 
					  // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
					  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_tile_should_promote_to_attribute(%arg0 : tensor<?x10x10xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %shape = constant dense<[6, 7, 42]> : tensor<3xi64>
 | 
				
			||||||
 | 
					  %0 = "onnx.Tile"(%arg0, %shape) : (tensor<?x10x10xf32>, tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xf32>
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_tile_should_promote_to_attribute
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[TILE:%.+]] = "onnx.Tile"(%{{.*}}, [[NONE]]) {repeats = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10x10xf32>, none) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: return [[TILE]] : tensor<*xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -252,7 +252,7 @@ OpsWithShapeInference = [
 | 
				
			||||||
    'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
 | 
					    'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
 | 
				
			||||||
    'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
 | 
					    'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
 | 
				
			||||||
    'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
 | 
					    'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
 | 
				
			||||||
    'Squeeze'
 | 
					    'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice'
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Operations supporting canonicalization.
 | 
					# Operations supporting canonicalization.
 | 
				
			||||||
| 
						 | 
					@ -266,7 +266,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Scaler']
 | 
				
			||||||
# tuples, whose first item is the attribute/operand name, and the second item is
 | 
					# tuples, whose first item is the attribute/operand name, and the second item is
 | 
				
			||||||
# the index at which such operand occurs in the list of the operation's inputs.
 | 
					# the index at which such operand occurs in the list of the operation's inputs.
 | 
				
			||||||
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
 | 
					OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
 | 
				
			||||||
                                  "Pad": [("pads", 1), ("constant_value", 2)]}
 | 
					                                  "Pad": [("pads", 1), ("constant_value", 2)],
 | 
				
			||||||
 | 
					                                  "Tile": [("repeats", 1)]}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Interface for special handling of type inference
 | 
					# Interface for special handling of type inference
 | 
				
			||||||
# The common code are put into get_type_inference_func
 | 
					# The common code are put into get_type_inference_func
 | 
				
			||||||
| 
						 | 
					@ -281,7 +282,15 @@ OpsWithResultTypeInference = {
 | 
				
			||||||
    '''auto toAttr = to().getSExtValue();
 | 
					    '''auto toAttr = to().getSExtValue();
 | 
				
			||||||
      auto builder = mlir::OpBuilder(getContext());
 | 
					      auto builder = mlir::OpBuilder(getContext());
 | 
				
			||||||
      resultTypes.push_back(mlir::UnrankedTensorType::get(
 | 
					      resultTypes.push_back(mlir::UnrankedTensorType::get(
 | 
				
			||||||
        convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));'''
 | 
					        convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));''',
 | 
				
			||||||
 | 
					  "ConstantOfShape":
 | 
				
			||||||
 | 
					  '''if (auto attr = valueAttr()) {
 | 
				
			||||||
 | 
					        resultTypes.push_back(mlir::UnrankedTensorType::get(
 | 
				
			||||||
 | 
					          attr.getType().cast<ShapedType>().getElementType()));
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        resultTypes.push_back(mlir::UnrankedTensorType::get(
 | 
				
			||||||
 | 
					          FloatType::getF32(getContext())));
 | 
				
			||||||
 | 
					      }'''
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Add an Op in this list if the Op needs result type deduction which is required
 | 
					# Add an Op in this list if the Op needs result type deduction which is required
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue