[MLIR][KernelGen] Add MLIR-generated Xlogy kernel
Add the first MLIR-generated kernel that relies on an in-TF lowering. Fusion for this kernel relies on the generalized rank specialization for operation groups. PiperOrigin-RevId: 376805435
This commit is contained in:
parent
6a570502b6
commit
c7c245eaf1
|
@ -387,7 +387,13 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MoveUpBroadcastInDimOpPattern,
|
MoveUpBroadcastInDimOpPattern,
|
||||||
ShapeReificationPattern>(context);
|
ShapeReificationPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
|
||||||
|
context);
|
||||||
|
mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
|
shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
|
shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
|
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
|
shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -394,7 +394,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
|
||||||
Value MaterializeTargetRankSpecializationCase(
|
Value MaterializeTargetRankSpecializationCase(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
const SmallVector<Value, 8> &shapes, int64_t target_rank) {
|
const SmallVector<Value, 8> &shapes, int64_t target_rank) {
|
||||||
// Reshape operands to match the target rank.
|
// Reshape unranked operands to match the target rank.
|
||||||
RankedTensorType extent_tensor_ty =
|
RankedTensorType extent_tensor_ty =
|
||||||
shape::getExtentTensorType(b.getContext(), target_rank);
|
shape::getExtentTensorType(b.getContext(), target_rank);
|
||||||
Value all_ones_shape = b.create<shape::ConstShapeOp>(
|
Value all_ones_shape = b.create<shape::ConstShapeOp>(
|
||||||
|
|
Loading…
Reference in New Issue