diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 2bbd469..8dec4bb 100644 --- a/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -155,9 +155,16 @@ struct GeneralDotConvert : public OpRewritePattern { dot_numbers.rhs_contracting_dimensions(), /*outer_dims_first=*/false, &rewriter); + // Accept only static shaped types. + auto lhs_shape_type = lhs.getType().dyn_cast_or_null(); + auto rhs_shape_type = rhs.getType().dyn_cast_or_null(); + if (!lhs_shape_type || !rhs_shape_type) return failure(); + if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape()) + return failure(); + // Dot resulting shape. - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); + auto lhs_shape = lhs_shape_type.getShape(); + auto rhs_shape = rhs_shape_type.getShape(); auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);