diff --git a/BUILD b/BUILD index a030db2..041ed1c 100644 --- a/BUILD +++ b/BUILD @@ -658,6 +658,29 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "move_up_dynamic_broadcasts_for_fusion", + srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], + deps = [ + ":hlo", + ":map_chlo_to_hlo_op", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "lhlo_legalize_to_gpu", srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"], @@ -1025,6 +1048,7 @@ cc_library( ":mhlo_control_flow_to_scf", ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", + ":move_up_dynamic_broadcasts_for_fusion", ":sink_constants_to_control_flow", ":test_passes", ":transform_unranked_hlo", diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 18171a2..34e7722 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -110,6 +110,12 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> { let constructor = "createTransformUnrankedHloPass()"; } +def MoveUpDynamicBroadcastsForFusionPass : + Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> { + let summary = "Move up dynamic broadcasts and shape computations to allow " + "for fusion across broadcasts."; + let constructor = "createMoveUpDynamicBroadcastsForFusionPass()"; +} def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 26185cd..82b3d1d 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -67,6 +67,8 @@ std::unique_ptr> createMhloFusionPass(); std::unique_ptr> createLegalizeTrigonometricToApproximationPass(); +std::unique_ptr createMoveUpDynamicBroadcastsForFusionPass(); + std::unique_ptr createOptimizeMhloPass(); std::unique_ptr createLowerComplexPass(); std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index bcd59b5..2c7c1cf 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -79,6 +79,9 @@ void SetupTransformUnrankedHloLegality(MLIRContext *context, void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +void PopulateDynamicShapeFusionPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + // Populate a collection of conversion patterns for un-fusing // batch_norm_inference and batch_norm_training into constituent HLO ops. // TODO(laurenzo): Implement un-fusing of batch_norm_training. @@ -90,6 +93,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context, void PopulateTrigonometricToApproximationPatterns( MLIRContext *context, OwningRewritePatternList *patterns); +void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target); + +void PopulateMoveUpDynamicBroadcastsForFusionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + } // namespace mhlo namespace chlo { diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index d200be6..f27b7a1 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -48,6 +48,7 @@ add_mlir_library(ChloPasses ) add_mlir_library(MhloPasses + move_up_dynamic_broadcasts_for_fusion.cc legalize_gather_to_torch_index_select.cc legalize_trigonometric_to_approximation.cc lower_complex.cc diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc new file mode 100644 index 0000000..31e6f22 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -0,0 +1,115 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace mhlo { +namespace { + +struct ShapeOfOpConversion : public OpConversionPattern { + explicit ShapeOfOpConversion(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult matchAndRewrite( + shape::ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::ShapeOfOp::Adaptor transformed(operands); + + auto shape_origin = llvm::dyn_cast_or_null( + transformed.arg().getDefiningOp()); + if (!shape_origin) return failure(); + + llvm::SmallVector reified_shapes; + if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes))) + return failure(); + + assert(reified_shapes.size() == 1); + Value reified_shape = reified_shapes.front(); + if (reified_shape.getType() != op.getType()) { + reified_shape = rewriter.create(op.getLoc(), op.getType(), + reified_shape); + } + + rewriter.replaceOp(op, reified_shapes.front()); + return success(); + } +}; + +struct MoveUpDynamicBroadcastsForFusionPass + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + // Setup target legality. + MLIRContext &ctx = getContext(); + ConversionTarget target(ctx); + PopulateMoveUpDynamicBroadcastsForFusionLegality(&target); + + // Populate rewrite patterns. + OwningRewritePatternList patterns; + mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns); + + // Apply transformation. + if (failed(applyPartialConversion(getFunction(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +void PopulateMoveUpDynamicBroadcastsForFusionLegality( + ConversionTarget *target) { + target->addLegalDialect(); +} + +void PopulateMoveUpDynamicBroadcastsForFusionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + // clang-format off + patterns->insert(context); + // clang-format on +} + +std::unique_ptr createMoveUpDynamicBroadcastsForFusionPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir new file mode 100644 index 0000000..a07210c --- /dev/null +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s + +// Shape computations shall be reified. +// CHECK-LABEL: @shape_of_unary +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @shape_of_unary(%arg : tensor) { + // CHECK-NOT: shape_of + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor + "use"(%1) : (tensor) -> () + return +} + +// ----- + +// Shape computations shall be reified. +// CHECK-LABEL: @shape_of_nary +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { + // CHECK-NOT: shape_of + %0 = mhlo.subtract %arg0, %arg1 : tensor + %1 = shape.shape_of %0 : tensor -> tensor + "use"(%1) : (tensor) -> () + return +}