[MHLO:Linalg] Add support for lowering dynamic-slice on unsigned ints
PiperOrigin-RevId: 371979004
This commit is contained in:
parent
fdae82aac7
commit
5a60793b31
|
@ -1290,9 +1290,13 @@ class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
|
|||
int64_t rank = arg_type.getRank();
|
||||
SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
rewriter.replaceOpWithNewOp<SubTensorOp>(
|
||||
dynamic_slice_op, dynamic_slice_op.getType().cast<RankedTensorType>(),
|
||||
adaptor.operand(), start_indices, sizes, strides);
|
||||
auto result_type =
|
||||
this->typeConverter->convertType(dynamic_slice_op.getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SubTensorOp>(dynamic_slice_op, result_type,
|
||||
adaptor.operand(), start_indices,
|
||||
sizes, strides);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1596,7 +1596,7 @@ func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor<i64>, %start2: tensor
|
|||
} : (tensor<3x4xf32>, tensor<i64>, tensor<i64>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @dynamic_slice
|
||||
// CHECK-LABEL: func @dynamic_slice(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||
|
@ -1619,6 +1619,35 @@ func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor<i64>, %start2: tensor
|
|||
|
||||
// -----
|
||||
|
||||
func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xui32> {
|
||||
%0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) {
|
||||
slice_sizes = dense<[1, 4]> : tensor<2xi64>
|
||||
} : (tensor<3x4xui32>, tensor<i64>, tensor<i64>) -> tensor<1x4xui32>
|
||||
return %0 : tensor<1x4xui32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_slice_unsigned(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[SIGNLESS_ARG0:.*]] = unrealized_conversion_cast %[[ARG0]] : tensor<3x4xui32> to tensor<3x4xi32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : i64
|
||||
// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG1]][] : tensor<i64>
|
||||
// CHECK: %[[UB1:.*]] = constant 2 : i64
|
||||
// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i64
|
||||
// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i64
|
||||
// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i64
|
||||
// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i64
|
||||
// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i64 to index
|
||||
// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor<i64>
|
||||
// CHECK: %[[UB2:.*]] = constant 0 : i64
|
||||
// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i64
|
||||
// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i64
|
||||
// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i64
|
||||
// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i64
|
||||
// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index
|
||||
// CHECK: subtensor %[[SIGNLESS_ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1]
|
||||
|
||||
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
|
||||
%0 = constant dense<0.0> : tensor<f32>
|
||||
%1 = "mhlo.pad"(%arg0, %0) {
|
||||
|
|
Loading…
Reference in New Issue