Separate CHLO transforms for expanding compositions and lowering broadcasts.
* The former is typically invariant regardless of backend. * The latter may need to be done differently depending on capabilities of the lowering target. PiperOrigin-RevId: 374492924
This commit is contained in:
parent
e0d9e9bffd
commit
0fe07e3814
|
@ -19,8 +19,10 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
|
|||
let summary = "Legalize CHLO to HLO.";
|
||||
let constructor = "createChloLegalizeToHloPass()";
|
||||
let options = [
|
||||
Option<"broadcast_only_", "broadcast-only", "bool",
|
||||
/*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">,
|
||||
Option<"legalize_broadcasts_", "legalize-broadcasts", "bool",
|
||||
/*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">,
|
||||
Option<"expand_compositions_", "expand-compositions", "bool",
|
||||
/*default=*/"true", "Expands client-centric compositions to HLO primitives">,
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
|||
|
||||
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
||||
bool broadcast_only = false);
|
||||
bool legalize_broadcasts = true, bool expand_compositions = true);
|
||||
|
||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
/// buffers if necessary.
|
||||
|
|
|
@ -116,10 +116,11 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
|||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||
// HLO. Includes decomposition of operations and inserting of explicit
|
||||
// broadcasts.
|
||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
// HLO by decomposing client-operations to corresponding sequences of more
|
||||
// primitive operations. This does not include the
|
||||
// PopulateChloBroadcastingPatterns above.
|
||||
void PopulateDecomposeChloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace chlo
|
||||
|
||||
|
|
|
@ -1234,17 +1234,16 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
|||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||
context, patterns, 5);
|
||||
patterns->insert<ConvertSelectOp>(context);
|
||||
patterns->insert<ConvertConstantLikeOp>(context);
|
||||
}
|
||||
|
||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
void PopulateDecomposeChloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
populateWithGenerated(*patterns);
|
||||
PopulateChloBroadcastingPatterns(context, patterns);
|
||||
|
||||
// Other patterns.
|
||||
// clang-format off
|
||||
patterns->insert<ConvertConstantLikeOp,
|
||||
ConvertDigammaOp,
|
||||
patterns->insert<ConvertDigammaOp,
|
||||
ConvertErfOp,
|
||||
ConvertErfcOp,
|
||||
ConvertLgammaOp,
|
||||
|
|
|
@ -31,10 +31,12 @@ namespace {
|
|||
|
||||
struct ChloLegalizeToHloPass
|
||||
: public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> {
|
||||
explicit ChloLegalizeToHloPass(bool broadcast_only)
|
||||
explicit ChloLegalizeToHloPass(bool legalize_broadcasts,
|
||||
bool expand_compositions)
|
||||
: ChloLegalizeToHloPassBase<
|
||||
ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() {
|
||||
this->broadcast_only_ = broadcast_only;
|
||||
this->legalize_broadcasts_ = legalize_broadcasts;
|
||||
this->expand_compositions_ = expand_compositions;
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -53,13 +55,15 @@ struct ChloLegalizeToHloPass
|
|||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
|
||||
|
||||
if (broadcast_only_) {
|
||||
if (legalize_broadcasts_) {
|
||||
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
||||
&conversionPatterns);
|
||||
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
|
||||
}
|
||||
|
||||
if (expand_compositions_) {
|
||||
chlo::PopulateDecomposeChloPatterns(&getContext(), &conversionPatterns);
|
||||
} else {
|
||||
chlo::PopulateLegalizeChloToHloPatterns(&getContext(),
|
||||
&conversionPatterns);
|
||||
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
|
||||
}
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), conversionTarget,
|
||||
|
@ -71,8 +75,10 @@ struct ChloLegalizeToHloPass
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(bool broadcast_only) {
|
||||
return std::make_unique<ChloLegalizeToHloPass>(broadcast_only);
|
||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
||||
bool legalize_broadcasts, bool expand_compositions) {
|
||||
return std::make_unique<ChloLegalizeToHloPass>(legalize_broadcasts,
|
||||
expand_compositions);
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
|
|
|
@ -14,6 +14,10 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
// This is the legalization pattern definition file for CHLO to MHLO.
|
||||
// These are included in the PopulateDecomposeChloPatterns factory
|
||||
// and should only include canonical expansions which are not actually
|
||||
// ambiguous/different for various backends. Avoid patterns that are actually
|
||||
// lowering to non-canonical forms.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// Check the non-broadcast case for each registered op, then just check a
|
||||
// representative op for detailed broadcast semantics.
|
||||
|
|
Loading…
Reference in New Issue