Only apply GeneralDotOpLoweringPatterns for static shaped inputs
PiperOrigin-RevId: 333439680
This commit is contained in:
parent
34c6844dcc
commit
9c6640cbb6
|
@ -155,9 +155,16 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
|||
dot_numbers.rhs_contracting_dimensions(),
|
||||
/*outer_dims_first=*/false, &rewriter);
|
||||
|
||||
// Accept only static shaped types.
|
||||
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>();
|
||||
if (!lhs_shape_type || !rhs_shape_type) return failure();
|
||||
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Dot resulting shape.
|
||||
auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
|
||||
auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
|
||||
auto lhs_shape = lhs_shape_type.getShape();
|
||||
auto rhs_shape = rhs_shape_type.getShape();
|
||||
auto new_dot_type =
|
||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||
|
||||
|
|
Loading…
Reference in New Issue