Add a transform for Gathers to torch_index_select.
Some gathers can be interpreted as torch index selects. Transforming these cases allow torch_index_select lowerings to be used for certain gathers. PiperOrigin-RevId: 322255835
This commit is contained in:
		
							parent
							
								
									cc776071fe
								
							
						
					
					
						commit
						c23ad602c8
					
				|  | @ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context, | ||||||
| void PopulateOptimizeMHLOPatterns(MLIRContext *context, | void PopulateOptimizeMHLOPatterns(MLIRContext *context, | ||||||
|                                   OwningRewritePatternList *patterns); |                                   OwningRewritePatternList *patterns); | ||||||
| 
 | 
 | ||||||
|  | // Rewrite patterns for gather to equivalent torch index select legalization.
 | ||||||
|  | void PopulateGatherToTorchIndexSelectPatterns( | ||||||
|  |     mlir::MLIRContext *context, OwningRewritePatternList *patterns); | ||||||
|  | 
 | ||||||
| void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, | void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, | ||||||
|                                MLIRContext *ctx); |                                MLIRContext *ctx); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,152 @@ | ||||||
|  | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "third_party/absl/memory/memory.h" | ||||||
|  | #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" | ||||||
|  | #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" | ||||||
|  | #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" | ||||||
|  | #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||||
|  | #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" | ||||||
|  | #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | ||||||
|  | 
 | ||||||
|  | namespace mlir { | ||||||
|  | 
 | ||||||
|  | namespace mhlo { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> { | ||||||
|  |   using OpRewritePattern<GatherOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   LogicalResult matchAndRewrite(GatherOp gather, | ||||||
|  |                                 PatternRewriter &rewriter) const override { | ||||||
|  |     auto start_indices = gather.start_indices(); | ||||||
|  |     auto start_indices_ty = start_indices.getType().cast<ShapedType>(); | ||||||
|  |     if (!start_indices_ty.hasRank()) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto operand = gather.operand(); | ||||||
|  |     auto operand_ty = operand.getType().cast<ShapedType>(); | ||||||
|  |     if (!operand_ty.hasRank()) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     int64_t index_vector_dim = | ||||||
|  |         std::max<int64_t>(0, start_indices_ty.getRank() - 1); | ||||||
|  | 
 | ||||||
|  |     // We can use torch_index_select if the last dimension represents the
 | ||||||
|  |     // gather indices.
 | ||||||
|  |     auto dimension_numbers = gather.dimension_numbers(); | ||||||
|  |     if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != | ||||||
|  |         index_vector_dim) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Index select only works across a single dimension.
 | ||||||
|  |     if (!start_indices_ty.getShape().empty() && | ||||||
|  |         start_indices_ty.getShape().back() != 1) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Only support the default case for start_index_map.
 | ||||||
|  |     if (dimension_numbers.start_index_map().getType().getRank() != 1 || | ||||||
|  |         dimension_numbers.start_index_map() | ||||||
|  |                 .getValue(0) | ||||||
|  |                 .cast<IntegerAttr>() | ||||||
|  |                 .getValue() != 0) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>(); | ||||||
|  |     if (!result_ty) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Offset dimensions should be the defaults.
 | ||||||
|  |     if (dimension_numbers.offset_dims().getType().getNumElements() != | ||||||
|  |         result_ty.getRank() - index_vector_dim) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) { | ||||||
|  |       if ((it.index() + index_vector_dim) != it.value()) { | ||||||
|  |         return failure(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) { | ||||||
|  |       // First shape value must be 1.
 | ||||||
|  |       if (it.index() == 0) { | ||||||
|  |         if (it.value().getSExtValue() != 1) { | ||||||
|  |           return failure(); | ||||||
|  |         } | ||||||
|  |         continue; | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       // The gather needs to index the entire slice for each other dimension.
 | ||||||
|  |       if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) { | ||||||
|  |         return failure(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     llvm::SmallVector<int64_t, 4> index_select_shape = | ||||||
|  |         llvm::to_vector<4>(start_indices_ty.getShape()); | ||||||
|  | 
 | ||||||
|  |     for (auto dim : operand_ty.getShape().drop_front()) { | ||||||
|  |       index_select_shape.push_back(dim); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() || | ||||||
|  |         dimension_numbers.collapsed_slice_dims().getType().getNumElements() != | ||||||
|  |             1 || | ||||||
|  |         dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) { | ||||||
|  |       return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto torch_index_select = rewriter.create<TorchIndexSelectOp>( | ||||||
|  |         gather.getLoc(), | ||||||
|  |         RankedTensorType::get(index_select_shape, operand_ty.getElementType()), | ||||||
|  |         operand, gather.start_indices(), rewriter.getI64IntegerAttr(0), | ||||||
|  |         rewriter.getI64IntegerAttr(0)); | ||||||
|  | 
 | ||||||
|  |     rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), | ||||||
|  |                                            torch_index_select); | ||||||
|  | 
 | ||||||
|  |     return success(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct LegalizeGatherToTorchIndexSelect | ||||||
|  |     : public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> { | ||||||
|  |   /// Perform the lowering of standard dialect operations to approximations.
 | ||||||
|  |   void runOnFunction() override { | ||||||
|  |     OwningRewritePatternList patterns; | ||||||
|  |     PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns); | ||||||
|  |     applyPatternsAndFoldGreedily(getFunction(), patterns); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | void PopulateGatherToTorchIndexSelectPatterns( | ||||||
|  |     mlir::MLIRContext *context, OwningRewritePatternList *patterns) { | ||||||
|  |   patterns->insert<GatherIsTorchIndexSelect>(context); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass( | ||||||
|  |     "mhlo-legalize-gather-to-torch-index-select", | ||||||
|  |     "Legalizes gathers to a torch index select."); | ||||||
|  | 
 | ||||||
|  | }  // namespace mhlo
 | ||||||
|  | }  // namespace mlir
 | ||||||
|  | @ -0,0 +1,41 @@ | ||||||
|  | // RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @gather_to_index_select | ||||||
|  | func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> { | ||||||
|  |   // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { | ||||||
|  |   // CHECK-SAME:   batch_dims = 0 : i64, | ||||||
|  |   // CHECK-SAME:   dim = 0 : i64 | ||||||
|  |   // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> | ||||||
|  |   // CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]]) | ||||||
|  |   %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32> | ||||||
|  | 
 | ||||||
|  |   // CHECK: return [[RES]] | ||||||
|  |   return %0 : tensor<1x3x4xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @scalar_gather_to_index_select | ||||||
|  | func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> { | ||||||
|  |   // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { | ||||||
|  |   // CHECK-SAME:   batch_dims = 0 : i64, | ||||||
|  |   // CHECK-SAME:   dim = 0 : i64 | ||||||
|  |   // CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32> | ||||||
|  |   // CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]]) | ||||||
|  |   %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32> | ||||||
|  | 
 | ||||||
|  |   // CHECK: return [[RES]] | ||||||
|  |   return %0 : tensor<1x4xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @gather_no_lowering_subslice | ||||||
|  | func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> { | ||||||
|  |   // CHECK: "mhlo.gather" | ||||||
|  |   %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32> | ||||||
|  |   return %0 : tensor<1x3x3xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @gather_no_lowering_multidim | ||||||
|  | func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> { | ||||||
|  |   // CHECK: "mhlo.gather" | ||||||
|  |   %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32> | ||||||
|  |   return %0 : tensor<1x3x4xf32> | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue