diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 850ef00..e482f2b 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1134,6 +1134,70 @@ class SliceConverter : public OpConversionPattern { } }; +class DynamicSliceConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicSliceOp dynamic_slice_op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = dynamic_slice_op.getLoc(); + mhlo::DynamicSliceOp::Adaptor adaptor(args); + auto arg_type = adaptor.operand().getType().dyn_cast(); + if (!arg_type || !arg_type.hasRank()) { + return rewriter.notifyMatchFailure(dynamic_slice_op, + "require known-rank args"); + } + + auto index_type = rewriter.getIndexType(); + SmallVector start_indices, sizes; + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(adaptor.start_indices()[0] + .getType() + .cast() + .getElementType())); + for (auto en : llvm::enumerate( + llvm::zip(adaptor.start_indices(), + dynamic_slice_op.slice_sizes().getValues()))) { + int64_t size = std::get<1>(en.value()); + sizes.push_back(rewriter.getI64IntegerAttr(size)); + + // By mhlo.DynamicSlice definition: + // `start_indices[i] = clamp(start_indices[i], + // 0, operand.dimension_size[i] - size_indices[i])` + Value start_index = + rewriter.create(loc, std::get<0>(en.value())); + Value ub = rewriter.createOrFold(loc, adaptor.operand(), + en.index()); + // ClampOp lowering does not support index type, so cast it into integer + // type. + ub = rewriter.createOrFold(loc, start_index.getType(), ub); + ub = rewriter.createOrFold( + loc, ub, + rewriter.create( + loc, rewriter.getIntegerAttr(start_index.getType(), size))); + // TODO(hanchung): This is a workaround to use the method because only + // lmhlo version is defined. The implementation in + // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it + // to an lmhlo op and call the lmhlo implementation. + start_index = lmhlo::HloOpToStdScalarOp::map( + loc, start_index.getType(), ArrayRef{zero, start_index, ub}, + &rewriter); + start_indices.push_back( + rewriter.create(loc, index_type, start_index) + .getResult()); + } + + int64_t rank = arg_type.getRank(); + SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); + + rewriter.replaceOpWithNewOp( + dynamic_slice_op, dynamic_slice_op.getType().cast(), + adaptor.operand(), start_indices, sizes, strides); + return success(); + } +}; + enum class DotOperationType { kVectorDot = 0, kMatrixVector = 1, @@ -2090,6 +2154,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReshapeOpConverter, ReverseConverter, SliceConverter, + DynamicSliceConverter, TransposeConverter, DotOpOnTensorsConversion, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index a3eb16b..76a1629 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1532,6 +1532,35 @@ func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { // ----- +func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor) -> tensor<1x4xf32> { + %0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) { + slice_sizes = dense<[1, 4]> : tensor<2xi64> + } : (tensor<3x4xf32>, tensor, tensor) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} +// CHECK-LABEL: func @dynamic_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.*]] = constant 0 : i64 +// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[UB1:.*]] = constant 2 : i64 +// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i64 +// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i64 +// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i64 +// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i64 +// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i64 to index +// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB2:.*]] = constant 0 : i64 +// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i64 +// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i64 +// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i64 +// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i64 +// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index +// CHECK: subtensor %[[ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1] + +// ----- + func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { %0 = constant dense<0.0> : tensor %1 = "mhlo.pad"(%arg0, %0) {