diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 579e89c..33c13aa 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -699,7 +699,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, def HLO_SliceOp: HLO_Op< "slice", [NoSideEffect, SameOperandsAndResultElementType, - AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { + AllTypesMatch<["start_indices", "limit_indices", "strides"]>, + DeclareOpInterfaceMethods]> { let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$start_indices, @@ -711,21 +712,6 @@ def HLO_SliceOp: HLO_Op< let hasCanonicalizer = 1; let hasFolder = 1; - - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand, " - "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " - "DenseIntElementsAttr strides" - >]; - - let extraClassDeclaration = [{ - // Infers output type for given operand and attributes. Result type is - // unranked if any of the attributes is illegal. - static Type InferOutputTypes(Builder *builder, Value operand, - DenseIntElementsAttr start_indices, - DenseIntElementsAttr limit_indices, - DenseIntElementsAttr strides); - }]; } def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index c04e27d..1725522 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2142,14 +2142,66 @@ BINARY_FOLDER(MinOp, min); // SliceOp //===----------------------------------------------------------------------===// -void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand, - DenseIntElementsAttr start_indices, - DenseIntElementsAttr limit_indices, - DenseIntElementsAttr strides) { - return build(builder, result, - InferOutputTypes(&builder, operand, start_indices, limit_indices, - strides), - operand, start_indices, limit_indices, strides); +// Returns output dimension size for slice result for the given arguments. +// Returns -1 if arguments are illegal. +static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end, + int64_t stride) { + if (input_dim == -1 || start < 0 || start > end || end > input_dim || + stride == 0) + return -1; + + return llvm::divideCeil(end - start, stride); +} + +LogicalResult SliceOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + SliceOpAdaptor slice(operands, attributes); + // TODO(jpienaar): Update this code after refactoring verify. + if (failed(slice.verify(location.getValueOr(UnknownLoc::get(context))))) { + return failure(); + } + + Type ty = slice.operand().getType(); + RankedTensorType ranked_ty = ty.dyn_cast(); + if (!ranked_ty) { + // The operand type is unranked, so the best we can infer for the result + // type is an unranked tensor with the same element type as the operand + // type. + inferredReturnTypes.assign({ty}); + return success(); + } + + int64_t rank = ranked_ty.getRank(); + ShapedType attr_ty = slice.start_indices().getType(); + if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || + !attr_ty.getElementType().isSignlessInteger(64) || + slice.limit_indices().getType() != attr_ty || + slice.strides().getType() != attr_ty) { + // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp + // having been verified at this point. Emit an error message that matches + // the one that would be reported by AllTypesMatch for a more consistent + // user experience. + // TODO(b/171567182): Clean this up after AllTypesMatch has been refactored. + return emitOptionalError(location, + "failed to verify that all of {start_indices, " + "limit_indices, strides} have same type"); + } + + SmallVector start(slice.start_indices().getValues()); + SmallVector limit(slice.limit_indices().getValues()); + SmallVector stride_vals(slice.strides().getValues()); + + SmallVector shape; + shape.reserve(rank); + for (int64_t i = 0, e = rank; i != e; i++) { + shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i], + stride_vals[i])); + } + inferredReturnTypes.assign( + {RankedTensorType::get(shape, ranked_ty.getElementType())}); + return success(); } template @@ -2332,46 +2384,6 @@ void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } -// Returns output dimension size for slice result for the given arguments. -// Returns -1 if arguments are illegal. -static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end, - int64_t stride) { - if (input_dim == -1 || start < 0 || start > end || end > input_dim || - stride == 0) - return -1; - - return llvm::divideCeil(end - start, stride); -} - -Type SliceOp::InferOutputTypes(Builder* builder, Value operand, - DenseIntElementsAttr start_indices, - DenseIntElementsAttr limit_indices, - DenseIntElementsAttr strides) { - Type ty = operand.getType(); - RankedTensorType ranked_ty = ty.dyn_cast(); - if (!ranked_ty) return ty; - int64_t rank = ranked_ty.getRank(); - - // Illegal attributes. - ShapedType attr_ty = start_indices.getType(); - if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || - !attr_ty.getElementType().isSignlessInteger(64) || - limit_indices.getType() != attr_ty || strides.getType() != attr_ty) - return ty; - - SmallVector start(start_indices.getValues()); - SmallVector limit(limit_indices.getValues()); - SmallVector stride_vals(strides.getValues()); - - SmallVector shape; - shape.reserve(rank); - for (int64_t i = 0, e = rank; i != e; i++) { - shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i], - stride_vals[i])); - } - return RankedTensorType::get(shape, ranked_ty.getElementType()); -} - //===----------------------------------------------------------------------===// // SortOp //===----------------------------------------------------------------------===// diff --git a/tests/ops.mlir b/tests/ops.mlir index fb4ab62..d22f7d1 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -691,9 +691,9 @@ func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf3 // ----- // CHECK-LABEL: func @slice -func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> +func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> } // -----