Add support for lowering mhlo.dynamic-update-slice ops to Linalg and std ops.

PiperOrigin-RevId: 376042810
This commit is contained in:
Hanhan Wang 2021-05-26 15:30:09 -07:00 committed by TensorFlow MLIR Team
parent 1c7415ba0b
commit 28c411606f
2 changed files with 134 additions and 0 deletions

View File

@ -1316,6 +1316,72 @@ class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
} }
}; };
class DynamicUpdateSliceConverter
: public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
public:
using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DynamicUpdateSliceOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
mhlo::DynamicUpdateSliceOp::Adaptor adaptor(args);
auto operand_type =
adaptor.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_type || !operand_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "require static ranked type for operand");
}
auto update_type = adaptor.update().getType().dyn_cast<RankedTensorType>();
if (!update_type || !update_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "require static ranked type for operand");
}
// We do not have to clamp sizes because the semantic of `update`
// guarantees that it is always in the bounds. See
// https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
SmallVector<OpFoldResult, 3> sizes;
for (auto size : update_type.getShape()) {
sizes.push_back(rewriter.getIndexAttr(size));
}
auto index_type = rewriter.getIndexType();
SmallVector<OpFoldResult, 3> start_indices;
Value zero = rewriter.create<ConstantOp>(
loc, rewriter.getZeroAttr(operand_type.getElementType()));
for (auto en : llvm::enumerate(adaptor.start_indices())) {
// By mhlo.DynamicUpdateSlice definition:
// `start_indices[i] = clamp(start_indices[i],
// 0, operand.dimension_size[i] - update.dimension_size[i])`
Value start_index = rewriter.create<tensor::ExtractOp>(loc, en.value());
Value ub = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(operand_type.getElementType(),
operand_type.getDimSize(en.index()) -
update_type.getDimSize(en.index())));
// TODO(hanchung): This is a workaround to use the method because only
// lmhlo version is defined. The implementation in
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
// to an lmhlo op and call the lmhlo implementation.
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
loc, start_index.getType(),
ArrayRef<Type>{start_index.getType(), start_index.getType(),
start_index.getType()},
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
start_indices.push_back(
rewriter.create<IndexCastOp>(loc, index_type, start_index)
.getResult());
}
int64_t rank = operand_type.getRank();
SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
op, adaptor.update(), adaptor.operand(), start_indices, sizes, strides);
return success();
}
};
enum class DotOperationType { enum class DotOperationType {
kVectorDot = 0, kVectorDot = 0,
kMatrixVector = 1, kMatrixVector = 1,
@ -2440,6 +2506,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
SliceConverter<mhlo::SliceOp, false>, SliceConverter<mhlo::SliceOp, false>,
DynamicSliceConverter, DynamicSliceConverter,
DynamicUpdateSliceConverter,
TransposeConverter<mhlo::TransposeOp, false>, TransposeConverter<mhlo::TransposeOp, false>,
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix, DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
linalg::MatmulOp>, linalg::MatmulOp>,

View File

@ -1708,6 +1708,73 @@ func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor<i64>, %star
// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index // CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index
// CHECK: subtensor %[[SIGNLESS_ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1] // CHECK: subtensor %[[SIGNLESS_ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1]
// -----
func @dynamic_update_slice(%target: tensor<3x3xi32>, %update: tensor<2x2xi32>, %c0: tensor<i32>) -> tensor<3x3xi32> {
%0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0)
: (tensor<3x3xi32>, tensor<2x2xi32>, tensor<i32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
}
// CHECK-LABEL: func @dynamic_update_slice(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.*]] = constant 0 : i32
// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG2]][] : tensor<i32>
// CHECK: %[[UB1:.*]] = constant 1 : i32
// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i32
// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i32
// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i32
// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i32
// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i32 to index
// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor<i32>
// CHECK: %[[UB2:.*]] = constant 1 : i32
// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i32
// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i32
// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i32
// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i32
// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i32 to index
// CHECK: %[[RES:.*]] = subtensor_insert %[[ARG1]] into %[[ARG0]]
// CHECK-SAME: [%[[START1]], %[[START2]]] [2, 2] [1, 1]
// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32>
// CHECK: return %[[RES]] : tensor<3x3xi32>
// -----
func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2x2xui32>, %c0: tensor<i32>) -> tensor<3x3xui32> {
%0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0)
: (tensor<3x3xui32>, tensor<2x2xui32>, tensor<i32>, tensor<i32>) -> tensor<3x3xui32>
return %0 : tensor<3x3xui32>
}
// CHECK-LABEL: func @dynamic_update_slice_unsigned(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
// CHECK: %[[SIGNLESS_TARGET:.*]] = unrealized_conversion_cast %[[ARG0]] : tensor<3x3xui32> to tensor<3x3xi32>
// CHECK: %[[SIGNLESS_UPDATE:.*]] = unrealized_conversion_cast %[[ARG1]] : tensor<2x2xui32> to tensor<2x2xi32>
// CHECK: %[[C0:.*]] = constant 0 : i32
// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG2]][] : tensor<i32>
// CHECK: %[[UB1:.*]] = constant 1 : i32
// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i32
// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i32
// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i32
// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i32
// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i32 to index
// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor<i32>
// CHECK: %[[UB2:.*]] = constant 1 : i32
// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i32
// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i32
// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i32
// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i32
// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i32 to index
// CHECK: %[[SIGNLESS_RES:.*]] = subtensor_insert %[[SIGNLESS_UPDATE]] into %[[SIGNLESS_TARGET]]
// CHECK-SAME: [%[[START1]], %[[START2]]] [2, 2] [1, 1]
// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32>
// CHECK: %[[RES:.*]] = unrealized_conversion_cast %[[SIGNLESS_RES]] : tensor<3x3xi32> to tensor<3x3xui32>
// CHECK: return %[[RES]] : tensor<3x3xui32>
// -----
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
%0 = constant dense<0.0> : tensor<f32> %0 = constant dense<0.0> : tensor<f32>
%1 = "mhlo.pad"(%arg0, %0) { %1 = "mhlo.pad"(%arg0, %0) {