Separate CHLO transforms for expanding compositions and lowering broadcasts.

* The former is typically invariant regardless of backend.
* The latter may need to be done differently depending on capabilities of the lowering target.

PiperOrigin-RevId: 374492924
This commit is contained in:
Stella Laurenzo 2021-05-18 13:32:58 -07:00 committed by TensorFlow MLIR Team
parent e0d9e9bffd
commit 0fe07e3814
7 changed files with 33 additions and 21 deletions

View File

@ -19,8 +19,10 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
let summary = "Legalize CHLO to HLO.";
let constructor = "createChloLegalizeToHloPass()";
let options = [
Option<"broadcast_only_", "broadcast-only", "bool",
/*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">,
Option<"legalize_broadcasts_", "legalize-broadcasts", "bool",
/*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">,
Option<"expand_compositions_", "expand-compositions", "bool",
/*default=*/"true", "Expands client-centric compositions to HLO primitives">,
];
}

View File

@ -46,7 +46,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
bool broadcast_only = false);
bool legalize_broadcasts = true, bool expand_compositions = true);
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary.

View File

@ -116,10 +116,11 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Populates a collection of conversion patterns for legalizing client-HLO to
// HLO. Includes decomposition of operations and inserting of explicit
// broadcasts.
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// HLO by decomposing client-operations to corresponding sequences of more
// primitive operations. This does not include the
// PopulateChloBroadcastingPatterns above.
void PopulateDecomposeChloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace chlo

View File

@ -1234,17 +1234,16 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
patterns->insert<ConvertSelectOp>(context);
patterns->insert<ConvertConstantLikeOp>(context);
}
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
void PopulateDecomposeChloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(*patterns);
PopulateChloBroadcastingPatterns(context, patterns);
// Other patterns.
// clang-format off
patterns->insert<ConvertConstantLikeOp,
ConvertDigammaOp,
patterns->insert<ConvertDigammaOp,
ConvertErfOp,
ConvertErfcOp,
ConvertLgammaOp,

View File

@ -31,10 +31,12 @@ namespace {
struct ChloLegalizeToHloPass
: public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> {
explicit ChloLegalizeToHloPass(bool broadcast_only)
explicit ChloLegalizeToHloPass(bool legalize_broadcasts,
bool expand_compositions)
: ChloLegalizeToHloPassBase<
ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() {
this->broadcast_only_ = broadcast_only;
this->legalize_broadcasts_ = legalize_broadcasts;
this->expand_compositions_ = expand_compositions;
}
void getDependentDialects(DialectRegistry &registry) const override {
@ -53,13 +55,15 @@ struct ChloLegalizeToHloPass
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
if (broadcast_only_) {
if (legalize_broadcasts_) {
chlo::PopulateChloBroadcastingPatterns(&getContext(),
&conversionPatterns);
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
}
if (expand_compositions_) {
chlo::PopulateDecomposeChloPatterns(&getContext(), &conversionPatterns);
} else {
chlo::PopulateLegalizeChloToHloPatterns(&getContext(),
&conversionPatterns);
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
}
if (failed(applyPartialConversion(getOperation(), conversionTarget,
@ -71,8 +75,10 @@ struct ChloLegalizeToHloPass
} // namespace
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(bool broadcast_only) {
return std::make_unique<ChloLegalizeToHloPass>(broadcast_only);
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
bool legalize_broadcasts, bool expand_compositions) {
return std::make_unique<ChloLegalizeToHloPass>(legalize_broadcasts,
expand_compositions);
}
} // namespace mhlo

View File

@ -14,6 +14,10 @@ limitations under the License.
==============================================================================*/
// This is the legalization pattern definition file for CHLO to MHLO.
// These are included in the PopulateDecomposeChloPatterns factory
// and should only include canonical expansions which are not actually
// ambiguous/different for various backends. Avoid patterns that are actually
// lowering to non-canonical forms.
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.