Fix type bug in mhlo.dynamic-update-slice lowering.
The operand type can be f32. We should not use operand type to do clamp operations. PiperOrigin-RevId: 376286524
This commit is contained in:
parent
b5f444232f
commit
402b74ed7f
|
@ -1349,15 +1349,19 @@ class DynamicUpdateSliceConverter
|
||||||
|
|
||||||
auto index_type = rewriter.getIndexType();
|
auto index_type = rewriter.getIndexType();
|
||||||
SmallVector<OpFoldResult, 3> start_indices;
|
SmallVector<OpFoldResult, 3> start_indices;
|
||||||
|
Type start_index_type = adaptor.start_indices()[0]
|
||||||
|
.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
Value zero = rewriter.create<ConstantOp>(
|
Value zero = rewriter.create<ConstantOp>(
|
||||||
loc, rewriter.getZeroAttr(operand_type.getElementType()));
|
loc, rewriter.getZeroAttr(start_index_type));
|
||||||
for (auto en : llvm::enumerate(adaptor.start_indices())) {
|
for (auto en : llvm::enumerate(adaptor.start_indices())) {
|
||||||
// By mhlo.DynamicUpdateSlice definition:
|
// By mhlo.DynamicUpdateSlice definition:
|
||||||
// `start_indices[i] = clamp(start_indices[i],
|
// `start_indices[i] = clamp(start_indices[i],
|
||||||
// 0, operand.dimension_size[i] - update.dimension_size[i])`
|
// 0, operand.dimension_size[i] - update.dimension_size[i])`
|
||||||
Value start_index = rewriter.create<tensor::ExtractOp>(loc, en.value());
|
Value start_index = rewriter.create<tensor::ExtractOp>(loc, en.value());
|
||||||
Value ub = rewriter.create<ConstantOp>(
|
Value ub = rewriter.create<ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(operand_type.getElementType(),
|
loc, rewriter.getIntegerAttr(start_index_type,
|
||||||
operand_type.getDimSize(en.index()) -
|
operand_type.getDimSize(en.index()) -
|
||||||
update_type.getDimSize(en.index())));
|
update_type.getDimSize(en.index())));
|
||||||
// TODO(hanchung): This is a workaround to use the method because only
|
// TODO(hanchung): This is a workaround to use the method because only
|
||||||
|
@ -1365,9 +1369,8 @@ class DynamicUpdateSliceConverter
|
||||||
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
|
// map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
|
||||||
// to an lmhlo op and call the lmhlo implementation.
|
// to an lmhlo op and call the lmhlo implementation.
|
||||||
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
|
start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
|
||||||
loc, start_index.getType(),
|
loc, start_index_type,
|
||||||
ArrayRef<Type>{start_index.getType(), start_index.getType(),
|
ArrayRef<Type>{start_index_type, start_index_type, start_index_type},
|
||||||
start_index.getType()},
|
|
||||||
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
|
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
|
||||||
start_indices.push_back(
|
start_indices.push_back(
|
||||||
rewriter.create<IndexCastOp>(loc, index_type, start_index)
|
rewriter.create<IndexCastOp>(loc, index_type, start_index)
|
||||||
|
|
|
@ -1775,6 +1775,39 @@ func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @dynamic_update_slice_float(%target: tensor<3x3xf32>,
|
||||||
|
%update: tensor<2x2xf32>,
|
||||||
|
%c0: tensor<i32>) -> tensor<3x3xf32> {
|
||||||
|
%0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0)
|
||||||
|
: (tensor<3x3xf32>, tensor<2x2xf32>, tensor<i32>, tensor<i32>) -> tensor<3x3xf32>
|
||||||
|
return %0 : tensor<3x3xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dynamic_update_slice_float(
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : i32
|
||||||
|
// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG2]][] : tensor<i32>
|
||||||
|
// CHECK: %[[UB1:.*]] = constant 1 : i32
|
||||||
|
// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i32
|
||||||
|
// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i32
|
||||||
|
// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i32
|
||||||
|
// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i32
|
||||||
|
// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i32 to index
|
||||||
|
// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor<i32>
|
||||||
|
// CHECK: %[[UB2:.*]] = constant 1 : i32
|
||||||
|
// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i32
|
||||||
|
// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i32
|
||||||
|
// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i32
|
||||||
|
// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i32
|
||||||
|
// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i32 to index
|
||||||
|
// CHECK: %[[RES:.*]] = subtensor_insert %[[ARG1]] into %[[ARG0]]
|
||||||
|
// CHECK-SAME: [%[[START1]], %[[START2]]] [2, 2] [1, 1]
|
||||||
|
// CHECK-SAME: : tensor<2x2xf32> into tensor<3x3xf32>
|
||||||
|
// CHECK: return %[[RES]] : tensor<3x3xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
|
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
|
||||||
%0 = constant dense<0.0> : tensor<f32>
|
%0 = constant dense<0.0> : tensor<f32>
|
||||||
%1 = "mhlo.pad"(%arg0, %0) {
|
%1 = "mhlo.pad"(%arg0, %0) {
|
||||||
|
|
Loading…
Reference in New Issue