Add a transform for Gathers to torch_index_select.
Some gathers can be interpreted as torch index selects. Transforming these cases allow torch_index_select lowerings to be used for certain gathers. PiperOrigin-RevId: 322255835
This commit is contained in:
parent
cc776071fe
commit
c23ad602c8
|
@ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||||
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
|
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
// Rewrite patterns for gather to equivalent torch index select legalization.
|
||||||
|
void PopulateGatherToTorchIndexSelectPatterns(
|
||||||
|
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||||
MLIRContext *ctx);
|
MLIRContext *ctx);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
namespace mhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
|
||||||
|
using OpRewritePattern<GatherOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(GatherOp gather,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto start_indices = gather.start_indices();
|
||||||
|
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
|
||||||
|
if (!start_indices_ty.hasRank()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto operand = gather.operand();
|
||||||
|
auto operand_ty = operand.getType().cast<ShapedType>();
|
||||||
|
if (!operand_ty.hasRank()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t index_vector_dim =
|
||||||
|
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
|
||||||
|
|
||||||
|
// We can use torch_index_select if the last dimension represents the
|
||||||
|
// gather indices.
|
||||||
|
auto dimension_numbers = gather.dimension_numbers();
|
||||||
|
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
|
||||||
|
index_vector_dim) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Index select only works across a single dimension.
|
||||||
|
if (!start_indices_ty.getShape().empty() &&
|
||||||
|
start_indices_ty.getShape().back() != 1) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only support the default case for start_index_map.
|
||||||
|
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
|
||||||
|
dimension_numbers.start_index_map()
|
||||||
|
.getValue(0)
|
||||||
|
.cast<IntegerAttr>()
|
||||||
|
.getValue() != 0) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!result_ty) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset dimensions should be the defaults.
|
||||||
|
if (dimension_numbers.offset_dims().getType().getNumElements() !=
|
||||||
|
result_ty.getRank() - index_vector_dim) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
|
||||||
|
if ((it.index() + index_vector_dim) != it.value()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
|
||||||
|
// First shape value must be 1.
|
||||||
|
if (it.index() == 0) {
|
||||||
|
if (it.value().getSExtValue() != 1) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The gather needs to index the entire slice for each other dimension.
|
||||||
|
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 4> index_select_shape =
|
||||||
|
llvm::to_vector<4>(start_indices_ty.getShape());
|
||||||
|
|
||||||
|
for (auto dim : operand_ty.getShape().drop_front()) {
|
||||||
|
index_select_shape.push_back(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
|
||||||
|
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
|
||||||
|
1 ||
|
||||||
|
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
|
||||||
|
gather.getLoc(),
|
||||||
|
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
|
||||||
|
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
|
||||||
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
|
||||||
|
torch_index_select);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LegalizeGatherToTorchIndexSelect
|
||||||
|
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
|
||||||
|
/// Perform the lowering of standard dialect operations to approximations.
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void PopulateGatherToTorchIndexSelectPatterns(
|
||||||
|
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
|
patterns->insert<GatherIsTorchIndexSelect>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
|
||||||
|
"mhlo-legalize-gather-to-torch-index-select",
|
||||||
|
"Legalizes gathers to a torch index select.");
|
||||||
|
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,41 @@
|
||||||
|
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @gather_to_index_select
|
||||||
|
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
|
||||||
|
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
|
// CHECK-SAME: batch_dims = 0 : i64,
|
||||||
|
// CHECK-SAME: dim = 0 : i64
|
||||||
|
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||||
|
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RES]]
|
||||||
|
return %0 : tensor<1x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @scalar_gather_to_index_select
|
||||||
|
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
|
||||||
|
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||||
|
// CHECK-SAME: batch_dims = 0 : i64,
|
||||||
|
// CHECK-SAME: dim = 0 : i64
|
||||||
|
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||||
|
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RES]]
|
||||||
|
return %0 : tensor<1x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @gather_no_lowering_subslice
|
||||||
|
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
|
||||||
|
// CHECK: "mhlo.gather"
|
||||||
|
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
|
||||||
|
return %0 : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @gather_no_lowering_multidim
|
||||||
|
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
|
||||||
|
// CHECK: "mhlo.gather"
|
||||||
|
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
|
||||||
|
return %0 : tensor<1x3x4xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue