diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 1e99ce0..2657550 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -38,6 +38,9 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, void PopulateComplexLoweringPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +void PopulateOptimizeMHLOPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); diff --git a/lib/Dialect/mhlo/transforms/optimize_mhlo.cc b/lib/Dialect/mhlo/transforms/optimize_mhlo.cc new file mode 100644 index 0000000..0e49c73 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/optimize_mhlo.cc @@ -0,0 +1,187 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides optional optimization patterns for mhlo, canonocalizing +// operations to equivalent but potentially more efficient operations. + +#include +#include +#include +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" + +using mlir::OwningRewritePatternList; + +namespace mlir { +namespace mhlo { +namespace { + +// Returns 1D 64-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +class GatherIsSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter& rewriter) const override { + auto dimension_numbers = gather.dimension_numbers(); + + // Inputs need to be ranked to lower. + if (!gather.operand().getType().cast().hasRank() || + !gather.operand().getType().cast().hasStaticShape() || + !gather.start_indices().getType().cast().hasRank() || + !gather.start_indices().getType().cast().hasStaticShape()) { + return failure(); + } + + if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) { + return failure(); + } + + // TODO(suderman): Handle start index map != {0}. + if (!dimension_numbers.start_index_map() || + dimension_numbers.start_index_map().getType().getRank() != 1 || + dimension_numbers.start_index_map().getType().getDimSize(0) != 1 || + dimension_numbers.start_index_map() + .getValue({0}) + .cast() + .getValue() != 0) { + return failure(); + } + + auto result_ty = gather.getResult().getType().dyn_cast(); + + // Requires a ranked output. + if (!result_ty) { + return failure(); + } + if (dimension_numbers.offset_dims().getType().getNumElements() != + result_ty.getRank()) { + return failure(); + } + for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) { + if (it.index() != it.value()) { + return failure(); + } + } + + // Verify the gather slice sizes are correct. + if (gather.slice_sizes().getNumElements() != + gather.operand().getType().cast().getRank()) { + return failure(); + } + + // Validate the slice sizes are correct. + if (gather.slice_sizes().getType().cast().getNumElements() < + result_ty.getShape().size() + 1) { + return failure(); + } + + for (auto it : llvm::enumerate(result_ty.getShape())) { + if (gather.slice_sizes() + .getValue(it.index() + 1) + .cast() + .getValue() != it.value()) { + return failure(); + } + } + + auto gather_start_indices = gather.start_indices(); + auto gather_start_indices_ty = + gather_start_indices.getType().cast(); + + llvm::SmallVector slice_start_indices; + + if (gather_start_indices_ty.getRank() == 0) { + slice_start_indices.push_back(gather_start_indices); + } else if (gather_start_indices_ty.getRank() == 1) { + for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) { + auto start = GetI64ElementsAttr({i}, &rewriter); + auto limit = GetI64ElementsAttr({i + 1}, &rewriter); + auto stride = GetI64ElementsAttr({1}, &rewriter); + auto indicesSlice = rewriter.create( + gather.getLoc(), gather_start_indices, start, limit, stride); + auto reshaped = rewriter.create( + gather.getLoc(), + RankedTensorType::get( + {}, indicesSlice.getType().cast().getElementType()), + indicesSlice); + slice_start_indices.push_back(reshaped); + } + } else { + return failure(); + } + + auto sliceSizes = gather.slice_sizes(); + auto sliceSizesTy = sliceSizes.getType(); + if (sliceSizesTy.getRank() != 1) { + return failure(); + } + + // Start indices have implicit zeros when not specified. This is because + // Gather occurs similar to slicing where full slices are inferred. Add any + // missing zeros as necessary. + auto zero = rewriter.create( + gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get( + {}, gather_start_indices_ty.getElementType()))); + while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) { + slice_start_indices.push_back(zero); + } + + SmallVector sliceShape; + for (auto shapeValue : gather.slice_sizes().getIntValues()) { + sliceShape.push_back(shapeValue.getSExtValue()); + } + + auto sliceTy = + RankedTensorType::get(sliceShape, result_ty.getElementType()); + auto slice = rewriter.create( + gather.getLoc(), sliceTy, gather.operand(), slice_start_indices, + gather.slice_sizes()); + + rewriter.replaceOpWithNewOp(gather, gather.getType(), slice); + + return success(); + } +}; + +} // end anonymous namespace + +void PopulateOptimizeMHLOPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + patterns->insert(context); +} +} // end namespace mhlo +} // end namespace mlir diff --git a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc new file mode 100644 index 0000000..b4184e2 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +using mlir::FunctionPass; +using mlir::PassRegistration; +using mlir::PassWrapper; + +namespace { +class OptimizeMhlo : public PassWrapper { + public: + explicit OptimizeMhlo() : PassWrapper() {} + + /// Performs the lowering to MHLO dialect. + void runOnFunction() override; +}; +} // end anonymous namespace + +// Lowers the complex operations that can be represented using other operations. +void OptimizeMhlo::runOnFunction() { + // Add lowering patterns to the list. + mlir::OwningRewritePatternList patterns; + mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); + + applyPatternsAndFoldGreedily(getFunction(), patterns); +} + +static PassRegistration pass("mhlo-test-optimize", + "Run optional HLO optimizations."); diff --git a/tests/optimize-hlo.mlir b/tests/optimize-hlo.mlir new file mode 100644 index 0000000..c20de0b --- /dev/null +++ b/tests/optimize-hlo.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s + +// CHECK-LABEL: @gather_is_slice_no_rank +func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor) -> tensor<1x2xi32> { + // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor + // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]]) + %res = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = { + collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1]> : tensor<2xi64>, + start_index_map = dense<0> : tensor<1xi64> + }, + slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> + } : (tensor<2x1x2xi32>, tensor) -> tensor<1x2xi32> + + // CHECK: return [[RESHAPE]] + return %res : tensor<1x2xi32> +} + +// CHECK-LABEL: @gather_is_slice +func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> { + // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor + // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1) + // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]]) + + %res = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = { + collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1]> : tensor<2xi64>, + start_index_map = dense<0> : tensor<1xi64> + }, + slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> + } : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32> + + // CHECK: return [[RES]] + return %res : tensor<1x2xi32> +} + +// CHECK-LABEL: @gather_is_slice_multiple_start_indices +func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> { + // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) + // CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]]) + %res = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = { + collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1]> : tensor<2xi64>, + start_index_map = dense<0> : tensor<1xi64> + }, + slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> + } : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32> + + // CHECK: return [[RES]] + return %res : tensor<1x2xi32> +}