diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index ec57da4..ea28d37 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1290,9 +1290,13 @@ class DynamicSliceConverter : public OpConversionPattern { int64_t rank = arg_type.getRank(); SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); - rewriter.replaceOpWithNewOp( - dynamic_slice_op, dynamic_slice_op.getType().cast(), - adaptor.operand(), start_indices, sizes, strides); + auto result_type = + this->typeConverter->convertType(dynamic_slice_op.getType()) + .cast(); + + rewriter.replaceOpWithNewOp(dynamic_slice_op, result_type, + adaptor.operand(), start_indices, + sizes, strides); return success(); } }; diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index e7dc9c5..8c52cd8 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1596,7 +1596,7 @@ func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor } : (tensor<3x4xf32>, tensor, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } -// CHECK-LABEL: func @dynamic_slice +// CHECK-LABEL: func @dynamic_slice( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] @@ -1619,6 +1619,35 @@ func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor // ----- +func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) -> tensor<1x4xui32> { + %0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) { + slice_sizes = dense<[1, 4]> : tensor<2xi64> + } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> + return %0 : tensor<1x4xui32> +} + +// CHECK-LABEL: func @dynamic_slice_unsigned( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[SIGNLESS_ARG0:.*]] = unrealized_conversion_cast %[[ARG0]] : tensor<3x4xui32> to tensor<3x4xi32> +// CHECK: %[[C0:.*]] = constant 0 : i64 +// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[UB1:.*]] = constant 2 : i64 +// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i64 +// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i64 +// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i64 +// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i64 +// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i64 to index +// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB2:.*]] = constant 0 : i64 +// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i64 +// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i64 +// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i64 +// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i64 +// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index +// CHECK: subtensor %[[SIGNLESS_ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1] + func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { %0 = constant dense<0.0> : tensor %1 = "mhlo.pad"(%arg0, %0) {