[NFC] Factor out repeated code out of InferFusibilityOpInterface.
PiperOrigin-RevId: 348671671
This commit is contained in:
parent
bc367971ec
commit
8252eafa99
|
@ -50,7 +50,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||||
/*args=*/(ins),
|
/*args=*/(ins),
|
||||||
/*methodBody=*/[{}],
|
/*methodBody=*/[{}],
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
/// Return whether this op can be fused withh its consumers
|
/// Return whether this op can be fused with its consumers
|
||||||
return true;
|
return true;
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
|
@ -64,21 +64,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
/// Return whether two inputs have the same shape.
|
/// Return whether two inputs have the same shape.
|
||||||
Operation *op = this->getOperation();
|
Operation *op = this->getOperation();
|
||||||
assert(lhs < op->getNumOperands() && lhs >= 0 &&
|
assert(lhs >= 0 && rhs >= 0);
|
||||||
rhs < op->getNumOperands() && rhs >= 0);
|
|
||||||
if (lhs == rhs) return true;
|
if (lhs == rhs) return true;
|
||||||
|
return inferShapeEquality(op->getOperand(lhs), op->getOperand(rhs));
|
||||||
// if both lhs and rhs have static shapes, check them directly
|
|
||||||
Type lhs_ty = op->getOperand(lhs).getType();
|
|
||||||
Type rhs_ty = op->getOperand(rhs).getType();
|
|
||||||
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
|
||||||
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
|
||||||
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
|
||||||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
|
||||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
|
@ -91,21 +79,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
/// Return whether two outputs have the same shape.
|
/// Return whether two outputs have the same shape.
|
||||||
Operation *op = this->getOperation();
|
Operation *op = this->getOperation();
|
||||||
assert(lhs < op->getNumResults() && lhs >= 0 &&
|
assert(lhs >= 0 && rhs >= 0);
|
||||||
rhs < op->getNumResults() && rhs >= 0);
|
|
||||||
if (lhs == rhs) return true;
|
if (lhs == rhs) return true;
|
||||||
|
return inferShapeEquality(op->getResult(lhs), op->getResult(rhs));
|
||||||
// if both lhs and rhs have static shapes, check them directly
|
|
||||||
Type lhs_ty = op->getResult(lhs).getType();
|
|
||||||
Type rhs_ty = op->getResult(rhs).getType();
|
|
||||||
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
|
||||||
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
|
||||||
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
|
||||||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
|
||||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
|
@ -118,20 +94,8 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
/// Return whether the input and the output have the same shape.
|
/// Return whether the input and the output have the same shape.
|
||||||
Operation *op = this->getOperation();
|
Operation *op = this->getOperation();
|
||||||
assert(input < op->getNumOperands() && input >= 0 &&
|
assert(input >= 0 && output >= 0);
|
||||||
output < op->getNumResults() && output >= 0);
|
return inferShapeEquality(op->getOperand(input), op->getResult(output));
|
||||||
|
|
||||||
// if both input and output have static shapes, check them directly
|
|
||||||
Type input_ty = op->getOperand(input).getType();
|
|
||||||
Type output_ty = op->getResult(output).getType();
|
|
||||||
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
|
|
||||||
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
|
|
||||||
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
|
|
||||||
!output_shape_type || !output_shape_type.hasStaticShape() ||
|
|
||||||
input_shape_type.getRank() != output_shape_type.getRank()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return input_shape_type.getShape() == output_shape_type.getShape();
|
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
|
@ -156,6 +120,21 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// Returns whether the given values have the same static shape.
|
||||||
|
static bool inferShapeEquality(Value first, Value second) {
|
||||||
|
// If both lhs and rhs have static shapes, check them directly.
|
||||||
|
auto first_ty = first.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto second_ty = second.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!first_ty || !first_ty.hasStaticShape() ||
|
||||||
|
!second_ty || !second_ty.hasStaticShape() ||
|
||||||
|
first_ty.getRank() != second_ty.getRank()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return first_ty.getShape() == second_ty.getShape();
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue