Only apply GeneralDotOpLoweringPatterns for static shaped inputs
PiperOrigin-RevId: 333439680
This commit is contained in:
parent
34c6844dcc
commit
9c6640cbb6
|
@ -155,9 +155,16 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||||
dot_numbers.rhs_contracting_dimensions(),
|
dot_numbers.rhs_contracting_dimensions(),
|
||||||
/*outer_dims_first=*/false, &rewriter);
|
/*outer_dims_first=*/false, &rewriter);
|
||||||
|
|
||||||
|
// Accept only static shaped types.
|
||||||
|
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||||
|
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||||
|
if (!lhs_shape_type || !rhs_shape_type) return failure();
|
||||||
|
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Dot resulting shape.
|
// Dot resulting shape.
|
||||||
auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
|
auto lhs_shape = lhs_shape_type.getShape();
|
||||||
auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
|
auto rhs_shape = rhs_shape_type.getShape();
|
||||||
auto new_dot_type =
|
auto new_dot_type =
|
||||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue