PR #50271: [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50271

support hlo-to-lhlo conversion for GatherOp and DynamicGatherOp
Copybara import of the project:

--
117a1b1bcaac7ecc5224b02863eede5c1b9618fe by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Bufferize GatherOp and DynamicGatherOp

PiperOrigin-RevId: 379801972
This commit is contained in:
Wenyi Zhao 2021-06-16 13:44:21 -07:00 committed by TensorFlow MLIR Team
parent 34dc5f2a79
commit 88cc0c6c46
6 changed files with 194 additions and 3 deletions

View File

@ -1618,7 +1618,7 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]> {
let results = (outs HLO_Tensor);
}
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]> {
def HLO_GatherOp: HLO_ShapedInterfaceOp<"gather", [NoSideEffect]> {
let arguments = (ins
HLO_Tensor:$operand,
HLO_IntTensor:$start_indices,
@ -2268,7 +2268,7 @@ def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad",
let hasCustomHLOConverter = 1;
}
def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> {
def HLO_DynamicGatherOp: HLO_ShapedInterfaceOp<"dynamic_gather", [NoSideEffect]> {
string summary = "Dynamic Gather operator";
string description = [{
The dynamic shape version of GatherOp. Stitches together several slices of an input

View File

@ -1460,7 +1460,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
Arg<LHLO_IntBuffer, "", [MemRead]>:$slice_sizes,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$slice_sizes,
GatherDimensionNumbers:$dimension_numbers,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);

View File

@ -56,6 +56,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
MAP_HLO_TO_LHLO(DynamicGatherOp);
MAP_HLO_TO_LHLO(DynamicIotaOp);
MAP_HLO_TO_LHLO(DynamicPadOp);
MAP_HLO_TO_LHLO(DynamicReshapeOp);

View File

@ -339,6 +339,147 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<GatherSlice>(context);
}
namespace {
// following https://www.tensorflow.org/xla/operation_semantics#gather
// The bounds for the output array along dimension i is computed as follows:
// (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
// then we pick
// the corresponding dimension bounds out of start_indices.shape, skipping
// index_vector_dim
// (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
// start_indices.shape.dims[k+1] otherwise).
// (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
// then we pick
// the corresponding bound out of slice_sizes after accounting for
// collapsed_slice_dims
// (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
// slice_sizes with the bounds at indices collapsed_slice_dims removed).
void GetSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
ValueRange operands,
SmallVectorImpl<Value>& slice_sizes) {
for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
slice_sizes.push_back(builder.create<ConstantIndexOp>(loc, val));
}
}
void GetSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder,
Location loc, ValueRange operands,
SmallVectorImpl<Value>& slice_size_values) {
DynamicGatherOp::Adaptor adaptor(operands);
Value slice_sizes = adaptor.slice_sizes();
auto slice_sizes_ty = slice_sizes.getType().cast<ShapedType>();
for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) {
Value idx = builder.create<ConstantIndexOp>(loc, i);
slice_size_values.push_back(
builder.create<tensor::ExtractOp>(loc, slice_sizes, idx));
}
}
template <typename Op>
LogicalResult GatherShapeInferImpl(
Op* op, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
// Not support unranked pad a.t.m.
auto result_ty =
op->getResult().getType().template dyn_cast<RankedTensorType>();
if (!result_ty) return failure();
typename Op::Adaptor adaptor(operands);
Value start_indices = adaptor.start_indices();
Location loc = op->getLoc();
int result_rank = result_ty.getRank();
Type shape_scalar_type =
start_indices.getType().cast<ShapedType>().getElementType();
auto to_shape_scalar_type = [&](Value v) {
return MaybeCastTo(builder, loc, v, shape_scalar_type);
};
auto dimension_numbers = op->dimension_numbers();
SmallVector<int64_t, 4> collapsed_slice_dims(
dimension_numbers.collapsed_slice_dims().template getValues<int64_t>());
SmallVector<int64_t, 4> offset_dims(
dimension_numbers.offset_dims().template getValues<int64_t>());
int64_t index_vector_dim =
dimension_numbers.index_vector_dim().getValue().getSExtValue();
SmallVector<Value, 4> slice_sizes;
GetSliceSizeValues(op, builder, loc, operands, slice_sizes);
// Convert to `shape_scalar_type`
llvm::transform(slice_sizes, slice_sizes.begin(),
[&](Value v) { return to_shape_scalar_type(v); });
// we label dimensions in the output array not in offset_dims as batch_dims
SmallVector<int64_t, 4> batch_dims;
for (int64_t i = 0; i < result_rank; ++i) {
if (std::find(offset_dims.begin(), offset_dims.end(), i) ==
offset_dims.end()) {
batch_dims.push_back(i);
}
}
// adjusted_slice_sizes is slice_sizes with the bounds at indices
// collapsed_slice_dims removed
SmallVector<Value, 4> adjusted_slice_sizes;
for (int64_t i = 0; i < slice_sizes.size(); ++i) {
if (std::find(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
i) == collapsed_slice_dims.end()) {
adjusted_slice_sizes.push_back(slice_sizes[i]);
}
}
SmallVector<Value, 4> shape_values;
shape_values.reserve(result_rank);
for (int64_t i = 0; i < result_rank; ++i) {
auto iter = std::find(batch_dims.begin(), batch_dims.end(), i);
if (iter != batch_dims.end()) {
// i is present in batch_dims
int64_t k = std::distance(batch_dims.begin(), iter);
if (k < index_vector_dim) {
shape_values.push_back(to_shape_scalar_type(
builder.create<memref::DimOp>(loc, start_indices, k)));
} else {
shape_values.push_back(to_shape_scalar_type(
builder.create<memref::DimOp>(loc, start_indices, k + 1)));
}
} else {
// i is present in offset_dims
auto offset_dims_iter =
std::find(offset_dims.begin(), offset_dims.end(), i);
assert(offset_dims_iter != offset_dims.end());
int64_t k = std::distance(offset_dims.begin(), offset_dims_iter);
assert(k < adjusted_slice_sizes.size());
shape_values.push_back(adjusted_slice_sizes[k]);
}
}
Value output_shape = builder.create<tensor::FromElementsOp>(
loc, shape_scalar_type, shape_values);
reifiedReturnShapes.push_back(output_shape);
return success();
}
} // namespace
LogicalResult GatherOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// DynamicGatherOp
//===----------------------------------------------------------------------===//
//
LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// GetDimensionSizeOp
//===----------------------------------------------------------------------===//

