[MHLO] Add pass to move up dynamic broadcasts for fusion
For now, the pass only reifies the required shape computations. Moving broadcasts will follow to allow for fusion across them. PiperOrigin-RevId: 362033715
This commit is contained in:
parent
cabd4d9a06
commit
c217a6ef61
24
BUILD
24
BUILD
|
@ -658,6 +658,29 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "move_up_dynamic_broadcasts_for_fusion",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":map_chlo_to_hlo_op",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_legalize_to_gpu",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc"],
|
||||
|
@ -1025,6 +1048,7 @@ cc_library(
|
|||
":mhlo_control_flow_to_scf",
|
||||
":mhlo_fusion",
|
||||
":mhlo_to_mhlo_lowering_patterns",
|
||||
":move_up_dynamic_broadcasts_for_fusion",
|
||||
":sink_constants_to_control_flow",
|
||||
":test_passes",
|
||||
":transform_unranked_hlo",
|
||||
|
|
|
@ -110,6 +110,12 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
|
|||
let constructor = "createTransformUnrankedHloPass()";
|
||||
}
|
||||
|
||||
def MoveUpDynamicBroadcastsForFusionPass :
|
||||
Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> {
|
||||
let summary = "Move up dynamic broadcasts and shape computations to allow "
|
||||
"for fusion across broadcasts.";
|
||||
let constructor = "createMoveUpDynamicBroadcastsForFusionPass()";
|
||||
}
|
||||
|
||||
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
||||
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
|
||||
|
|
|
@ -67,6 +67,8 @@ std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLegalizeTrigonometricToApproximationPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||
std::unique_ptr<FunctionPass> createLowerComplexPass();
|
||||
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();
|
||||
|
|
|
@ -79,6 +79,9 @@ void SetupTransformUnrankedHloLegality(MLIRContext *context,
|
|||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateDynamicShapeFusionPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Populate a collection of conversion patterns for un-fusing
|
||||
// batch_norm_inference and batch_norm_training into constituent HLO ops.
|
||||
// TODO(laurenzo): Implement un-fusing of batch_norm_training.
|
||||
|
@ -90,6 +93,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
|||
void PopulateTrigonometricToApproximationPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace chlo {
|
||||
|
|
|
@ -48,6 +48,7 @@ add_mlir_library(ChloPasses
|
|||
)
|
||||
|
||||
add_mlir_library(MhloPasses
|
||||
move_up_dynamic_broadcasts_for_fusion.cc
|
||||
legalize_gather_to_torch_index_select.cc
|
||||
legalize_trigonometric_to_approximation.cc
|
||||
lower_complex.cc
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||
explicit ShapeOfOpConversion(MLIRContext *context)
|
||||
: OpConversionPattern<shape::ShapeOfOp>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
shape::ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::ShapeOfOp::Adaptor transformed(operands);
|
||||
|
||||
auto shape_origin = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
|
||||
transformed.arg().getDefiningOp());
|
||||
if (!shape_origin) return failure();
|
||||
|
||||
llvm::SmallVector<Value, 1> reified_shapes;
|
||||
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
|
||||
return failure();
|
||||
|
||||
assert(reified_shapes.size() == 1);
|
||||
Value reified_shape = reified_shapes.front();
|
||||
if (reified_shape.getType() != op.getType()) {
|
||||
reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
|
||||
reified_shape);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reified_shapes.front());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MoveUpDynamicBroadcastsForFusionPass
|
||||
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
// Setup target legality.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
|
||||
|
||||
// Apply transformation.
|
||||
if (failed(applyPartialConversion(getFunction(), target,
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
|
||||
ConversionTarget *target) {
|
||||
target->addLegalDialect<MhloDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<ShapeOfOpConversion>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
|
||||
return std::make_unique<MoveUpDynamicBroadcastsForFusionPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
|
@ -0,0 +1,25 @@
|
|||
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s
|
||||
|
||||
// Shape computations shall be reified.
|
||||
// CHECK-LABEL: @shape_of_unary
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
|
||||
func @shape_of_unary(%arg : tensor<?x32xi16>) {
|
||||
// CHECK-NOT: shape_of
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
|
||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%1) : (tensor<?xindex>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Shape computations shall be reified.
|
||||
// CHECK-LABEL: @shape_of_nary
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>)
|
||||
func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
||||
// CHECK-NOT: shape_of
|
||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%1) : (tensor<?xindex>) -> ()
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue