diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 07b2276..07cf259 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1349,15 +1349,19 @@ class DynamicUpdateSliceConverter auto index_type = rewriter.getIndexType(); SmallVector start_indices; + Type start_index_type = adaptor.start_indices()[0] + .getType() + .cast() + .getElementType(); Value zero = rewriter.create( - loc, rewriter.getZeroAttr(operand_type.getElementType())); + loc, rewriter.getZeroAttr(start_index_type)); for (auto en : llvm::enumerate(adaptor.start_indices())) { // By mhlo.DynamicUpdateSlice definition: // `start_indices[i] = clamp(start_indices[i], // 0, operand.dimension_size[i] - update.dimension_size[i])` Value start_index = rewriter.create(loc, en.value()); Value ub = rewriter.create( - loc, rewriter.getIntegerAttr(operand_type.getElementType(), + loc, rewriter.getIntegerAttr(start_index_type, operand_type.getDimSize(en.index()) - update_type.getDimSize(en.index()))); // TODO(hanchung): This is a workaround to use the method because only @@ -1365,9 +1369,8 @@ class DynamicUpdateSliceConverter // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it // to an lmhlo op and call the lmhlo implementation. start_index = lmhlo::HloOpToStdScalarOp::map( - loc, start_index.getType(), - ArrayRef{start_index.getType(), start_index.getType(), - start_index.getType()}, + loc, start_index_type, + ArrayRef{start_index_type, start_index_type, start_index_type}, ArrayRef{zero, start_index, ub}, &rewriter); start_indices.push_back( rewriter.create(loc, index_type, start_index) diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 4528bf8..c49cf55 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1775,6 +1775,39 @@ func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2 // ----- +func @dynamic_update_slice_float(%target: tensor<3x3xf32>, + %update: tensor<2x2xf32>, + %c0: tensor) -> tensor<3x3xf32> { + %0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0) + : (tensor<3x3xf32>, tensor<2x2xf32>, tensor, tensor) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} +// CHECK-LABEL: func @dynamic_update_slice_float( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.*]] = constant 0 : i32 +// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB1:.*]] = constant 1 : i32 +// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i32 +// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i32 +// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i32 +// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i32 +// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i32 to index +// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB2:.*]] = constant 1 : i32 +// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i32 +// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i32 +// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i32 +// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i32 +// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i32 to index +// CHECK: %[[RES:.*]] = subtensor_insert %[[ARG1]] into %[[ARG0]] +// CHECK-SAME: [%[[START1]], %[[START2]]] [2, 2] [1, 1] +// CHECK-SAME: : tensor<2x2xf32> into tensor<3x3xf32> +// CHECK: return %[[RES]] : tensor<3x3xf32> + +// ----- + func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { %0 = constant dense<0.0> : tensor %1 = "mhlo.pad"(%arg0, %0) {