View File

@ -666,6 +666,7 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
HloToLhloOpConverter<mhlo::DynamicGatherOp>,
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
HloToLhloOpConverter<mhlo::DynamicPadOp>,
HloToLhloOpConverter<mhlo::DynamicReshapeOp>,

View File

@ -184,3 +184,51 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %concat : tensor<?x?xi32>
}
// -----
// CHECK-LABEL: func @gather
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?xi32>) -> memref<?x?xf32>
func @gather(%operand: tensor<?x?xf32>, %idxs: tensor<?xi32>)
-> tensor<?x?xf32> {
// CHECK: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref<?xi32>
// CHECK: %[[TMP:.*]] = memref.alloc(%0) : memref<?x7xf32>
// CHECK: %[[OUT:.*]] = memref.cast %[[TMP:.*]] : memref<?x7xf32> to memref<?x?xf32>
// CHECK: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[OUT]])
%result =
"mhlo.gather"(%operand, %idxs)
{ dimension_numbers =
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
, index_vector_dim = 1 : i64
, offset_dims = dense<1> : tensor<1xi64>
, start_index_map = dense<0> : tensor<1xi64> }
, indices_are_sorted = false
, name = "gather.71"
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
: (tensor<?x?xf32>, tensor<?xi32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @dynamic_gather
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?xi32>, %[[ARG2:.*]]: memref<2xi32>) -> memref<?x?xf32>
func @dynamic_gather(%operand: tensor<?x?xf32>, %idxs: tensor<?xi32>, %slice_sizes: tensor<2xi32>)
-> tensor<?x?xf32> {
// CHECK-DAG: %[[SIZE1_i32:.*]] = memref.load %[[ARG2]][%c1] : memref<2xi32>
// CHECK-DAG: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref<?xi32>
// CHECK-DAG: %[[SIZE:.*]] = index_cast %[[SIZE1_i32]] : i32 to index
// CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG1_DIM0]], %[[SIZE]]) : memref<?x?xf32>
// CHECK: "lmhlo.dynamic_gather"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]])
%result =
"mhlo.dynamic_gather"(%operand, %idxs, %slice_sizes)
{ dimension_numbers =
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
, index_vector_dim = 1 : i64
, offset_dims = dense<1> : tensor<1xi64>
, start_index_map = dense<0> : tensor<1xi64> }
, indices_are_sorted = false
, name = "gather.71"}
: (tensor<?x?xf32>, tensor<?xi32>, tensor<2xi32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}