From 88cc0c6c46cddbaa6220994bc42afd94e02b4ece Mon Sep 17 00:00:00 2001 From: Wenyi Zhao <951425797@qq.com> Date: Wed, 16 Jun 2021 13:44:21 -0700 Subject: [PATCH] PR #50271: [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50271 support hlo-to-lhlo conversion for GatherOp and DynamicGatherOp Copybara import of the project: -- 117a1b1bcaac7ecc5224b02863eede5c1b9618fe by Wenyi Zhao : [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp PiperOrigin-RevId: 379801972 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 4 +- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 2 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 141 ++++++++++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 1 + tests/hlo-legalize-to-lhlo-only-dynamic.mlir | 48 ++++++ 6 files changed, 194 insertions(+), 3 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 1649793..085f7da 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1618,7 +1618,7 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]> { let results = (outs HLO_Tensor); } -def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]> { +def HLO_GatherOp: HLO_ShapedInterfaceOp<"gather", [NoSideEffect]> { let arguments = (ins HLO_Tensor:$operand, HLO_IntTensor:$start_indices, @@ -2268,7 +2268,7 @@ def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad", let hasCustomHLOConverter = 1; } -def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> { +def HLO_DynamicGatherOp: HLO_ShapedInterfaceOp<"dynamic_gather", [NoSideEffect]> { string summary = "Dynamic Gather operator"; string description = [{ The dynamic shape version of GatherOp. Stitches together several slices of an input diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index ccdf982..a7fccb5 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -1460,7 +1460,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> { let arguments = (ins Arg:$operand, Arg:$start_indices, - Arg:$slice_sizes, + Arg:$slice_sizes, GatherDimensionNumbers:$dimension_numbers, Arg:$output ); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index cb9b360..5562d42 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -56,6 +56,7 @@ MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp); +MAP_HLO_TO_LHLO(DynamicGatherOp); MAP_HLO_TO_LHLO(DynamicIotaOp); MAP_HLO_TO_LHLO(DynamicPadOp); MAP_HLO_TO_LHLO(DynamicReshapeOp); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index d98706b..d7fb7f0 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -339,6 +339,147 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } +namespace { + +// following https://www.tensorflow.org/xla/operation_semantics#gather +// The bounds for the output array along dimension i is computed as follows: +// (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k) +// then we pick +// the corresponding dimension bounds out of start_indices.shape, skipping +// index_vector_dim +// (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and +// start_indices.shape.dims[k+1] otherwise). +// (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k) +// then we pick +// the corresponding bound out of slice_sizes after accounting for +// collapsed_slice_dims +// (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is +// slice_sizes with the bounds at indices collapsed_slice_dims removed). + +void GetSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc, + ValueRange operands, + SmallVectorImpl& slice_sizes) { + for (int64_t val : gather->slice_sizes().getValues()) { + slice_sizes.push_back(builder.create(loc, val)); + } +} + +void GetSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder, + Location loc, ValueRange operands, + SmallVectorImpl& slice_size_values) { + DynamicGatherOp::Adaptor adaptor(operands); + Value slice_sizes = adaptor.slice_sizes(); + auto slice_sizes_ty = slice_sizes.getType().cast(); + for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) { + Value idx = builder.create(loc, i); + slice_size_values.push_back( + builder.create(loc, slice_sizes, idx)); + } +} + +template +LogicalResult GatherShapeInferImpl( + Op* op, OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + // Not support unranked pad a.t.m. + auto result_ty = + op->getResult().getType().template dyn_cast(); + if (!result_ty) return failure(); + + typename Op::Adaptor adaptor(operands); + Value start_indices = adaptor.start_indices(); + + Location loc = op->getLoc(); + int result_rank = result_ty.getRank(); + Type shape_scalar_type = + start_indices.getType().cast().getElementType(); + auto to_shape_scalar_type = [&](Value v) { + return MaybeCastTo(builder, loc, v, shape_scalar_type); + }; + + auto dimension_numbers = op->dimension_numbers(); + SmallVector collapsed_slice_dims( + dimension_numbers.collapsed_slice_dims().template getValues()); + SmallVector offset_dims( + dimension_numbers.offset_dims().template getValues()); + int64_t index_vector_dim = + dimension_numbers.index_vector_dim().getValue().getSExtValue(); + + SmallVector slice_sizes; + GetSliceSizeValues(op, builder, loc, operands, slice_sizes); + // Convert to `shape_scalar_type` + llvm::transform(slice_sizes, slice_sizes.begin(), + [&](Value v) { return to_shape_scalar_type(v); }); + + // we label dimensions in the output array not in offset_dims as batch_dims + SmallVector batch_dims; + for (int64_t i = 0; i < result_rank; ++i) { + if (std::find(offset_dims.begin(), offset_dims.end(), i) == + offset_dims.end()) { + batch_dims.push_back(i); + } + } + // adjusted_slice_sizes is slice_sizes with the bounds at indices + // collapsed_slice_dims removed + SmallVector adjusted_slice_sizes; + for (int64_t i = 0; i < slice_sizes.size(); ++i) { + if (std::find(collapsed_slice_dims.begin(), collapsed_slice_dims.end(), + i) == collapsed_slice_dims.end()) { + adjusted_slice_sizes.push_back(slice_sizes[i]); + } + } + + SmallVector shape_values; + shape_values.reserve(result_rank); + for (int64_t i = 0; i < result_rank; ++i) { + auto iter = std::find(batch_dims.begin(), batch_dims.end(), i); + if (iter != batch_dims.end()) { + // i is present in batch_dims + int64_t k = std::distance(batch_dims.begin(), iter); + if (k < index_vector_dim) { + shape_values.push_back(to_shape_scalar_type( + builder.create(loc, start_indices, k))); + } else { + shape_values.push_back(to_shape_scalar_type( + builder.create(loc, start_indices, k + 1))); + } + } else { + // i is present in offset_dims + auto offset_dims_iter = + std::find(offset_dims.begin(), offset_dims.end(), i); + assert(offset_dims_iter != offset_dims.end()); + int64_t k = std::distance(offset_dims.begin(), offset_dims_iter); + assert(k < adjusted_slice_sizes.size()); + shape_values.push_back(adjusted_slice_sizes[k]); + } + } + + Value output_shape = builder.create( + loc, shape_scalar_type, shape_values); + reifiedReturnShapes.push_back(output_shape); + + return success(); +} + +} // namespace + +LogicalResult GatherOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// DynamicGatherOp +//===----------------------------------------------------------------------===// +// + +LogicalResult DynamicGatherOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes); +} + //===----------------------------------------------------------------------===// // GetDimensionSizeOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 0fa8af8..1e107bd 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -666,6 +666,7 @@ void populateDynamicHLOToLHLOOnlyConversionPattern( OwningRewritePatternList* patterns) { // clang-format off patterns->insert, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir index 4f3b5ee..15099e5 100644 --- a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir +++ b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir @@ -184,3 +184,51 @@ func @concatenate(%a: tensor, %b: tensor, %c: tensor) } : (tensor, tensor, tensor) -> tensor return %concat : tensor } + +// ----- + +// CHECK-LABEL: func @gather +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref) -> memref +func @gather(%operand: tensor, %idxs: tensor) + -> tensor { + // CHECK: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref + // CHECK: %[[TMP:.*]] = memref.alloc(%0) : memref + // CHECK: %[[OUT:.*]] = memref.cast %[[TMP:.*]] : memref to memref + // CHECK: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[OUT]]) + %result = + "mhlo.gather"(%operand, %idxs) + { dimension_numbers = + { collapsed_slice_dims = dense<0> : tensor<1xi64> + , index_vector_dim = 1 : i64 + , offset_dims = dense<1> : tensor<1xi64> + , start_index_map = dense<0> : tensor<1xi64> } + , indices_are_sorted = false + , name = "gather.71" + , slice_sizes = dense<[1, 7]> : tensor<2xi64> } + : (tensor, tensor) -> tensor + return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @dynamic_gather +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<2xi32>) -> memref +func @dynamic_gather(%operand: tensor, %idxs: tensor, %slice_sizes: tensor<2xi32>) + -> tensor { + // CHECK-DAG: %[[SIZE1_i32:.*]] = memref.load %[[ARG2]][%c1] : memref<2xi32> + // CHECK-DAG: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref + // CHECK-DAG: %[[SIZE:.*]] = index_cast %[[SIZE1_i32]] : i32 to index + // CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG1_DIM0]], %[[SIZE]]) : memref + // CHECK: "lmhlo.dynamic_gather"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]]) + %result = + "mhlo.dynamic_gather"(%operand, %idxs, %slice_sizes) + { dimension_numbers = + { collapsed_slice_dims = dense<0> : tensor<1xi64> + , index_vector_dim = 1 : i64 + , offset_dims = dense<1> : tensor<1xi64> + , start_index_map = dense<0> : tensor<1xi64> } + , indices_are_sorted = false + , name = "gather.71"} + : (tensor, tensor, tensor<2xi32>) -> tensor + return %result : tensor +}