[NFC] Factor out repeated code out of InferFusibilityOpInterface.
PiperOrigin-RevId: 348671671
This commit is contained in:
parent
bc367971ec
commit
8252eafa99
|
@ -50,7 +50,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
|||
/*args=*/(ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
/// Return whether this op can be fused withh its consumers
|
||||
/// Return whether this op can be fused with its consumers
|
||||
return true;
|
||||
}]
|
||||
>,
|
||||
|
@ -64,21 +64,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
/// Return whether two inputs have the same shape.
|
||||
Operation *op = this->getOperation();
|
||||
assert(lhs < op->getNumOperands() && lhs >= 0 &&
|
||||
rhs < op->getNumOperands() && rhs >= 0);
|
||||
assert(lhs >= 0 && rhs >= 0);
|
||||
if (lhs == rhs) return true;
|
||||
|
||||
// if both lhs and rhs have static shapes, check them directly
|
||||
Type lhs_ty = op->getOperand(lhs).getType();
|
||||
Type rhs_ty = op->getOperand(rhs).getType();
|
||||
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
||||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
||||
return inferShapeEquality(op->getOperand(lhs), op->getOperand(rhs));
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
@ -91,21 +79,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
/// Return whether two outputs have the same shape.
|
||||
Operation *op = this->getOperation();
|
||||
assert(lhs < op->getNumResults() && lhs >= 0 &&
|
||||
rhs < op->getNumResults() && rhs >= 0);
|
||||
assert(lhs >= 0 && rhs >= 0);
|
||||
if (lhs == rhs) return true;
|
||||
|
||||
// if both lhs and rhs have static shapes, check them directly
|
||||
Type lhs_ty = op->getResult(lhs).getType();
|
||||
Type rhs_ty = op->getResult(rhs).getType();
|
||||
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
||||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
||||
return inferShapeEquality(op->getResult(lhs), op->getResult(rhs));
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
@ -118,20 +94,8 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
/// Return whether the input and the output have the same shape.
|
||||
Operation *op = this->getOperation();
|
||||
assert(input < op->getNumOperands() && input >= 0 &&
|
||||
output < op->getNumResults() && output >= 0);
|
||||
|
||||
// if both input and output have static shapes, check them directly
|
||||
Type input_ty = op->getOperand(input).getType();
|
||||
Type output_ty = op->getResult(output).getType();
|
||||
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
|
||||
!output_shape_type || !output_shape_type.hasStaticShape() ||
|
||||
input_shape_type.getRank() != output_shape_type.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return input_shape_type.getShape() == output_shape_type.getShape();
|
||||
assert(input >= 0 && output >= 0);
|
||||
return inferShapeEquality(op->getOperand(input), op->getResult(output));
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
@ -156,6 +120,21 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
|||
}]
|
||||
>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns whether the given values have the same static shape.
|
||||
static bool inferShapeEquality(Value first, Value second) {
|
||||
// If both lhs and rhs have static shapes, check them directly.
|
||||
auto first_ty = first.getType().dyn_cast<RankedTensorType>();
|
||||
auto second_ty = second.getType().dyn_cast<RankedTensorType>();
|
||||
if (!first_ty || !first_ty.hasStaticShape() ||
|
||||
!second_ty || !second_ty.hasStaticShape() ||
|
||||
first_ty.getRank() != second_ty.getRank()) {
|
||||
return false;
|
||||
}
|
||||
return first_ty.getShape() == second_ty.getShape();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue