diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 7f9c129..5365e88 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -387,7 +387,13 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on + mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, + context); + mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context); + shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context); + shape::AssumingOp::getCanonicalizationPatterns(*patterns, context); shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); + shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context); } diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index e50c4ee..4d17904 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -394,7 +394,7 @@ Value MaterializeEqualShapesRankSpecializationCase( Value MaterializeTargetRankSpecializationCase( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, const SmallVector &shapes, int64_t target_rank) { - // Reshape operands to match the target rank. + // Reshape unranked operands to match the target rank. RankedTensorType extent_tensor_ty = shape::getExtentTensorType(b.getContext(), target_rank); Value all_ones_shape = b.create(