Separate CHLO transforms for expanding compositions and lowering broadcasts.
* The former is typically invariant regardless of backend. * The latter may need to be done differently depending on capabilities of the lowering target. PiperOrigin-RevId: 374492924
This commit is contained in:
parent
e0d9e9bffd
commit
0fe07e3814
|
@ -19,8 +19,10 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
|
||||||
let summary = "Legalize CHLO to HLO.";
|
let summary = "Legalize CHLO to HLO.";
|
||||||
let constructor = "createChloLegalizeToHloPass()";
|
let constructor = "createChloLegalizeToHloPass()";
|
||||||
let options = [
|
let options = [
|
||||||
Option<"broadcast_only_", "broadcast-only", "bool",
|
Option<"legalize_broadcasts_", "legalize-broadcasts", "bool",
|
||||||
/*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">,
|
/*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">,
|
||||||
|
Option<"expand_compositions_", "expand-compositions", "bool",
|
||||||
|
/*default=*/"true", "Expands client-centric compositions to HLO primitives">,
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
/// Lowers from the CHLO dialect to the HLO dialect.
|
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
||||||
bool broadcast_only = false);
|
bool legalize_broadcasts = true, bool expand_compositions = true);
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary.
|
/// buffers if necessary.
|
||||||
|
|
|
@ -116,10 +116,11 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||||
// HLO. Includes decomposition of operations and inserting of explicit
|
// HLO by decomposing client-operations to corresponding sequences of more
|
||||||
// broadcasts.
|
// primitive operations. This does not include the
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
// PopulateChloBroadcastingPatterns above.
|
||||||
OwningRewritePatternList *patterns);
|
void PopulateDecomposeChloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
|
|
||||||
|
|
|
@ -1234,17 +1234,16 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
||||||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||||
context, patterns, 5);
|
context, patterns, 5);
|
||||||
patterns->insert<ConvertSelectOp>(context);
|
patterns->insert<ConvertSelectOp>(context);
|
||||||
|
patterns->insert<ConvertConstantLikeOp>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateDecomposeChloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
populateWithGenerated(*patterns);
|
populateWithGenerated(*patterns);
|
||||||
PopulateChloBroadcastingPatterns(context, patterns);
|
|
||||||
|
|
||||||
// Other patterns.
|
// Other patterns.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<ConvertConstantLikeOp,
|
patterns->insert<ConvertDigammaOp,
|
||||||
ConvertDigammaOp,
|
|
||||||
ConvertErfOp,
|
ConvertErfOp,
|
||||||
ConvertErfcOp,
|
ConvertErfcOp,
|
||||||
ConvertLgammaOp,
|
ConvertLgammaOp,
|
||||||
|
|
|
@ -31,10 +31,12 @@ namespace {
|
||||||
|
|
||||||
struct ChloLegalizeToHloPass
|
struct ChloLegalizeToHloPass
|
||||||
: public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> {
|
: public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> {
|
||||||
explicit ChloLegalizeToHloPass(bool broadcast_only)
|
explicit ChloLegalizeToHloPass(bool legalize_broadcasts,
|
||||||
|
bool expand_compositions)
|
||||||
: ChloLegalizeToHloPassBase<
|
: ChloLegalizeToHloPassBase<
|
||||||
ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() {
|
ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() {
|
||||||
this->broadcast_only_ = broadcast_only;
|
this->legalize_broadcasts_ = legalize_broadcasts;
|
||||||
|
this->expand_compositions_ = expand_compositions;
|
||||||
}
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
@ -53,13 +55,15 @@ struct ChloLegalizeToHloPass
|
||||||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||||
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
|
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
|
||||||
|
|
||||||
if (broadcast_only_) {
|
if (legalize_broadcasts_) {
|
||||||
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
||||||
&conversionPatterns);
|
&conversionPatterns);
|
||||||
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
|
}
|
||||||
|
|
||||||
|
if (expand_compositions_) {
|
||||||
|
chlo::PopulateDecomposeChloPatterns(&getContext(), &conversionPatterns);
|
||||||
} else {
|
} else {
|
||||||
chlo::PopulateLegalizeChloToHloPatterns(&getContext(),
|
conversionTarget.addLegalOp<chlo::ZetaOp, chlo::PolygammaOp>();
|
||||||
&conversionPatterns);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), conversionTarget,
|
if (failed(applyPartialConversion(getOperation(), conversionTarget,
|
||||||
|
@ -71,8 +75,10 @@ struct ChloLegalizeToHloPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(bool broadcast_only) {
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
||||||
return std::make_unique<ChloLegalizeToHloPass>(broadcast_only);
|
bool legalize_broadcasts, bool expand_compositions) {
|
||||||
|
return std::make_unique<ChloLegalizeToHloPass>(legalize_broadcasts,
|
||||||
|
expand_compositions);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
|
|
|
@ -14,6 +14,10 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// This is the legalization pattern definition file for CHLO to MHLO.
|
// This is the legalization pattern definition file for CHLO to MHLO.
|
||||||
|
// These are included in the PopulateDecomposeChloPatterns factory
|
||||||
|
// and should only include canonical expansions which are not actually
|
||||||
|
// ambiguous/different for various backends. Avoid patterns that are actually
|
||||||
|
// lowering to non-canonical forms.
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||||
|
|
||||||
// Check the non-broadcast case for each registered op, then just check a
|
// Check the non-broadcast case for each registered op, then just check a
|
||||||
// representative op for detailed broadcast semantics.
|
// representative op for detailed broadcast semantics.
|
||||||
|
|
Loading…
Reference in New Issue