diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 2657550..88a7758 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context, void PopulateOptimizeMHLOPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +// Rewrite patterns for gather to equivalent torch index select legalization. +void PopulateGatherToTorchIndexSelectPatterns( + mlir::MLIRContext *context, OwningRewritePatternList *patterns); + void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); diff --git a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc new file mode 100644 index 0000000..1fb2c13 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { + +namespace mhlo { +namespace { + +struct GatherIsTorchIndexSelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter &rewriter) const override { + auto start_indices = gather.start_indices(); + auto start_indices_ty = start_indices.getType().cast(); + if (!start_indices_ty.hasRank()) { + return failure(); + } + + auto operand = gather.operand(); + auto operand_ty = operand.getType().cast(); + if (!operand_ty.hasRank()) { + return failure(); + } + + int64_t index_vector_dim = + std::max(0, start_indices_ty.getRank() - 1); + + // We can use torch_index_select if the last dimension represents the + // gather indices. + auto dimension_numbers = gather.dimension_numbers(); + if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != + index_vector_dim) { + return failure(); + } + + // Index select only works across a single dimension. + if (!start_indices_ty.getShape().empty() && + start_indices_ty.getShape().back() != 1) { + return failure(); + } + + // Only support the default case for start_index_map. + if (dimension_numbers.start_index_map().getType().getRank() != 1 || + dimension_numbers.start_index_map() + .getValue(0) + .cast() + .getValue() != 0) { + return failure(); + } + + auto result_ty = gather.getResult().getType().dyn_cast(); + if (!result_ty) { + return failure(); + } + + // Offset dimensions should be the defaults. + if (dimension_numbers.offset_dims().getType().getNumElements() != + result_ty.getRank() - index_vector_dim) { + return failure(); + } + + for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) { + if ((it.index() + index_vector_dim) != it.value()) { + return failure(); + } + } + + for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) { + // First shape value must be 1. + if (it.index() == 0) { + if (it.value().getSExtValue() != 1) { + return failure(); + } + continue; + } + + // The gather needs to index the entire slice for each other dimension. + if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) { + return failure(); + } + } + + llvm::SmallVector index_select_shape = + llvm::to_vector<4>(start_indices_ty.getShape()); + + for (auto dim : operand_ty.getShape().drop_front()) { + index_select_shape.push_back(dim); + } + + if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() || + dimension_numbers.collapsed_slice_dims().getType().getNumElements() != + 1 || + dimension_numbers.collapsed_slice_dims().getValue({0}) != 0) { + return failure(); + } + + auto torch_index_select = rewriter.create( + gather.getLoc(), + RankedTensorType::get(index_select_shape, operand_ty.getElementType()), + operand, gather.start_indices(), rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(0)); + + rewriter.replaceOpWithNewOp(gather, gather.getType(), + torch_index_select); + + return success(); + } +}; + +struct LegalizeGatherToTorchIndexSelect + : public PassWrapper { + /// Perform the lowering of standard dialect operations to approximations. + void runOnFunction() override { + OwningRewritePatternList patterns; + PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; +} // namespace + +void PopulateGatherToTorchIndexSelectPatterns( + mlir::MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +static PassRegistration legalize_hlo_pass( + "mhlo-legalize-gather-to-torch-index-select", + "Legalizes gathers to a torch index select."); + +} // namespace mhlo +} // namespace mlir diff --git a/tests/hlo-legalize-gather-to-torch-index-select.mlir b/tests/hlo-legalize-gather-to-torch-index-select.mlir new file mode 100644 index 0000000..ca90a80 --- /dev/null +++ b/tests/hlo-legalize-gather-to-torch-index-select.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s + +// CHECK-LABEL: @gather_to_index_select +func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> { + // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: batch_dims = 0 : i64, + // CHECK-SAME: dim = 0 : i64 + // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + // CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]]) + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32> + + // CHECK: return [[RES]] + return %0 : tensor<1x3x4xf32> +} + +// CHECK-LABEL: @scalar_gather_to_index_select +func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor) -> tensor<1x4xf32> { + // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: batch_dims = 0 : i64, + // CHECK-SAME: dim = 0 : i64 + // CHECK-SAME: } : (tensor<5x4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]]) + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor) -> tensor<1x4xf32> + + // CHECK: return [[RES]] + return %0 : tensor<1x4xf32> +} + +// CHECK-LABEL: @gather_no_lowering_subslice +func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> { + // CHECK: "mhlo.gather" + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32> + return %0 : tensor<1x3x3xf32> +} + +// CHECK-LABEL: @gather_no_lowering_multidim +func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> { + // CHECK: "mhlo.gather" + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32> + return %0 : tensor<1x3x4xf32> +}