Separate CHLO transforms for expanding compositions and lowering broadcasts.

* The former is typically invariant regardless of backend.
* The latter may need to be done differently depending on capabilities of the lowering target.

PiperOrigin-RevId: 374492924
This commit is contained in:
Stella Laurenzo 2021-05-18 13:32:58 -07:00 committed by TensorFlow MLIR Team
parent e0d9e9bffd
commit 0fe07e3814
7 changed files with 33 additions and 21 deletions

View File

@ -19,8 +19,10 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
let summary = "Legalize CHLO to HLO."; let summary = "Legalize CHLO to HLO.";
let constructor = "createChloLegalizeToHloPass()"; let constructor = "createChloLegalizeToHloPass()";
let options = [ let options = [
Option<"broadcast_only_", "broadcast-only", "bool", Option<"legalize_broadcasts_", "legalize-broadcasts", "bool",
/*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">, /*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">,
Option<"expand_compositions_", "expand-compositions", "bool",
/*default=*/"true", "Expands client-centric compositions to HLO primitives">,
]; ];
} }

View File

@ -46,7 +46,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect. /// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass( std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
bool broadcast_only = false); bool legalize_broadcasts = true, bool expand_compositions = true);
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. /// buffers if necessary.

View File

@ -116,10 +116,11 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
// Populates a collection of conversion patterns for legalizing client-HLO to // Populates a collection of conversion patterns for legalizing client-HLO to
// HLO. Includes decomposition of operations and inserting of explicit // HLO by decomposing client-operations to corresponding sequences of more
// broadcasts. // primitive operations. This does not include the
void PopulateLegalizeChloToHloPatterns(MLIRContext *context, // PopulateChloBroadcastingPatterns above.
OwningRewritePatternList *patterns); void PopulateDecomposeChloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace chlo } // namespace chlo

View File

@ -1234,17 +1234,16 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5); context, patterns, 5);
patterns->insert<ConvertSelectOp>(context); patterns->insert<ConvertSelectOp>(context);
patterns->insert<ConvertConstantLikeOp>(context);
} }
void PopulateLegalizeChloToHloPatterns(MLIRContext *context, void PopulateDecomposeChloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
populateWithGenerated(*patterns); populateWithGenerated(*patterns);
PopulateChloBroadcastingPatterns(context, patterns);
// Other patterns. // Other patterns.
// clang-format off // clang-format off
patterns->insert<ConvertConstantLikeOp, patterns->insert<ConvertDigammaOp,
ConvertDigammaOp,
ConvertErfOp, ConvertErfOp,
ConvertErfcOp, ConvertErfcOp,
ConvertLgammaOp, ConvertLgammaOp,

View File

@ -31,10 +31,12 @@ namespace {
struct ChloLegalizeToHloPass struct ChloLegalizeToHloPass
: public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> { : public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> {
explicit ChloLegalizeToHloPass(bool broadcast_only) explicit ChloLegalizeToHloPass(bool legalize_broadcasts,
bool expand_compositions)
: ChloLegalizeToHloPassBase< : ChloLegalizeToHloPassBase<
ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() { ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() {
this->broadcast_only_ = broadcast_only; this->legalize_broadcasts_ = legalize_broadcasts;
this->expand_compositions_ = expand_compositions;
} }
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
@ -53,13 +55,15 @@ struct ChloLegalizeToHloPass
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>(); conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
if (broadcast_only_) { if (legalize_broadcasts_) {
chlo::PopulateChloBroadcastingPatterns(&getContext(), chlo::PopulateChloBroadcastingPatterns(&getContext(),
&conversionPatterns); &conversionPatterns);
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>(); }
if (expand_compositions_) {
chlo::PopulateDecomposeChloPatterns(&getContext(), &conversionPatterns);
} else { } else {
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
&conversionPatterns);
} }
if (failed(applyPartialConversion(getOperation(), conversionTarget, if (failed(applyPartialConversion(getOperation(), conversionTarget,
@ -71,8 +75,10 @@ struct ChloLegalizeToHloPass
} // namespace } // namespace
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(bool broadcast_only) { std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
return std::make_unique<ChloLegalizeToHloPass>(broadcast_only); bool legalize_broadcasts, bool expand_compositions) {
return std::make_unique<ChloLegalizeToHloPass>(legalize_broadcasts,
expand_compositions);
} }
} // namespace mhlo } // namespace mhlo

View File

@ -14,6 +14,10 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// This is the legalization pattern definition file for CHLO to MHLO. // This is the legalization pattern definition file for CHLO to MHLO.
// These are included in the PopulateDecomposeChloPatterns factory
// and should only include canonical expansions which are not actually
// ambiguous/different for various backends. Avoid patterns that are actually
// lowering to non-canonical forms.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s // RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a // Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics. // representative op for detailed broadcast semantics.