From 8252eafa9928f9608776f90706b502e0553124cd Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 22 Dec 2020 12:03:07 -0800 Subject: [PATCH] [NFC] Factor out repeated code out of InferFusibilityOpInterface. PiperOrigin-RevId: 348671671 --- .../mhlo/IR/infer_fusibility_op_interface.td | 65 +++++++------------ 1 file changed, 22 insertions(+), 43 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td index f8e02d4..280c0a1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td @@ -50,7 +50,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*args=*/(ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ - /// Return whether this op can be fused withh its consumers + /// Return whether this op can be fused with its consumers return true; }] >, @@ -64,21 +64,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether two inputs have the same shape. Operation *op = this->getOperation(); - assert(lhs < op->getNumOperands() && lhs >= 0 && - rhs < op->getNumOperands() && rhs >= 0); + assert(lhs >= 0 && rhs >= 0); if (lhs == rhs) return true; - - // if both lhs and rhs have static shapes, check them directly - Type lhs_ty = op->getOperand(lhs).getType(); - Type rhs_ty = op->getOperand(rhs).getType(); - auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); - auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); - if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || - !rhs_shape_type || !rhs_shape_type.hasStaticShape() || - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - return false; - } - return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + return inferShapeEquality(op->getOperand(lhs), op->getOperand(rhs)); }] >, InterfaceMethod< @@ -91,21 +79,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether two outputs have the same shape. Operation *op = this->getOperation(); - assert(lhs < op->getNumResults() && lhs >= 0 && - rhs < op->getNumResults() && rhs >= 0); + assert(lhs >= 0 && rhs >= 0); if (lhs == rhs) return true; - - // if both lhs and rhs have static shapes, check them directly - Type lhs_ty = op->getResult(lhs).getType(); - Type rhs_ty = op->getResult(rhs).getType(); - auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); - auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); - if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || - !rhs_shape_type || !rhs_shape_type.hasStaticShape() || - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - return false; - } - return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + return inferShapeEquality(op->getResult(lhs), op->getResult(rhs)); }] >, InterfaceMethod< @@ -118,20 +94,8 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether the input and the output have the same shape. Operation *op = this->getOperation(); - assert(input < op->getNumOperands() && input >= 0 && - output < op->getNumResults() && output >= 0); - - // if both input and output have static shapes, check them directly - Type input_ty = op->getOperand(input).getType(); - Type output_ty = op->getResult(output).getType(); - auto input_shape_type = input_ty.dyn_cast_or_null(); - auto output_shape_type = output_ty.dyn_cast_or_null(); - if (!input_shape_type || !input_shape_type.hasStaticShape() || - !output_shape_type || !output_shape_type.hasStaticShape() || - input_shape_type.getRank() != output_shape_type.getRank()) { - return false; - } - return input_shape_type.getShape() == output_shape_type.getShape(); + assert(input >= 0 && output >= 0); + return inferShapeEquality(op->getOperand(input), op->getResult(output)); }] >, InterfaceMethod< @@ -156,6 +120,21 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { }] >, ]; + + let extraClassDeclaration = [{ + // Returns whether the given values have the same static shape. + static bool inferShapeEquality(Value first, Value second) { + // If both lhs and rhs have static shapes, check them directly. + auto first_ty = first.getType().dyn_cast(); + auto second_ty = second.getType().dyn_cast(); + if (!first_ty || !first_ty.hasStaticShape() || + !second_ty || !second_ty.hasStaticShape() || + first_ty.getRank() != second_ty.getRank()) { + return false; + } + return first_ty.getShape() == second_ty.getShape(); + } + }]; } #endif