Add an optimization that converts some Gathers to Slices.
Some Gathers can be represented as slices. This lowering transforms these gathers into slices. PiperOrigin-RevId: 321394868
This commit is contained in:
parent
7a6adc6a84
commit
98a1e3b108
|
@ -38,6 +38,9 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
|||
void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file provides optional optimization patterns for mhlo, canonocalizing
|
||||
// operations to equivalent but potentially more efficient operations.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
|
||||
using mlir::OwningRewritePatternList;
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// Returns 1D 64-bit dense elements attribute with the given values.
|
||||
static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
|
||||
Builder* builder) {
|
||||
RankedTensorType ty = RankedTensorType::get(
|
||||
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(ty, values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class GatherIsSlice : public OpRewritePattern<GatherOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(GatherOp gather,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto dimension_numbers = gather.dimension_numbers();
|
||||
|
||||
// Inputs need to be ranked to lower.
|
||||
if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
|
||||
!gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!gather.start_indices().getType().cast<ShapedType>().hasRank() ||
|
||||
!gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// TODO(suderman): Handle start index map != {0}.
|
||||
if (!dimension_numbers.start_index_map() ||
|
||||
dimension_numbers.start_index_map().getType().getRank() != 1 ||
|
||||
dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
|
||||
dimension_numbers.start_index_map()
|
||||
.getValue({0})
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// Requires a ranked output.
|
||||
if (!result_ty) {
|
||||
return failure();
|
||||
}
|
||||
if (dimension_numbers.offset_dims().getType().getNumElements() !=
|
||||
result_ty.getRank()) {
|
||||
return failure();
|
||||
}
|
||||
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
|
||||
if (it.index() != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the gather slice sizes are correct.
|
||||
if (gather.slice_sizes().getNumElements() !=
|
||||
gather.operand().getType().cast<ShapedType>().getRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Validate the slice sizes are correct.
|
||||
if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
|
||||
result_ty.getShape().size() + 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(result_ty.getShape())) {
|
||||
if (gather.slice_sizes()
|
||||
.getValue(it.index() + 1)
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto gather_start_indices = gather.start_indices();
|
||||
auto gather_start_indices_ty =
|
||||
gather_start_indices.getType().cast<ShapedType>();
|
||||
|
||||
llvm::SmallVector<Value, 4> slice_start_indices;
|
||||
|
||||
if (gather_start_indices_ty.getRank() == 0) {
|
||||
slice_start_indices.push_back(gather_start_indices);
|
||||
} else if (gather_start_indices_ty.getRank() == 1) {
|
||||
for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
|
||||
auto start = GetI64ElementsAttr({i}, &rewriter);
|
||||
auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
|
||||
auto stride = GetI64ElementsAttr({1}, &rewriter);
|
||||
auto indicesSlice = rewriter.create<SliceOp>(
|
||||
gather.getLoc(), gather_start_indices, start, limit, stride);
|
||||
auto reshaped = rewriter.create<ReshapeOp>(
|
||||
gather.getLoc(),
|
||||
RankedTensorType::get(
|
||||
{}, indicesSlice.getType().cast<ShapedType>().getElementType()),
|
||||
indicesSlice);
|
||||
slice_start_indices.push_back(reshaped);
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto sliceSizes = gather.slice_sizes();
|
||||
auto sliceSizesTy = sliceSizes.getType();
|
||||
if (sliceSizesTy.getRank() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Start indices have implicit zeros when not specified. This is because
|
||||
// Gather occurs similar to slicing where full slices are inferred. Add any
|
||||
// missing zeros as necessary.
|
||||
auto zero = rewriter.create<ConstOp>(
|
||||
gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
|
||||
{}, gather_start_indices_ty.getElementType())));
|
||||
while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
|
||||
slice_start_indices.push_back(zero);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 5> sliceShape;
|
||||
for (auto shapeValue : gather.slice_sizes().getIntValues()) {
|
||||
sliceShape.push_back(shapeValue.getSExtValue());
|
||||
}
|
||||
|
||||
auto sliceTy =
|
||||
RankedTensorType::get(sliceShape, result_ty.getElementType());
|
||||
auto slice = rewriter.create<DynamicSliceOp>(
|
||||
gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
|
||||
gather.slice_sizes());
|
||||
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void PopulateOptimizeMHLOPatterns(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<GatherIsSlice>(context);
|
||||
}
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,49 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::PassRegistration;
|
||||
using mlir::PassWrapper;
|
||||
|
||||
namespace {
|
||||
class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
|
||||
public:
|
||||
explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
|
||||
|
||||
/// Performs the lowering to MHLO dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void OptimizeMhlo::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
|
||||
"Run optional HLO optimizations.");
|
|
@ -0,0 +1,64 @@
|
|||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @gather_is_slice_no_rank
|
||||
func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
|
||||
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
|
||||
%res = "mhlo.gather"(%arg0, %arg1) {
|
||||
dimension_numbers = {
|
||||
collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||
index_vector_dim = 0 : i64,
|
||||
offset_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
start_index_map = dense<0> : tensor<1xi64>
|
||||
},
|
||||
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
|
||||
} : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %res : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_is_slice
|
||||
func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
|
||||
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
|
||||
|
||||
%res = "mhlo.gather"(%arg0, %arg1) {
|
||||
dimension_numbers = {
|
||||
collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||
index_vector_dim = 0 : i64,
|
||||
offset_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
start_index_map = dense<0> : tensor<1xi64>
|
||||
},
|
||||
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
|
||||
} : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %res : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_is_slice_multiple_start_indices
|
||||
func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
|
||||
// CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
|
||||
// CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
|
||||
// CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
|
||||
%res = "mhlo.gather"(%arg0, %arg1) {
|
||||
dimension_numbers = {
|
||||
collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||
index_vector_dim = 0 : i64,
|
||||
offset_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
start_index_map = dense<0> : tensor<1xi64>
|
||||
},
|
||||
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
|
||||
} : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %res : tensor<1x2xi32>
|
||||
}
|
Loading…
Reference in New Issue