Add a transform for Gathers to torch_index_select.
Some gathers can be interpreted as torch index selects. Transforming these cases allow torch_index_select lowerings to be used for certain gathers. PiperOrigin-RevId: 322255835
This commit is contained in:
parent
cc776071fe
commit
c23ad602c8
|
@ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context,
|
|||
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Rewrite patterns for gather to equivalent torch index select legalization.
|
||||
void PopulateGatherToTorchIndexSelectPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
|
||||
using OpRewritePattern<GatherOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GatherOp gather,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto start_indices = gather.start_indices();
|
||||
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
|
||||
if (!start_indices_ty.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto operand = gather.operand();
|
||||
auto operand_ty = operand.getType().cast<ShapedType>();
|
||||
if (!operand_ty.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
int64_t index_vector_dim =
|
||||
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
|
||||
|
||||
// We can use torch_index_select if the last dimension represents the
|
||||
// gather indices.
|
||||
auto dimension_numbers = gather.dimension_numbers();
|
||||
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
|
||||
index_vector_dim) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Index select only works across a single dimension.
|
||||
if (!start_indices_ty.getShape().empty() &&
|
||||
start_indices_ty.getShape().back() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Only support the default case for start_index_map.
|
||||
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
|
||||
dimension_numbers.start_index_map()
|
||||
.getValue(0)
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_ty) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Offset dimensions should be the defaults.
|
||||
if (dimension_numbers.offset_dims().getType().getNumElements() !=
|
||||
result_ty.getRank() - index_vector_dim) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
|
||||
if ((it.index() + index_vector_dim) != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
|
||||
// First shape value must be 1.
|
||||
if (it.index() == 0) {
|
||||
if (it.value().getSExtValue() != 1) {
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// The gather needs to index the entire slice for each other dimension.
|
||||
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 4> index_select_shape =
|
||||
llvm::to_vector<4>(start_indices_ty.getShape());
|
||||
|
||||
for (auto dim : operand_ty.getShape().drop_front()) {
|
||||
index_select_shape.push_back(dim);
|
||||
}
|
||||
|
||||
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
|
||||
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
|
||||
1 ||
|
||||
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
|
||||
gather.getLoc(),
|
||||
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
|
||||
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
|
||||
torch_index_select);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeGatherToTorchIndexSelect
|
||||
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
|
||||
/// Perform the lowering of standard dialect operations to approximations.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void PopulateGatherToTorchIndexSelectPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<GatherIsTorchIndexSelect>(context);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
|
||||
"mhlo-legalize-gather-to-torch-index-select",
|
||||
"Legalizes gathers to a torch index select.");
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
|
@ -0,0 +1,41 @@
|
|||
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @gather_to_index_select
|
||||
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
|
||||
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
// CHECK-SAME: batch_dims = 0 : i64,
|
||||
// CHECK-SAME: dim = 0 : i64
|
||||
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %0 : tensor<1x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @scalar_gather_to_index_select
|
||||
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
|
||||
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
// CHECK-SAME: batch_dims = 0 : i64,
|
||||
// CHECK-SAME: dim = 0 : i64
|
||||
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_no_lowering_subslice
|
||||
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
|
||||
// CHECK: "mhlo.gather"
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
|
||||
return %0 : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_no_lowering_multidim
|
||||
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
|
||||
// CHECK: "mhlo.gather"
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
|
||||
return %0 : tensor<1x3x4xf32>
|
||||
}
|
Loading…
Reference in New Issue