Use InferTypeOpInterface for HLO_SliceOp.

Instead of having a custom builder to construct a slice op without an explicit
return type.

PiperOrigin-RevId: 339058864
This commit is contained in:
Richard Uhler 2020-10-26 09:53:21 -07:00 committed by TensorFlow MLIR Team
parent 444fae9bac
commit f9843fabe1
3 changed files with 65 additions and 67 deletions

View File

@ -699,7 +699,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
def HLO_SliceOp: HLO_Op< def HLO_SliceOp: HLO_Op<
"slice", "slice",
[NoSideEffect, SameOperandsAndResultElementType, [NoSideEffect, SameOperandsAndResultElementType,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { AllTypesMatch<["start_indices", "limit_indices", "strides"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I64ElementsAttr:$start_indices, I64ElementsAttr:$start_indices,
@ -711,21 +712,6 @@ def HLO_SliceOp: HLO_Op<
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value operand, "
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
"DenseIntElementsAttr strides"
>];
let extraClassDeclaration = [{
// Infers output type for given operand and attributes. Result type is
// unranked if any of the attributes is illegal.
static Type InferOutputTypes(Builder *builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides);
}];
} }
def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice",

View File

@ -2142,14 +2142,66 @@ BINARY_FOLDER(MinOp, min);
// SliceOp // SliceOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand, // Returns output dimension size for slice result for the given arguments.
DenseIntElementsAttr start_indices, // Returns -1 if arguments are illegal.
DenseIntElementsAttr limit_indices, static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
DenseIntElementsAttr strides) { int64_t stride) {
return build(builder, result, if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
InferOutputTypes(&builder, operand, start_indices, limit_indices, stride == 0)
strides), return -1;
operand, start_indices, limit_indices, strides);
return llvm::divideCeil(end - start, stride);
}
LogicalResult SliceOp::inferReturnTypes(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
SliceOpAdaptor slice(operands, attributes);
// TODO(jpienaar): Update this code after refactoring verify.
if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) {
return failure();
}
Type ty = slice.operand().getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) {
// The operand type is unranked, so the best we can infer for the result
// type is an unranked tensor with the same element type as the operand
// type.
inferredReturnTypes.assign({ty});
return success();
}
int64_t rank = ranked_ty.getRank();
ShapedType attr_ty = slice.start_indices().getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isSignlessInteger(64) ||
slice.limit_indices().getType() != attr_ty ||
slice.strides().getType() != attr_ty) {
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
// having been verified at this point. Emit an error message that matches
// the one that would be reported by AllTypesMatch for a more consistent
// user experience.
// TODO(b/171567182): Clean this up after AllTypesMatch has been refactored.
return emitOptionalError(location,
"failed to verify that all of {start_indices, "
"limit_indices, strides} have same type");
}
SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(slice.strides().getValues<int64_t>());
SmallVector<int64_t, 4> shape;
shape.reserve(rank);
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
stride_vals[i]));
}
inferredReturnTypes.assign(
{RankedTensorType::get(shape, ranked_ty.getElementType())});
return success();
} }
template <typename I, typename E> template <typename I, typename E>
@ -2332,46 +2384,6 @@ void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<SimplifyConcatSlice>(context); results.insert<SimplifyConcatSlice>(context);
} }
// Returns output dimension size for slice result for the given arguments.
// Returns -1 if arguments are illegal.
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
int64_t stride) {
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
stride == 0)
return -1;
return llvm::divideCeil(end - start, stride);
}
Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
Type ty = operand.getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return ty;
int64_t rank = ranked_ty.getRank();
// Illegal attributes.
ShapedType attr_ty = start_indices.getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isSignlessInteger(64) ||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
return ty;
SmallVector<int64_t, 4> start(start_indices.getValues<int64_t>());
SmallVector<int64_t, 4> limit(limit_indices.getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(strides.getValues<int64_t>());
SmallVector<int64_t, 4> shape;
shape.reserve(rank);
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
stride_vals[i]));
}
return RankedTensorType::get(shape, ranked_ty.getElementType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SortOp // SortOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -691,9 +691,9 @@ func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf3
// ----- // -----
// CHECK-LABEL: func @slice // CHECK-LABEL: func @slice
func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x4xi32> return %0 : tensor<1x2xi32>
} }
// ----- // -----