From 9c6640cbb63dfbddaf80a9210233f3f0e857d5c0 Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Wed, 23 Sep 2020 21:41:07 -0700 Subject: [PATCH] Only apply GeneralDotOpLoweringPatterns for static shaped inputs PiperOrigin-RevId: 333439680 --- lib/Dialect/mhlo/transforms/lower_general_dot.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 2bbd469..8dec4bb 100644 --- a/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -155,9 +155,16 @@ struct GeneralDotConvert : public OpRewritePattern { dot_numbers.rhs_contracting_dimensions(), /*outer_dims_first=*/false, &rewriter); + // Accept only static shaped types. + auto lhs_shape_type = lhs.getType().dyn_cast_or_null(); + auto rhs_shape_type = rhs.getType().dyn_cast_or_null(); + if (!lhs_shape_type || !rhs_shape_type) return failure(); + if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape()) + return failure(); + // Dot resulting shape. - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); + auto lhs_shape = lhs_shape_type.getShape(); + auto rhs_shape = rhs_shape_type.getShape(); auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);