Use InferTypeOpInterface for HLO_SliceOp.
Instead of having a custom builder to construct a slice op without an explicit return type. PiperOrigin-RevId: 339058864
This commit is contained in:
parent
444fae9bac
commit
f9843fabe1
|
@ -699,7 +699,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
|
||||||
def HLO_SliceOp: HLO_Op<
|
def HLO_SliceOp: HLO_Op<
|
||||||
"slice",
|
"slice",
|
||||||
[NoSideEffect, SameOperandsAndResultElementType,
|
[NoSideEffect, SameOperandsAndResultElementType,
|
||||||
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
|
AllTypesMatch<["start_indices", "limit_indices", "strides"]>,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
I64ElementsAttr:$start_indices,
|
I64ElementsAttr:$start_indices,
|
||||||
|
@ -711,21 +712,6 @@ def HLO_SliceOp: HLO_Op<
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
|
||||||
"OpBuilder &builder, OperationState &result, Value operand, "
|
|
||||||
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
|
|
||||||
"DenseIntElementsAttr strides"
|
|
||||||
>];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
// Infers output type for given operand and attributes. Result type is
|
|
||||||
// unranked if any of the attributes is illegal.
|
|
||||||
static Type InferOutputTypes(Builder *builder, Value operand,
|
|
||||||
DenseIntElementsAttr start_indices,
|
|
||||||
DenseIntElementsAttr limit_indices,
|
|
||||||
DenseIntElementsAttr strides);
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",
|
||||||
|
|
|
@ -2142,14 +2142,66 @@ BINARY_FOLDER(MinOp, min);
|
||||||
// SliceOp
|
// SliceOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand,
|
// Returns output dimension size for slice result for the given arguments.
|
||||||
DenseIntElementsAttr start_indices,
|
// Returns -1 if arguments are illegal.
|
||||||
DenseIntElementsAttr limit_indices,
|
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
|
||||||
DenseIntElementsAttr strides) {
|
int64_t stride) {
|
||||||
return build(builder, result,
|
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
|
||||||
InferOutputTypes(&builder, operand, start_indices, limit_indices,
|
stride == 0)
|
||||||
strides),
|
return -1;
|
||||||
operand, start_indices, limit_indices, strides);
|
|
||||||
|
return llvm::divideCeil(end - start, stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SliceOp::inferReturnTypes(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
SliceOpAdaptor slice(operands, attributes);
|
||||||
|
// TODO(jpienaar): Update this code after refactoring verify.
|
||||||
|
if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type ty = slice.operand().getType();
|
||||||
|
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
|
||||||
|
if (!ranked_ty) {
|
||||||
|
// The operand type is unranked, so the best we can infer for the result
|
||||||
|
// type is an unranked tensor with the same element type as the operand
|
||||||
|
// type.
|
||||||
|
inferredReturnTypes.assign({ty});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t rank = ranked_ty.getRank();
|
||||||
|
ShapedType attr_ty = slice.start_indices().getType();
|
||||||
|
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
|
||||||
|
!attr_ty.getElementType().isSignlessInteger(64) ||
|
||||||
|
slice.limit_indices().getType() != attr_ty ||
|
||||||
|
slice.strides().getType() != attr_ty) {
|
||||||
|
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
|
||||||
|
// having been verified at this point. Emit an error message that matches
|
||||||
|
// the one that would be reported by AllTypesMatch for a more consistent
|
||||||
|
// user experience.
|
||||||
|
// TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
|
||||||
|
return emitOptionalError(location,
|
||||||
|
"failed to verify that all of {start_indices, "
|
||||||
|
"limit_indices, strides} have same type");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
|
||||||
|
SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
|
||||||
|
SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> shape;
|
||||||
|
shape.reserve(rank);
|
||||||
|
for (int64_t i = 0, e = rank; i != e; i++) {
|
||||||
|
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
|
||||||
|
stride_vals[i]));
|
||||||
|
}
|
||||||
|
inferredReturnTypes.assign(
|
||||||
|
{RankedTensorType::get(shape, ranked_ty.getElementType())});
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename I, typename E>
|
template <typename I, typename E>
|
||||||
|
@ -2332,46 +2384,6 @@ void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
results.insert<SimplifyConcatSlice>(context);
|
results.insert<SimplifyConcatSlice>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns output dimension size for slice result for the given arguments.
|
|
||||||
// Returns -1 if arguments are illegal.
|
|
||||||
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
|
|
||||||
int64_t stride) {
|
|
||||||
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
|
|
||||||
stride == 0)
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
return llvm::divideCeil(end - start, stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
|
|
||||||
DenseIntElementsAttr start_indices,
|
|
||||||
DenseIntElementsAttr limit_indices,
|
|
||||||
DenseIntElementsAttr strides) {
|
|
||||||
Type ty = operand.getType();
|
|
||||||
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
|
|
||||||
if (!ranked_ty) return ty;
|
|
||||||
int64_t rank = ranked_ty.getRank();
|
|
||||||
|
|
||||||
// Illegal attributes.
|
|
||||||
ShapedType attr_ty = start_indices.getType();
|
|
||||||
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
|
|
||||||
!attr_ty.getElementType().isSignlessInteger(64) ||
|
|
||||||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
|
|
||||||
return ty;
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> start(start_indices.getValues<int64_t>());
|
|
||||||
SmallVector<int64_t, 4> limit(limit_indices.getValues<int64_t>());
|
|
||||||
SmallVector<int64_t, 4> stride_vals(strides.getValues<int64_t>());
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> shape;
|
|
||||||
shape.reserve(rank);
|
|
||||||
for (int64_t i = 0, e = rank; i != e; i++) {
|
|
||||||
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
|
|
||||||
stride_vals[i]));
|
|
||||||
}
|
|
||||||
return RankedTensorType::get(shape, ranked_ty.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SortOp
|
// SortOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -691,9 +691,9 @@ func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf3
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @slice
|
// CHECK-LABEL: func @slice
|
||||||
func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||||
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||||
return %0 : tensor<1x4xi32>
|
return %0 : tensor<1x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue