diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 8bd261d..9f739d5 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -19,8 +19,10 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> { let summary = "Legalize CHLO to HLO."; let constructor = "createChloLegalizeToHloPass()"; let options = [ - Option<"broadcast_only_", "broadcast-only", "bool", - /*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">, + Option<"legalize_broadcasts_", "legalize-broadcasts", "bool", + /*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">, + Option<"expand_compositions_", "expand-compositions", "bool", + /*default=*/"true", "Expands client-centric compositions to HLO primitives">, ]; } diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index dcdfe11..0061bc5 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -46,7 +46,7 @@ std::unique_ptr> createLegalizeToStdPass(); /// Lowers from the CHLO dialect to the HLO dialect. std::unique_ptr createChloLegalizeToHloPass( - bool broadcast_only = false); + bool legalize_broadcasts = true, bool expand_compositions = true); /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// buffers if necessary. diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index a86c56d..2f7aa80 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -116,10 +116,11 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context, OwningRewritePatternList *patterns); // Populates a collection of conversion patterns for legalizing client-HLO to -// HLO. Includes decomposition of operations and inserting of explicit -// broadcasts. -void PopulateLegalizeChloToHloPatterns(MLIRContext *context, - OwningRewritePatternList *patterns); +// HLO by decomposing client-operations to corresponding sequences of more +// primitive operations. This does not include the +// PopulateChloBroadcastingPatterns above. +void PopulateDecomposeChloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); } // namespace chlo diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 916dc01..bc5d8d2 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -1234,17 +1234,16 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context, PopulateForBroadcastingBinaryOp( context, patterns, 5); patterns->insert(context); + patterns->insert(context); } -void PopulateLegalizeChloToHloPatterns(MLIRContext *context, - OwningRewritePatternList *patterns) { +void PopulateDecomposeChloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { populateWithGenerated(*patterns); - PopulateChloBroadcastingPatterns(context, patterns); // Other patterns. // clang-format off - patterns->insertinsert { - explicit ChloLegalizeToHloPass(bool broadcast_only) + explicit ChloLegalizeToHloPass(bool legalize_broadcasts, + bool expand_compositions) : ChloLegalizeToHloPassBase< ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() { - this->broadcast_only_ = broadcast_only; + this->legalize_broadcasts_ = legalize_broadcasts; + this->expand_compositions_ = expand_compositions; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -53,13 +55,15 @@ struct ChloLegalizeToHloPass mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); conversionTarget.addLegalOp(); - if (broadcast_only_) { + if (legalize_broadcasts_) { chlo::PopulateChloBroadcastingPatterns(&getContext(), &conversionPatterns); - conversionTarget.addLegalOp(); + } + + if (expand_compositions_) { + chlo::PopulateDecomposeChloPatterns(&getContext(), &conversionPatterns); } else { - chlo::PopulateLegalizeChloToHloPatterns(&getContext(), - &conversionPatterns); + conversionTarget.addLegalOp(); } if (failed(applyPartialConversion(getOperation(), conversionTarget, @@ -71,8 +75,10 @@ struct ChloLegalizeToHloPass } // namespace -std::unique_ptr createChloLegalizeToHloPass(bool broadcast_only) { - return std::make_unique(broadcast_only); +std::unique_ptr createChloLegalizeToHloPass( + bool legalize_broadcasts, bool expand_compositions) { + return std::make_unique(legalize_broadcasts, + expand_compositions); } } // namespace mhlo diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index f585a77..5266686 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -14,6 +14,10 @@ limitations under the License. ==============================================================================*/ // This is the legalization pattern definition file for CHLO to MHLO. +// These are included in the PopulateDecomposeChloPatterns factory +// and should only include canonical expansions which are not actually +// ambiguous/different for various backends. Avoid patterns that are actually +// lowering to non-canonical forms. include "mlir/IR/OpBase.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index 1d35ccf..abcc327 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics.