Add support for lowering mhlo.dynamic_slice to Linalg ops.
PiperOrigin-RevId: 368033540
This commit is contained in:
parent
0ec0a23e61
commit
a3fc99efe0
|
@ -1134,6 +1134,70 @@ class SliceConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DynamicSliceOp dynamic_slice_op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = dynamic_slice_op.getLoc();
|
||||||
|
mhlo::DynamicSliceOp::Adaptor adaptor(args);
|
||||||
|
auto arg_type = adaptor.operand().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!arg_type || !arg_type.hasRank()) {
|
||||||
|
return rewriter.notifyMatchFailure(dynamic_slice_op,
|
||||||
|
"require known-rank args");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto index_type = rewriter.getIndexType();
|
||||||
|
SmallVector<OpFoldResult, 3> start_indices, sizes;
|
||||||
|
Value zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getZeroAttr(adaptor.start_indices()[0]
|
||||||
|
.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType()));
|
||||||
|
for (auto en : llvm::enumerate(
|
||||||
|
llvm::zip(adaptor.start_indices(),
|
||||||
|
dynamic_slice_op.slice_sizes().getValues<int64_t>()))) {
|
||||||
|
int64_t size = std::get<1>(en.value());
|
||||||
|
sizes.push_back(rewriter.getI64IntegerAttr(size));
|
||||||
|
|
||||||
|
// By mhlo.DynamicSlice definition:
|
||||||
|
// `start_indices[i] = clamp(start_indices[i],
|
||||||
|
// 0, operand.dimension_size[i] - size_indices[i])`
|
||||||
|
Value start_index =
|
||||||
|
rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value()));
|
||||||
|
Value ub = rewriter.createOrFold<memref::DimOp>(loc, adaptor.operand(),
|
||||||
|
en.index());
|
||||||
|
// ClampOp lowering does not support index type, so cast it into integer
|
||||||
|
// type.
|
||||||
|
ub = rewriter.createOrFold<IndexCastOp>(loc, start_index.getType(), ub);
|
||||||
|
ub = rewriter.createOrFold<SubIOp>(
|
||||||
|
loc, ub,
|
||||||
|
rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(start_index.getType(), size)));
|
||||||
|
// TODO(hanchung): This is a workaround to use the method because only
|
||||||
|
// lmhlo version is defined. The implementation in
|
||||||
|
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
|
||||||
|
// to an lmhlo op and call the lmhlo implementation.
|
||||||
|
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
|
||||||
|
loc, start_index.getType(), ArrayRef<Value>{zero, start_index, ub},
|
||||||
|
&rewriter);
|
||||||
|
start_indices.push_back(
|
||||||
|
rewriter.create<IndexCastOp>(loc, index_type, start_index)
|
||||||
|
.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t rank = arg_type.getRank();
|
||||||
|
SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<SubTensorOp>(
|
||||||
|
dynamic_slice_op, dynamic_slice_op.getType().cast<RankedTensorType>(),
|
||||||
|
adaptor.operand(), start_indices, sizes, strides);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
enum class DotOperationType {
|
enum class DotOperationType {
|
||||||
kVectorDot = 0,
|
kVectorDot = 0,
|
||||||
kMatrixVector = 1,
|
kMatrixVector = 1,
|
||||||
|
@ -2090,6 +2154,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
SliceConverter<mhlo::SliceOp, false>,
|
SliceConverter<mhlo::SliceOp, false>,
|
||||||
|
DynamicSliceConverter,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>,
|
TransposeConverter<mhlo::TransposeOp, false>,
|
||||||
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
|
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
|
||||||
linalg::MatmulOp>,
|
linalg::MatmulOp>,
|
||||||
|
|
|
@ -1532,6 +1532,35 @@ func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xf32> {
|
||||||
|
%0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) {
|
||||||
|
slice_sizes = dense<[1, 4]> : tensor<2xi64>
|
||||||
|
} : (tensor<3x4xf32>, tensor<i64>, tensor<i64>) -> tensor<1x4xf32>
|
||||||
|
return %0 : tensor<1x4xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dynamic_slice
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : i64
|
||||||
|
// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG1]][] : tensor<i64>
|
||||||
|
// CHECK: %[[UB1:.*]] = constant 2 : i64
|
||||||
|
// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i64
|
||||||
|
// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i64
|
||||||
|
// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i64
|
||||||
|
// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i64
|
||||||
|
// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i64 to index
|
||||||
|
// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor<i64>
|
||||||
|
// CHECK: %[[UB2:.*]] = constant 0 : i64
|
||||||
|
// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i64
|
||||||
|
// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i64
|
||||||
|
// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i64
|
||||||
|
// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i64
|
||||||
|
// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index
|
||||||
|
// CHECK: subtensor %[[ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
|
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
|
||||||
%0 = constant dense<0.0> : tensor<f32>
|
%0 = constant dense<0.0> : tensor<f32>
|
||||||
%1 = "mhlo.pad"(%arg0, %0) {
|
%1 = "mhlo.pad"(%arg0, %0) {
|
||||||
|
|
Loading…
Reference in New Issue