Only apply GeneralDotOpLoweringPatterns for static shaped inputs

PiperOrigin-RevId: 333439680
This commit is contained in:
Ahmed S. Taei 2020-09-23 21:41:07 -07:00 committed by TensorFlow MLIR Team
parent 34c6844dcc
commit 9c6640cbb6
1 changed files with 9 additions and 2 deletions

View File

@ -155,9 +155,16 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
dot_numbers.rhs_contracting_dimensions(),
/*outer_dims_first=*/false, &rewriter);
// Accept only static shaped types.
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>();
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>();
if (!lhs_shape_type || !rhs_shape_type) return failure();
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
return failure();
// Dot resulting shape.
auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
auto lhs_shape = lhs_shape_type.getShape();
auto rhs_shape = rhs_shape_type.getShape();
auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);