PR #50271: [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50271 support hlo-to-lhlo conversion for GatherOp and DynamicGatherOp Copybara import of the project: -- 117a1b1bcaac7ecc5224b02863eede5c1b9618fe by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize GatherOp and DynamicGatherOp PiperOrigin-RevId: 379801972
This commit is contained in:
parent
34dc5f2a79
commit
88cc0c6c46
|
@ -1618,7 +1618,7 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]> {
|
|||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]> {
|
||||
def HLO_GatherOp: HLO_ShapedInterfaceOp<"gather", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_IntTensor:$start_indices,
|
||||
|
@ -2268,7 +2268,7 @@ def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad",
|
|||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> {
|
||||
def HLO_DynamicGatherOp: HLO_ShapedInterfaceOp<"dynamic_gather", [NoSideEffect]> {
|
||||
string summary = "Dynamic Gather operator";
|
||||
string description = [{
|
||||
The dynamic shape version of GatherOp. Stitches together several slices of an input
|
||||
|
|
|
@ -1460,7 +1460,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> {
|
|||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
||||
Arg<LHLO_IntBuffer, "", [MemRead]>:$slice_sizes,
|
||||
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$slice_sizes,
|
||||
GatherDimensionNumbers:$dimension_numbers,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
|
|
|
@ -56,6 +56,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
|
|||
MAP_HLO_TO_LHLO(DivOp);
|
||||
MAP_HLO_TO_LHLO(DotOp);
|
||||
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
|
||||
MAP_HLO_TO_LHLO(DynamicGatherOp);
|
||||
MAP_HLO_TO_LHLO(DynamicIotaOp);
|
||||
MAP_HLO_TO_LHLO(DynamicPadOp);
|
||||
MAP_HLO_TO_LHLO(DynamicReshapeOp);
|
||||
|
|
|
@ -339,6 +339,147 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|||
results.insert<GatherSlice>(context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// following https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
// The bounds for the output array along dimension i is computed as follows:
|
||||
// (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
|
||||
// then we pick
|
||||
// the corresponding dimension bounds out of start_indices.shape, skipping
|
||||
// index_vector_dim
|
||||
// (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
|
||||
// start_indices.shape.dims[k+1] otherwise).
|
||||
// (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
|
||||
// then we pick
|
||||
// the corresponding bound out of slice_sizes after accounting for
|
||||
// collapsed_slice_dims
|
||||
// (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
|
||||
// slice_sizes with the bounds at indices collapsed_slice_dims removed).
|
||||
|
||||
void GetSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
|
||||
ValueRange operands,
|
||||
SmallVectorImpl<Value>& slice_sizes) {
|
||||
for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
|
||||
slice_sizes.push_back(builder.create<ConstantIndexOp>(loc, val));
|
||||
}
|
||||
}
|
||||
|
||||
void GetSliceSizeValues(DynamicGatherOp* d_gather, OpBuilder& builder,
|
||||
Location loc, ValueRange operands,
|
||||
SmallVectorImpl<Value>& slice_size_values) {
|
||||
DynamicGatherOp::Adaptor adaptor(operands);
|
||||
Value slice_sizes = adaptor.slice_sizes();
|
||||
auto slice_sizes_ty = slice_sizes.getType().cast<ShapedType>();
|
||||
for (int64_t i = 0; i < slice_sizes_ty.getDimSize(0); ++i) {
|
||||
Value idx = builder.create<ConstantIndexOp>(loc, i);
|
||||
slice_size_values.push_back(
|
||||
builder.create<tensor::ExtractOp>(loc, slice_sizes, idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
LogicalResult GatherShapeInferImpl(
|
||||
Op* op, OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
// Not support unranked pad a.t.m.
|
||||
auto result_ty =
|
||||
op->getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!result_ty) return failure();
|
||||
|
||||
typename Op::Adaptor adaptor(operands);
|
||||
Value start_indices = adaptor.start_indices();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
int result_rank = result_ty.getRank();
|
||||
Type shape_scalar_type =
|
||||
start_indices.getType().cast<ShapedType>().getElementType();
|
||||
auto to_shape_scalar_type = [&](Value v) {
|
||||
return MaybeCastTo(builder, loc, v, shape_scalar_type);
|
||||
};
|
||||
|
||||
auto dimension_numbers = op->dimension_numbers();
|
||||
SmallVector<int64_t, 4> collapsed_slice_dims(
|
||||
dimension_numbers.collapsed_slice_dims().template getValues<int64_t>());
|
||||
SmallVector<int64_t, 4> offset_dims(
|
||||
dimension_numbers.offset_dims().template getValues<int64_t>());
|
||||
int64_t index_vector_dim =
|
||||
dimension_numbers.index_vector_dim().getValue().getSExtValue();
|
||||
|
||||
SmallVector<Value, 4> slice_sizes;
|
||||
GetSliceSizeValues(op, builder, loc, operands, slice_sizes);
|
||||
// Convert to `shape_scalar_type`
|
||||
llvm::transform(slice_sizes, slice_sizes.begin(),
|
||||
[&](Value v) { return to_shape_scalar_type(v); });
|
||||
|
||||
// we label dimensions in the output array not in offset_dims as batch_dims
|
||||
SmallVector<int64_t, 4> batch_dims;
|
||||
for (int64_t i = 0; i < result_rank; ++i) {
|
||||
if (std::find(offset_dims.begin(), offset_dims.end(), i) ==
|
||||
offset_dims.end()) {
|
||||
batch_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
// adjusted_slice_sizes is slice_sizes with the bounds at indices
|
||||
// collapsed_slice_dims removed
|
||||
SmallVector<Value, 4> adjusted_slice_sizes;
|
||||
for (int64_t i = 0; i < slice_sizes.size(); ++i) {
|
||||
if (std::find(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
|
||||
i) == collapsed_slice_dims.end()) {
|
||||
adjusted_slice_sizes.push_back(slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> shape_values;
|
||||
shape_values.reserve(result_rank);
|
||||
for (int64_t i = 0; i < result_rank; ++i) {
|
||||
auto iter = std::find(batch_dims.begin(), batch_dims.end(), i);
|
||||
if (iter != batch_dims.end()) {
|
||||
// i is present in batch_dims
|
||||
int64_t k = std::distance(batch_dims.begin(), iter);
|
||||
if (k < index_vector_dim) {
|
||||
shape_values.push_back(to_shape_scalar_type(
|
||||
builder.create<memref::DimOp>(loc, start_indices, k)));
|
||||
} else {
|
||||
shape_values.push_back(to_shape_scalar_type(
|
||||
builder.create<memref::DimOp>(loc, start_indices, k + 1)));
|
||||
}
|
||||
} else {
|
||||
// i is present in offset_dims
|
||||
auto offset_dims_iter =
|
||||
std::find(offset_dims.begin(), offset_dims.end(), i);
|
||||
assert(offset_dims_iter != offset_dims.end());
|
||||
int64_t k = std::distance(offset_dims.begin(), offset_dims_iter);
|
||||
assert(k < adjusted_slice_sizes.size());
|
||||
shape_values.push_back(adjusted_slice_sizes[k]);
|
||||
}
|
||||
}
|
||||
|
||||
Value output_shape = builder.create<tensor::FromElementsOp>(
|
||||
loc, shape_scalar_type, shape_values);
|
||||
reifiedReturnShapes.push_back(output_shape);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult GatherOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicGatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
|
||||
LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return GatherShapeInferImpl(this, builder, operands, reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetDimensionSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -666,6 +666,7 @@ void populateDynamicHLOToLHLOOnlyConversionPattern(
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicGatherOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicIotaOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicPadOp>,
|
||||
HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
|
||||
|
|
|
@ -184,3 +184,51 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
|
|||
} : (tensor<?x?xi32>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %concat : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @gather
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?xi32>) -> memref<?x?xf32>
|
||||
func @gather(%operand: tensor<?x?xf32>, %idxs: tensor<?xi32>)
|
||||
-> tensor<?x?xf32> {
|
||||
// CHECK: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref<?xi32>
|
||||
// CHECK: %[[TMP:.*]] = memref.alloc(%0) : memref<?x7xf32>
|
||||
// CHECK: %[[OUT:.*]] = memref.cast %[[TMP:.*]] : memref<?x7xf32> to memref<?x?xf32>
|
||||
// CHECK: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[OUT]])
|
||||
%result =
|
||||
"mhlo.gather"(%operand, %idxs)
|
||||
{ dimension_numbers =
|
||||
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
|
||||
, index_vector_dim = 1 : i64
|
||||
, offset_dims = dense<1> : tensor<1xi64>
|
||||
, start_index_map = dense<0> : tensor<1xi64> }
|
||||
, indices_are_sorted = false
|
||||
, name = "gather.71"
|
||||
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
|
||||
: (tensor<?x?xf32>, tensor<?xi32>) -> tensor<?x?xf32>
|
||||
return %result : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_gather
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: memref<?xi32>, %[[ARG2:.*]]: memref<2xi32>) -> memref<?x?xf32>
|
||||
func @dynamic_gather(%operand: tensor<?x?xf32>, %idxs: tensor<?xi32>, %slice_sizes: tensor<2xi32>)
|
||||
-> tensor<?x?xf32> {
|
||||
// CHECK-DAG: %[[SIZE1_i32:.*]] = memref.load %[[ARG2]][%c1] : memref<2xi32>
|
||||
// CHECK-DAG: %[[ARG1_DIM0:.*]] = memref.dim %[[ARG1]], %c0 : memref<?xi32>
|
||||
// CHECK-DAG: %[[SIZE:.*]] = index_cast %[[SIZE1_i32]] : i32 to index
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc(%[[ARG1_DIM0]], %[[SIZE]]) : memref<?x?xf32>
|
||||
// CHECK: "lmhlo.dynamic_gather"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[OUT]])
|
||||
%result =
|
||||
"mhlo.dynamic_gather"(%operand, %idxs, %slice_sizes)
|
||||
{ dimension_numbers =
|
||||
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
|
||||
, index_vector_dim = 1 : i64
|
||||
, offset_dims = dense<1> : tensor<1xi64>
|
||||
, start_index_map = dense<0> : tensor<1xi64> }
|
||||
, indices_are_sorted = false
|
||||
, name = "gather.71"}
|
||||
: (tensor<?x?xf32>, tensor<?xi32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %result : tensor<?x?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue