PR #50271: [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50271 support hlo-to-lhlo conversion for GatherOp and DynamicGatherOp Copybara import of the project: -- 117a1b1bcaac7ecc5224b02863eede5c1b9618fe by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp PiperOrigin-RevId: 379801972
This commit is contained in:
parent
34dc5f2a79
commit
88cc0c6c46
|
@ -1618,7 +1618,7 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]> {
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]> {
|
def HLO_GatherOp: HLO_ShapedInterfaceOp<"gather", [NoSideEffect]> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
HLO_IntTensor:$start_indices,
|
HLO_IntTensor:$start_indices,
|
||||||
|
@ -2268,7 +2268,7 @@ def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad",
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> {
|
def HLO_DynamicGatherOp: HLO_ShapedInterfaceOp<"dynamic_gather", [NoSideEffect]> {
|
||||||
string summary = "Dynamic Gather operator";
|
string summary = "Dynamic Gather operator";
|
||||||
string description = [{
|
string description = [{
|
||||||
The dynamic shape version of GatherOp. Stitches together several slices of an input
|
The dynamic shape version of GatherOp. Stitches together several slices of an input
|
||||||
|
|
|
@ -1460,7 +1460,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
||||||
Arg<LHLO_IntBuffer, "", [MemRead]>:$slice_sizes,
|
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$slice_sizes,
|
||||||
GatherDimensionNumbers:$dimension_numbers,
|
GatherDimensionNumbers:$dimension_numbers,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
|
|
|
@ -56,6 +56,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
|
||||||
MAP_HLO_TO_LHLO(DivOp);
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
MAP_HLO_TO_LHLO(DotOp);
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
||||||
|
MAP_HLO_TO_LHLO(DynamicGatherOp);
|
||||||
MAP_HLO_TO_LHLO(DynamicIotaOp);
|
MAP_HLO_TO_LHLO(DynamicIotaOp);
|
||||||
MAP_HLO_TO_LHLO(DynamicPadOp);
|
MAP_HLO_TO_LHLO(DynamicPadOp);
|
||||||
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
||||||
|
|
|
@ -339,6 +339,147 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
results.insert<GatherSlice>(context);
|
results.insert<GatherSlice>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// following https://www.tensorflow.org/xla/operation_semantics#gather
|
||||||
|
// The bounds for the output array along dimension i is computed as follows:
|
||||||
|
// (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
|
||||||
|
// then we pick
|
||||||
|
// the corresponding dimension bounds out of start_indices.shape, skipping
|
||||||
|
// index_vector_dim
|
||||||
|
// (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
|
||||||
|
// start_indices.shape.dims[k+1] otherwise).
|
||||||
|
// (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
|
||||||
|
// then we pick
|
||||||
|
// the corresponding bound out of slice_sizes after accounting for
|
||||||
|
// collapsed_slice_dims
|
||||||
|
// (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
|
||||||
|
// slice_sizes with the bounds at indices collapsed_slice_dims removed).
|
||||||
|
|
||||||
|
void GetSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
|
||||||
|
ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& slice_sizes) {
|
||||||
|
for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
|
||||||
|
slice_sizes.push_back(builder.create<ConstantIndexOp>(loc, val));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder,
|
||||||
|
Location loc, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& slice_size_values) {
|
||||||
|
DynamicGatherOp::Adaptor adaptor(operands);
|
||||||
|
Value slice_sizes = adaptor.slice_sizes();
|
||||||
|
auto slice_sizes_ty = slice_sizes.getType().cast<ShapedType>();
|
||||||
|
for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) {
|
||||||
|
Value idx = builder.create<ConstantIndexOp>(loc, i);
|
||||||
|
slice_size_values.push_back(
|
||||||
|
builder.create<tensor::ExtractOp>(loc, slice_sizes, idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
LogicalResult GatherShapeInferImpl(
|
||||||
|
Op* op, OpBuilder& builder, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
// Not support unranked pad a.t.m.
|
||||||
|
auto result_ty =
|
||||||
|
op->getResult().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!result_ty) return failure();
|
||||||
|
|
||||||
|
typename Op::Adaptor adaptor(operands);
|
||||||
|
Value start_indices = adaptor.start_indices();
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
int result_rank = result_ty.getRank();
|
||||||
|
Type shape_scalar_type =
|
||||||
|
start_indices.getType().cast<ShapedType>().getElementType();
|
||||||
|
auto to_shape_scalar_type = [&](Value v) {
|
||||||
|
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto dimension_numbers = op->dimension_numbers();
|
||||||
|
SmallVector<int64_t, 4> collapsed_slice_dims(
|
||||||
|
dimension_numbers.collapsed_slice_dims().template getValues<int64_t>());
|
||||||
|
SmallVector<int64_t, 4> offset_dims(
|
||||||
|
dimension_numbers.offset_dims().template getValues<int64_t>());
|
||||||
|
int64_t index_vector_dim =
|
||||||
|
dimension_numbers.index_vector_dim().getValue().getSExtValue();
|
||||||
|
|
||||||
|
SmallVector<Value, 4> slice_sizes;
|
||||||
|
GetSliceSizeValues(op, builder, loc, operands, slice_sizes);
|
||||||
|
// Convert to `shape_scalar_type`
|
||||||
|
llvm::transform(slice_sizes, slice_sizes.begin(),
|
||||||
|
[&](Value v) { return to_shape_scalar_type(v); });
|
||||||
|
|
||||||
|
// we label dimensions in the output array not in offset_dims as batch_dims
|
||||||
|
SmallVector<int64_t, 4> batch_dims;
|
||||||
|
for (int64_t i = 0; i < result_rank; ++i) {
|
||||||
|
if (std::find(offset_dims.begin(), offset_dims.end(), i) ==
|
||||||
|
offset_dims.end()) {
|
||||||
|
batch_dims.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// adjusted_slice_sizes is slice_sizes with the bounds at indices
|
||||||
|
// collapsed_slice_dims removed
|
||||||
|
SmallVector<Value, 4> adjusted_slice_sizes;
|
||||||
|
for (int64_t i = 0; i < slice_sizes.size(); ++i) {
|
||||||
|
if (std::find(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
|
||||||
|
i) == collapsed_slice_dims.end()) {
|
||||||
|
adjusted_slice_sizes.push_back(slice_sizes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> shape_values;
|
||||||
|
shape_values.reserve(result_rank);
|
||||||
|
for (int64_t i = 0; i < result_rank; ++i) {
|
||||||
|
auto iter = std::find(batch_dims.begin(), batch_dims.end(), i);
|
||||||
|
if (iter != batch_dims.end()) {
|
||||||
|
// i is present in batch_dims
|
||||||
|
int64_t k = std::distance(batch_dims.begin(), iter);
|
||||||
|
if (k < index_vector_dim) {
|
||||||
|
shape_values.push_back(to_shape_scalar_type(
|
||||||
|
builder.create<memref::DimOp>(loc, start_indices, k)));
|
||||||
|
} else {
|
||||||
|
shape_values.push_back(to_shape_scalar_type(
|
||||||
|
builder.create<memref::DimOp>(loc, start_indices, k + 1)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// i is present in offset_dims
|
||||||
|
auto offset_dims_iter =
|
||||||
|
std::find(offset_dims.begin(), offset_dims.end(), i);
|
||||||
|
assert(offset_dims_iter != offset_dims.end());
|
||||||
|
int64_t k = std::distance(offset_dims.begin(), offset_dims_iter);
|
||||||
|
assert(k < adjusted_slice_sizes.size());
|
||||||
|
shape_values.push_back(adjusted_slice_sizes[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||||
|
loc, shape_scalar_type, shape_values);
|
||||||
|
reifiedReturnShapes.push_back(output_shape);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult GatherOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DynamicGatherOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
|
||||||
|
LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, ValueRange operands,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GetDimensionSizeOp
|
// GetDimensionSizeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -666,6 +666,7 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::DynamicGatherOp>,
|
||||||
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
|
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
|
||||||
HloToLhloOpConverter<mhlo::DynamicPadOp>,
|
HloToLhloOpConverter<mhlo::DynamicPadOp>,
|
||||||
HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
|
HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
|
||||||
|
|
|
@ -184,3 +184,51 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
|
||||||
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
return %concat : tensor<?x?xi32>
|
return %concat : tensor<?x?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @gather
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?xi32>) -> memref<?x?xf32>
|
||||||
|
func @gather(%operand: tensor<?x?xf32>, %idxs: tensor<?xi32>)
|
||||||
|
-> tensor<?x?xf32> {
|
||||||
|
// CHECK: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref<?xi32>
|
||||||
|
// CHECK: %[[TMP:.*]] = memref.alloc(%0) : memref<?x7xf32>
|
||||||
|
// CHECK: %[[OUT:.*]] = memref.cast %[[TMP:.*]] : memref<?x7xf32> to memref<?x?xf32>
|
||||||
|
// CHECK: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[OUT]])
|
||||||
|
%result =
|
||||||
|
"mhlo.gather"(%operand, %idxs)
|
||||||
|
{ dimension_numbers =
|
||||||
|
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
|
||||||
|
, index_vector_dim = 1 : i64
|
||||||
|
, offset_dims = dense<1> : tensor<1xi64>
|
||||||
|
, start_index_map = dense<0> : tensor<1xi64> }
|
||||||
|
, indices_are_sorted = false
|
||||||
|
, name = "gather.71"
|
||||||
|
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
|
||||||
|
: (tensor<?x?xf32>, tensor<?xi32>) -> tensor<?x?xf32>
|
||||||
|
return %result : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_gather
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?xi32>, %[[ARG2:.*]]: memref<2xi32>) -> memref<?x?xf32>
|
||||||
|
func @dynamic_gather(%operand: tensor<?x?xf32>, %idxs: tensor<?xi32>, %slice_sizes: tensor<2xi32>)
|
||||||
|
-> tensor<?x?xf32> {
|
||||||
|
// CHECK-DAG: %[[SIZE1_i32:.*]] = memref.load %[[ARG2]][%c1] : memref<2xi32>
|
||||||
|
// CHECK-DAG: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref<?xi32>
|
||||||
|
// CHECK-DAG: %[[SIZE:.*]] = index_cast %[[SIZE1_i32]] : i32 to index
|
||||||
|
// CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG1_DIM0]], %[[SIZE]]) : memref<?x?xf32>
|
||||||
|
// CHECK: "lmhlo.dynamic_gather"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]])
|
||||||
|
%result =
|
||||||
|
"mhlo.dynamic_gather"(%operand, %idxs, %slice_sizes)
|
||||||
|
{ dimension_numbers =
|
||||||
|
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
|
||||||
|
, index_vector_dim = 1 : i64
|
||||||
|
, offset_dims = dense<1> : tensor<1xi64>
|
||||||
|
, start_index_map = dense<0> : tensor<1xi64> }
|
||||||
|
, indices_are_sorted = false
|
||||||
|
, name = "gather.71"}
|
||||||
|
: (tensor<?x?xf32>, tensor<?xi32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||||
|
return %result : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue