Add a transform for Gathers to torch_index_select.

Some gathers can be interpreted as torch index selects. Transforming these
cases allow torch_index_select lowerings to be used for certain gathers.

PiperOrigin-RevId: 322255835
This commit is contained in:
Robert Suderman 2020-07-20 23:58:31 +00:00 committed by Mehdi Amini
parent cc776071fe
commit c23ad602c8
3 changed files with 197 additions and 0 deletions

View File

@ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context,
void PopulateOptimizeMHLOPatterns(MLIRContext *context, void PopulateOptimizeMHLOPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
// Rewrite patterns for gather to equivalent torch index select legalization.
void PopulateGatherToTorchIndexSelectPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx); MLIRContext *ctx);

View File

@ -0,0 +1,152 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace mhlo {
namespace {
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
auto start_indices = gather.start_indices();
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
if (!start_indices_ty.hasRank()) {
return failure();
}
auto operand = gather.operand();
auto operand_ty = operand.getType().cast<ShapedType>();
if (!operand_ty.hasRank()) {
return failure();
}
int64_t index_vector_dim =
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
// We can use torch_index_select if the last dimension represents the
// gather indices.
auto dimension_numbers = gather.dimension_numbers();
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
index_vector_dim) {
return failure();
}
// Index select only works across a single dimension.
if (!start_indices_ty.getShape().empty() &&
start_indices_ty.getShape().back() != 1) {
return failure();
}
// Only support the default case for start_index_map.
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
dimension_numbers.start_index_map()
.getValue(0)
.cast<IntegerAttr>()
.getValue() != 0) {
return failure();
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return failure();
}
// Offset dimensions should be the defaults.
if (dimension_numbers.offset_dims().getType().getNumElements() !=
result_ty.getRank() - index_vector_dim) {
return failure();
}
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
if ((it.index() + index_vector_dim) != it.value()) {
return failure();
}
}
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
// First shape value must be 1.
if (it.index() == 0) {
if (it.value().getSExtValue() != 1) {
return failure();
}
continue;
}
// The gather needs to index the entire slice for each other dimension.
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
return failure();
}
}
llvm::SmallVector<int64_t, 4> index_select_shape =
llvm::to_vector<4>(start_indices_ty.getShape());
for (auto dim : operand_ty.getShape().drop_front()) {
index_select_shape.push_back(dim);
}
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
1 ||
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
return failure();
}
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
gather.getLoc(),
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
torch_index_select);
return success();
}
};
struct LegalizeGatherToTorchIndexSelect
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
void PopulateGatherToTorchIndexSelectPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<GatherIsTorchIndexSelect>(context);
}
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
"mhlo-legalize-gather-to-torch-index-select",
"Legalizes gathers to a torch index select.");
} // namespace mhlo
} // namespace mlir

View File

@ -0,0 +1,41 @@
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
// CHECK-LABEL: @gather_to_index_select
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
// CHECK-SAME: batch_dims = 0 : i64,
// CHECK-SAME: dim = 0 : i64
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
// CHECK: return [[RES]]
return %0 : tensor<1x3x4xf32>
}
// CHECK-LABEL: @scalar_gather_to_index_select
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
// CHECK-SAME: batch_dims = 0 : i64,
// CHECK-SAME: dim = 0 : i64
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
// CHECK: return [[RES]]
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: @gather_no_lowering_subslice
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
// CHECK: "mhlo.gather"
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
return %0 : tensor<1x3x3xf32>
}
// CHECK-LABEL: @gather_no_lowering_multidim
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
// CHECK: "mhlo.gather"
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
return %0 : tensor<1x3x4xf32>
}