[NFC] Factor out repeated code out of InferFusibilityOpInterface.

PiperOrigin-RevId: 348671671
This commit is contained in:
Rahul Joshi 2020-12-22 12:03:07 -08:00 committed by TensorFlow MLIR Team
parent bc367971ec
commit 8252eafa99
1 changed files with 22 additions and 43 deletions

View File

@ -50,7 +50,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
/*args=*/(ins), /*args=*/(ins),
/*methodBody=*/[{}], /*methodBody=*/[{}],
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
/// Return whether this op can be fused withh its consumers /// Return whether this op can be fused with its consumers
return true; return true;
}] }]
>, >,
@ -64,21 +64,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
/// Return whether two inputs have the same shape. /// Return whether two inputs have the same shape.
Operation *op = this->getOperation(); Operation *op = this->getOperation();
assert(lhs < op->getNumOperands() && lhs >= 0 && assert(lhs >= 0 && rhs >= 0);
rhs < op->getNumOperands() && rhs >= 0);
if (lhs == rhs) return true; if (lhs == rhs) return true;
return inferShapeEquality(op->getOperand(lhs), op->getOperand(rhs));
// if both lhs and rhs have static shapes, check them directly
Type lhs_ty = op->getOperand(lhs).getType();
Type rhs_ty = op->getOperand(rhs).getType();
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
return false;
}
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
}] }]
>, >,
InterfaceMethod< InterfaceMethod<
@ -91,21 +79,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
/// Return whether two outputs have the same shape. /// Return whether two outputs have the same shape.
Operation *op = this->getOperation(); Operation *op = this->getOperation();
assert(lhs < op->getNumResults() && lhs >= 0 && assert(lhs >= 0 && rhs >= 0);
rhs < op->getNumResults() && rhs >= 0);
if (lhs == rhs) return true; if (lhs == rhs) return true;
return inferShapeEquality(op->getResult(lhs), op->getResult(rhs));
// if both lhs and rhs have static shapes, check them directly
Type lhs_ty = op->getResult(lhs).getType();
Type rhs_ty = op->getResult(rhs).getType();
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
return false;
}
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
}] }]
>, >,
InterfaceMethod< InterfaceMethod<
@ -118,20 +94,8 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
/// Return whether the input and the output have the same shape. /// Return whether the input and the output have the same shape.
Operation *op = this->getOperation(); Operation *op = this->getOperation();
assert(input < op->getNumOperands() && input >= 0 && assert(input >= 0 && output >= 0);
output < op->getNumResults() && output >= 0); return inferShapeEquality(op->getOperand(input), op->getResult(output));
// if both input and output have static shapes, check them directly
Type input_ty = op->getOperand(input).getType();
Type output_ty = op->getResult(output).getType();
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
!output_shape_type || !output_shape_type.hasStaticShape() ||
input_shape_type.getRank() != output_shape_type.getRank()) {
return false;
}
return input_shape_type.getShape() == output_shape_type.getShape();
}] }]
>, >,
InterfaceMethod< InterfaceMethod<
@ -156,6 +120,21 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
}] }]
>, >,
]; ];
let extraClassDeclaration = [{
// Returns whether the given values have the same static shape.
static bool inferShapeEquality(Value first, Value second) {
// If both lhs and rhs have static shapes, check them directly.
auto first_ty = first.getType().dyn_cast<RankedTensorType>();
auto second_ty = second.getType().dyn_cast<RankedTensorType>();
if (!first_ty || !first_ty.hasStaticShape() ||
!second_ty || !second_ty.hasStaticShape() ||
first_ty.getRank() != second_ty.getRank()) {
return false;
}
return first_ty.getShape() == second_ty.getShape();
}
}];
} }
#endif #endif