[MHLO] Add pass to move up dynamic broadcasts for fusion

For now, the pass only reifies the required shape computations. Moving
broadcasts will follow to allow for fusion across them.

PiperOrigin-RevId: 362033715
This commit is contained in:
A. Unique TensorFlower 2021-03-10 06:20:43 -08:00 committed by TensorFlow MLIR Team
parent cabd4d9a06
commit c217a6ef61
7 changed files with 181 additions and 0 deletions

24
BUILD
View File

@ -658,6 +658,29 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "move_up_dynamic_broadcasts_for_fusion",
srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
],
deps = [
":hlo",
":map_chlo_to_hlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "lhlo_legalize_to_gpu", name = "lhlo_legalize_to_gpu",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"], srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"],
@ -1025,6 +1048,7 @@ cc_library(
":mhlo_control_flow_to_scf", ":mhlo_control_flow_to_scf",
":mhlo_fusion", ":mhlo_fusion",
":mhlo_to_mhlo_lowering_patterns", ":mhlo_to_mhlo_lowering_patterns",
":move_up_dynamic_broadcasts_for_fusion",
":sink_constants_to_control_flow", ":sink_constants_to_control_flow",
":test_passes", ":test_passes",
":transform_unranked_hlo", ":transform_unranked_hlo",

View File

@ -110,6 +110,12 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
let constructor = "createTransformUnrankedHloPass()"; let constructor = "createTransformUnrankedHloPass()";
} }
def MoveUpDynamicBroadcastsForFusionPass :
Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> {
let summary = "Move up dynamic broadcasts and shape computations to allow "
"for fusion across broadcasts.";
let constructor = "createMoveUpDynamicBroadcastsForFusionPass()";
}
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";

View File

@ -67,6 +67,8 @@ std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass(); createLegalizeTrigonometricToApproximationPass();
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
std::unique_ptr<FunctionPass> createOptimizeMhloPass(); std::unique_ptr<FunctionPass> createOptimizeMhloPass();
std::unique_ptr<FunctionPass> createLowerComplexPass(); std::unique_ptr<FunctionPass> createLowerComplexPass();
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();

View File

@ -79,6 +79,9 @@ void SetupTransformUnrankedHloLegality(MLIRContext *context,
void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
void PopulateDynamicShapeFusionPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Populate a collection of conversion patterns for un-fusing // Populate a collection of conversion patterns for un-fusing
// batch_norm_inference and batch_norm_training into constituent HLO ops. // batch_norm_inference and batch_norm_training into constituent HLO ops.
// TODO(laurenzo): Implement un-fusing of batch_norm_training. // TODO(laurenzo): Implement un-fusing of batch_norm_training.
@ -90,6 +93,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
void PopulateTrigonometricToApproximationPatterns( void PopulateTrigonometricToApproximationPatterns(
MLIRContext *context, OwningRewritePatternList *patterns); MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
} // namespace mhlo } // namespace mhlo
namespace chlo { namespace chlo {

View File

@ -48,6 +48,7 @@ add_mlir_library(ChloPasses
) )
add_mlir_library(MhloPasses add_mlir_library(MhloPasses
move_up_dynamic_broadcasts_for_fusion.cc
legalize_gather_to_torch_index_select.cc legalize_gather_to_torch_index_select.cc
legalize_trigonometric_to_approximation.cc legalize_trigonometric_to_approximation.cc
lower_complex.cc lower_complex.cc

View File

@ -0,0 +1,115 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace mhlo {
namespace {
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
explicit ShapeOfOpConversion(MLIRContext *context)
: OpConversionPattern<shape::ShapeOfOp>(context) {}
LogicalResult matchAndRewrite(
shape::ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::ShapeOfOp::Adaptor transformed(operands);
auto shape_origin = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
transformed.arg().getDefiningOp());
if (!shape_origin) return failure();
llvm::SmallVector<Value, 1> reified_shapes;
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
return failure();
assert(reified_shapes.size() == 1);
Value reified_shape = reified_shapes.front();
if (reified_shape.getType() != op.getType()) {
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
reified_shape);
}
rewriter.replaceOp(op, reified_shapes.front());
return success();
}
};
struct MoveUpDynamicBroadcastsForFusionPass
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}
void runOnFunction() override {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
// Populate rewrite patterns.
OwningRewritePatternList patterns;
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
// Apply transformation.
if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
ConversionTarget *target) {
target->addLegalDialect<MhloDialect, StandardOpsDialect,
tensor::TensorDialect>();
}
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<ShapeOfOpConversion>(context);
// clang-format on
}
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
return std::make_unique<MoveUpDynamicBroadcastsForFusionPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -0,0 +1,25 @@
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s
// Shape computations shall be reified.
// CHECK-LABEL: @shape_of_unary
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
func @shape_of_unary(%arg : tensor<?x32xi16>) {
// CHECK-NOT: shape_of
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%1) : (tensor<?xindex>) -> ()
return
}
// -----
// Shape computations shall be reified.
// CHECK-LABEL: @shape_of_nary
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
// CHECK-NOT: shape_of
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%1) : (tensor<?xindex>) -> ()
return
}