Add compare_type optional attribute to CompareOp in HLO dialects
If unspecified, `compare_type` is FLOAT for float element types, SIGNED for signed element types and UNSIGNED for unsigned element types. compare_type can be TOTALORDER for float element types. - Added import and export support the attribute. - Restricted legalization from HLO to TF to the default compare types. - Updated existing usage of the CompareOp PiperOrigin-RevId: 339099219
This commit is contained in:
parent
f9843fabe1
commit
6eda9ed273
|
@ -427,7 +427,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
string summary = "Compare operator (with optional broadcasting)";
|
||||
|
||||
string description = [{
|
||||
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
|
||||
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
|
||||
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
|
||||
types, SIGNED for signed element types and UNSIGNED for unsigned element
|
||||
types.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
||||
|
@ -437,13 +440,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction,
|
||||
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
||||
);
|
||||
let results = (outs HLO_PredTensor);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
|
||||
"DenseIntElementsAttr broadcast_dimensions, "
|
||||
"StringAttr comparison_direction, StringAttr compare_type = {}"
|
||||
>];
|
||||
}
|
||||
|
||||
|
|
|
@ -680,16 +680,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
|
|||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction,
|
||||
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
||||
);
|
||||
let results = (outs HLO_PredTensor);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"StringAttr comparison_direction"
|
||||
>];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"StringAttr comparison_direction, StringAttr compare_type = {}">,
|
||||
];
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -749,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
|
|||
HLO_COMPARISON_DIRECTION_LT
|
||||
]>;
|
||||
|
||||
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
|
||||
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
|
||||
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
|
||||
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
|
||||
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
|
||||
|
||||
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
|
||||
"Which comparison type to use.",
|
||||
[
|
||||
HLO_COMPARISON_TYPE_FLOAT,
|
||||
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
|
||||
HLO_COMPARISON_TYPE_SIGNED,
|
||||
HLO_COMPARISON_TYPE_UNSIGNED
|
||||
]>;
|
||||
|
||||
|
||||
class BASE_HLO_CompareOp {
|
||||
string summary = "Comparison operator";
|
||||
|
||||
string description = [{
|
||||
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
|
||||
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
|
||||
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
|
||||
types, SIGNED for signed element types and UNSIGNED for unsigned element
|
||||
types.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
||||
|
|
|
@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||
HLO_ComparisonDirectionAttr:$comparison_direction,
|
||||
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -190,11 +190,12 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
|||
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
|
||||
Value lhs, Value rhs,
|
||||
DenseIntElementsAttr broadcast_dimensions,
|
||||
StringAttr comparison_direction) {
|
||||
StringAttr comparison_direction,
|
||||
StringAttr compare_type) {
|
||||
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
|
||||
builder.getI1Type(), broadcast_dimensions);
|
||||
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
||||
comparison_direction);
|
||||
comparison_direction, compare_type);
|
||||
}
|
||||
|
||||
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||
|
|
|
@ -2611,10 +2611,12 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
||||
Value rhs, StringAttr comparison_direction) {
|
||||
Value rhs, StringAttr comparison_direction,
|
||||
StringAttr compare_type) {
|
||||
auto new_type =
|
||||
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
|
||||
build(builder, result, new_type, lhs, rhs, comparison_direction);
|
||||
build(builder, result, new_type, lhs, rhs, comparison_direction,
|
||||
compare_type);
|
||||
}
|
||||
|
||||
LogicalResult CompareOp::inferReturnTypeComponents(
|
||||
|
|
|
@ -505,9 +505,9 @@ struct HloCompareAdaptor {
|
|||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
return builder.create<mhlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -31,7 +31,8 @@ def : Pat<(HLOClient_AcosOp $input),
|
|||
(HLO_CompareOp
|
||||
$input,
|
||||
(HLO_ConstantLike<"-1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_NE
|
||||
HLO_COMPARISON_DIRECTION_NE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_MulOp
|
||||
(HLO_ConstantLike<"2"> $input),
|
||||
|
@ -67,7 +68,8 @@ def : Pat<(HLOClient_SinhOp $input),
|
|||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_LT
|
||||
HLO_COMPARISON_DIRECTION_LT,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_DivOp
|
||||
(HLO_SubOp
|
||||
|
|
Loading…
Reference in New Issue