Add an optimization that converts some Gathers to Slices.

Some Gathers can be represented as slices. This lowering transforms
these gathers into slices.

PiperOrigin-RevId: 321394868
This commit is contained in:
Robert Suderman 2020-07-15 17:56:59 +00:00 committed by Mehdi Amini
parent 7a6adc6a84
commit 98a1e3b108
4 changed files with 303 additions and 0 deletions

View File

@ -38,6 +38,9 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);

View File

@ -0,0 +1,187 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file provides optional optimization patterns for mhlo, canonocalizing
// operations to equivalent but potentially more efficient operations.
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
using mlir::OwningRewritePatternList;
namespace mlir {
namespace mhlo {
namespace {
// Returns 1D 64-bit dense elements attribute with the given values.
static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder* builder) {
RankedTensorType ty = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(ty, values);
}
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
class GatherIsSlice : public OpRewritePattern<GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter& rewriter) const override {
auto dimension_numbers = gather.dimension_numbers();
// Inputs need to be ranked to lower.
if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
!gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
!gather.start_indices().getType().cast<ShapedType>().hasRank() ||
!gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
return failure();
}
// TODO(suderman): Handle start index map != {0}.
if (!dimension_numbers.start_index_map() ||
dimension_numbers.start_index_map().getType().getRank() != 1 ||
dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
dimension_numbers.start_index_map()
.getValue({0})
.cast<IntegerAttr>()
.getValue() != 0) {
return failure();
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
// Requires a ranked output.
if (!result_ty) {
return failure();
}
if (dimension_numbers.offset_dims().getType().getNumElements() !=
result_ty.getRank()) {
return failure();
}
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
if (it.index() != it.value()) {
return failure();
}
}
// Verify the gather slice sizes are correct.
if (gather.slice_sizes().getNumElements() !=
gather.operand().getType().cast<ShapedType>().getRank()) {
return failure();
}
// Validate the slice sizes are correct.
if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
result_ty.getShape().size() + 1) {
return failure();
}
for (auto it : llvm::enumerate(result_ty.getShape())) {
if (gather.slice_sizes()
.getValue(it.index() + 1)
.cast<IntegerAttr>()
.getValue() != it.value()) {
return failure();
}
}
auto gather_start_indices = gather.start_indices();
auto gather_start_indices_ty =
gather_start_indices.getType().cast<ShapedType>();
llvm::SmallVector<Value, 4> slice_start_indices;
if (gather_start_indices_ty.getRank() == 0) {
slice_start_indices.push_back(gather_start_indices);
} else if (gather_start_indices_ty.getRank() == 1) {
for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
auto start = GetI64ElementsAttr({i}, &rewriter);
auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
auto stride = GetI64ElementsAttr({1}, &rewriter);
auto indicesSlice = rewriter.create<SliceOp>(
gather.getLoc(), gather_start_indices, start, limit, stride);
auto reshaped = rewriter.create<ReshapeOp>(
gather.getLoc(),
RankedTensorType::get(
{}, indicesSlice.getType().cast<ShapedType>().getElementType()),
indicesSlice);
slice_start_indices.push_back(reshaped);
}
} else {
return failure();
}
auto sliceSizes = gather.slice_sizes();
auto sliceSizesTy = sliceSizes.getType();
if (sliceSizesTy.getRank() != 1) {
return failure();
}
// Start indices have implicit zeros when not specified. This is because
// Gather occurs similar to slicing where full slices are inferred. Add any
// missing zeros as necessary.
auto zero = rewriter.create<ConstOp>(
gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
{}, gather_start_indices_ty.getElementType())));
while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
slice_start_indices.push_back(zero);
}
SmallVector<int64_t, 5> sliceShape;
for (auto shapeValue : gather.slice_sizes().getIntValues()) {
sliceShape.push_back(shapeValue.getSExtValue());
}
auto sliceTy =
RankedTensorType::get(sliceShape, result_ty.getElementType());
auto slice = rewriter.create<DynamicSliceOp>(
gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
gather.slice_sizes());
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
return success();
}
};
} // end anonymous namespace
void PopulateOptimizeMHLOPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<GatherIsSlice>(context);
}
} // end namespace mhlo
} // end namespace mlir

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
using mlir::FunctionPass;
using mlir::PassRegistration;
using mlir::PassWrapper;
namespace {
class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
public:
explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
} // end anonymous namespace
// Lowers the complex operations that can be represented using other operations.
void OptimizeMhlo::runOnFunction() {
// Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns;
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
"Run optional HLO optimizations.");

64
tests/optimize-hlo.mlir Normal file
View File

@ -0,0 +1,64 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
// CHECK-LABEL: @gather_is_slice_no_rank
func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
%res = "mhlo.gather"(%arg0, %arg1) {
dimension_numbers = {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
start_index_map = dense<0> : tensor<1xi64>
},
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
} : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
// CHECK: return [[RESHAPE]]
return %res : tensor<1x2xi32>
}
// CHECK-LABEL: @gather_is_slice
func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
%res = "mhlo.gather"(%arg0, %arg1) {
dimension_numbers = {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
start_index_map = dense<0> : tensor<1xi64>
},
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
} : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
// CHECK: return [[RES]]
return %res : tensor<1x2xi32>
}
// CHECK-LABEL: @gather_is_slice_multiple_start_indices
func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
// CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
// CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
// CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
// CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
%res = "mhlo.gather"(%arg0, %arg1) {
dimension_numbers = {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
start_index_map = dense<0> : tensor<1xi64>
},
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
} : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
// CHECK: return [[RES]]
return %res : tensor<1x2xi32>
}