diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 3daadf8..07b2276 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1316,6 +1316,72 @@ class DynamicSliceConverter : public OpConversionPattern { } }; +class DynamicUpdateSliceConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicUpdateSliceOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + mhlo::DynamicUpdateSliceOp::Adaptor adaptor(args); + auto operand_type = + adaptor.operand().getType().dyn_cast(); + if (!operand_type || !operand_type.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "require static ranked type for operand"); + } + + auto update_type = adaptor.update().getType().dyn_cast(); + if (!update_type || !update_type.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "require static ranked type for operand"); + } + + // We do not have to clamp sizes because the semantic of `update` + // guarantees that it is always in the bounds. See + // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice + SmallVector sizes; + for (auto size : update_type.getShape()) { + sizes.push_back(rewriter.getIndexAttr(size)); + } + + auto index_type = rewriter.getIndexType(); + SmallVector start_indices; + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(operand_type.getElementType())); + for (auto en : llvm::enumerate(adaptor.start_indices())) { + // By mhlo.DynamicUpdateSlice definition: + // `start_indices[i] = clamp(start_indices[i], + // 0, operand.dimension_size[i] - update.dimension_size[i])` + Value start_index = rewriter.create(loc, en.value()); + Value ub = rewriter.create( + loc, rewriter.getIntegerAttr(operand_type.getElementType(), + operand_type.getDimSize(en.index()) - + update_type.getDimSize(en.index()))); + // TODO(hanchung): This is a workaround to use the method because only + // lmhlo version is defined. The implementation in + // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it + // to an lmhlo op and call the lmhlo implementation. + start_index = lmhlo::HloOpToStdScalarOp::map( + loc, start_index.getType(), + ArrayRef{start_index.getType(), start_index.getType(), + start_index.getType()}, + ArrayRef{zero, start_index, ub}, &rewriter); + start_indices.push_back( + rewriter.create(loc, index_type, start_index) + .getResult()); + } + + int64_t rank = operand_type.getRank(); + SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp( + op, adaptor.update(), adaptor.operand(), start_indices, sizes, strides); + return success(); + } +}; + enum class DotOperationType { kVectorDot = 0, kMatrixVector = 1, @@ -2440,6 +2506,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReverseConverter, SliceConverter, DynamicSliceConverter, + DynamicUpdateSliceConverter, TransposeConverter, DotOpOnTensorsConversion, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index f8dbfe3..4528bf8 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1708,6 +1708,73 @@ func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor, %star // CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i64 to index // CHECK: subtensor %[[SIGNLESS_ARG0]][%[[START1]], %[[START2]]] [1, 4] [1, 1] +// ----- + +func @dynamic_update_slice(%target: tensor<3x3xi32>, %update: tensor<2x2xi32>, %c0: tensor) -> tensor<3x3xi32> { + %0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0) + : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> + return %0 : tensor<3x3xi32> +} +// CHECK-LABEL: func @dynamic_update_slice( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.*]] = constant 0 : i32 +// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB1:.*]] = constant 1 : i32 +// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i32 +// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i32 +// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i32 +// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i32 +// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i32 to index +// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB2:.*]] = constant 1 : i32 +// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i32 +// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i32 +// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i32 +// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i32 +// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i32 to index +// CHECK: %[[RES:.*]] = subtensor_insert %[[ARG1]] into %[[ARG0]] +// CHECK-SAME: [%[[START1]], %[[START2]]] [2, 2] [1, 1] +// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32> +// CHECK: return %[[RES]] : tensor<3x3xi32> + +// ----- + +func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2x2xui32>, %c0: tensor) -> tensor<3x3xui32> { + %0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0) + : (tensor<3x3xui32>, tensor<2x2xui32>, tensor, tensor) -> tensor<3x3xui32> + return %0 : tensor<3x3xui32> +} +// CHECK-LABEL: func @dynamic_update_slice_unsigned( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[SIGNLESS_TARGET:.*]] = unrealized_conversion_cast %[[ARG0]] : tensor<3x3xui32> to tensor<3x3xi32> +// CHECK: %[[SIGNLESS_UPDATE:.*]] = unrealized_conversion_cast %[[ARG1]] : tensor<2x2xui32> to tensor<2x2xi32> +// CHECK: %[[C0:.*]] = constant 0 : i32 +// CHECK: %[[SCALAR1:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB1:.*]] = constant 1 : i32 +// CHECK: %[[COND1:.*]] = cmpi slt, %[[SCALAR1]], %[[UB1]] : i32 +// CHECK: %[[T1:.*]] = select %[[COND1]], %[[SCALAR1]], %[[UB1]] : i32 +// CHECK: %[[COND2:.*]] = cmpi sgt, %[[T1]], %[[C0]] : i32 +// CHECK: %[[CLAMPED1:.*]] = select %[[COND2]], %[[T1]], %[[C0]] : i32 +// CHECK: %[[START1:.*]] = index_cast %[[CLAMPED1]] : i32 to index +// CHECK: %[[SCALAR2:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK: %[[UB2:.*]] = constant 1 : i32 +// CHECK: %[[COND3:.*]] = cmpi slt, %[[SCALAR2]], %[[UB2]] : i32 +// CHECK: %[[T2:.*]] = select %[[COND3]], %[[SCALAR2]], %[[UB2]] : i32 +// CHECK: %[[COND4:.*]] = cmpi sgt, %[[T2]], %[[C0]] : i32 +// CHECK: %[[CLAMPED2:.*]] = select %[[COND4]], %[[T2]], %[[C0]] : i32 +// CHECK: %[[START2:.*]] = index_cast %[[CLAMPED2]] : i32 to index +// CHECK: %[[SIGNLESS_RES:.*]] = subtensor_insert %[[SIGNLESS_UPDATE]] into %[[SIGNLESS_TARGET]] +// CHECK-SAME: [%[[START1]], %[[START2]]] [2, 2] [1, 1] +// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32> +// CHECK: %[[RES:.*]] = unrealized_conversion_cast %[[SIGNLESS_RES]] : tensor<3x3xi32> to tensor<3x3xui32> +// CHECK: return %[[RES]] : tensor<3x3xui32> + +// ----- + func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { %0 = constant dense<0.0> : tensor %1 = "mhlo.pad"(%arg0, %0) {