diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td index f8e02d4..280c0a1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td @@ -50,7 +50,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*args=*/(ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ - /// Return whether this op can be fused withh its consumers + /// Return whether this op can be fused with its consumers return true; }] >, @@ -64,21 +64,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether two inputs have the same shape. Operation *op = this->getOperation(); - assert(lhs < op->getNumOperands() && lhs >= 0 && - rhs < op->getNumOperands() && rhs >= 0); + assert(lhs >= 0 && rhs >= 0); if (lhs == rhs) return true; - - // if both lhs and rhs have static shapes, check them directly - Type lhs_ty = op->getOperand(lhs).getType(); - Type rhs_ty = op->getOperand(rhs).getType(); - auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); - auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); - if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || - !rhs_shape_type || !rhs_shape_type.hasStaticShape() || - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - return false; - } - return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + return inferShapeEquality(op->getOperand(lhs), op->getOperand(rhs)); }] >, InterfaceMethod< @@ -91,21 +79,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether two outputs have the same shape. Operation *op = this->getOperation(); - assert(lhs < op->getNumResults() && lhs >= 0 && - rhs < op->getNumResults() && rhs >= 0); + assert(lhs >= 0 && rhs >= 0); if (lhs == rhs) return true; - - // if both lhs and rhs have static shapes, check them directly - Type lhs_ty = op->getResult(lhs).getType(); - Type rhs_ty = op->getResult(rhs).getType(); - auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); - auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); - if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || - !rhs_shape_type || !rhs_shape_type.hasStaticShape() || - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - return false; - } - return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + return inferShapeEquality(op->getResult(lhs), op->getResult(rhs)); }] >, InterfaceMethod< @@ -118,20 +94,8 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether the input and the output have the same shape. Operation *op = this->getOperation(); - assert(input < op->getNumOperands() && input >= 0 && - output < op->getNumResults() && output >= 0); - - // if both input and output have static shapes, check them directly - Type input_ty = op->getOperand(input).getType(); - Type output_ty = op->getResult(output).getType(); - auto input_shape_type = input_ty.dyn_cast_or_null(); - auto output_shape_type = output_ty.dyn_cast_or_null(); - if (!input_shape_type || !input_shape_type.hasStaticShape() || - !output_shape_type || !output_shape_type.hasStaticShape() || - input_shape_type.getRank() != output_shape_type.getRank()) { - return false; - } - return input_shape_type.getShape() == output_shape_type.getShape(); + assert(input >= 0 && output >= 0); + return inferShapeEquality(op->getOperand(input), op->getResult(output)); }] >, InterfaceMethod< @@ -156,6 +120,21 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { }] >, ]; + + let extraClassDeclaration = [{ + // Returns whether the given values have the same static shape. + static bool inferShapeEquality(Value first, Value second) { + // If both lhs and rhs have static shapes, check them directly. + auto first_ty = first.getType().dyn_cast(); + auto second_ty = second.getType().dyn_cast(); + if (!first_ty || !first_ty.hasStaticShape() || + !second_ty || !second_ty.hasStaticShape() || + first_ty.getRank() != second_ty.getRank()) { + return false; + } + return first_ty.getShape() == second_ty.getShape(); + } + }]; } #endif