Add compare_type optional attribute to CompareOp in HLO dialects
If unspecified, `compare_type` is FLOAT for float element types, SIGNED for signed element types and UNSIGNED for unsigned element types. compare_type can be TOTALORDER for float element types. - Added import and export support the attribute. - Restricted legalization from HLO to TF to the default compare types. - Updated existing usage of the CompareOp PiperOrigin-RevId: 339099219
This commit is contained in:
parent
f9843fabe1
commit
6eda9ed273
|
@ -427,7 +427,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
string summary = "Compare operator (with optional broadcasting)";
|
string summary = "Compare operator (with optional broadcasting)";
|
||||||
|
|
||||||
string description = [{
|
string description = [{
|
||||||
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
|
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
|
||||||
|
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
|
||||||
|
types, SIGNED for signed element types and UNSIGNED for unsigned element
|
||||||
|
types.
|
||||||
|
|
||||||
See
|
See
|
||||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
||||||
|
@ -437,13 +440,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
HLO_Tensor:$lhs,
|
HLO_Tensor:$lhs,
|
||||||
HLO_Tensor:$rhs,
|
HLO_Tensor:$rhs,
|
||||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
HLO_ComparisonDirectionAttr:$comparison_direction,
|
||||||
|
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
||||||
);
|
);
|
||||||
let results = (outs HLO_PredTensor);
|
let results = (outs HLO_PredTensor);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||||
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
|
"DenseIntElementsAttr broadcast_dimensions, "
|
||||||
|
"StringAttr comparison_direction, StringAttr compare_type = {}"
|
||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -680,16 +680,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$lhs,
|
HLO_Tensor:$lhs,
|
||||||
HLO_Tensor:$rhs,
|
HLO_Tensor:$rhs,
|
||||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
HLO_ComparisonDirectionAttr:$comparison_direction,
|
||||||
|
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
||||||
);
|
);
|
||||||
let results = (outs HLO_PredTensor);
|
let results = (outs HLO_PredTensor);
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [
|
||||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
OpBuilder<"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||||
"StringAttr comparison_direction"
|
"StringAttr comparison_direction, StringAttr compare_type = {}">,
|
||||||
>];
|
];
|
||||||
|
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -749,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
|
||||||
HLO_COMPARISON_DIRECTION_LT
|
HLO_COMPARISON_DIRECTION_LT
|
||||||
]>;
|
]>;
|
||||||
|
|
||||||
|
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
|
||||||
|
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
|
||||||
|
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
|
||||||
|
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
|
||||||
|
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
|
||||||
|
|
||||||
|
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
|
||||||
|
"Which comparison type to use.",
|
||||||
|
[
|
||||||
|
HLO_COMPARISON_TYPE_FLOAT,
|
||||||
|
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
|
||||||
|
HLO_COMPARISON_TYPE_SIGNED,
|
||||||
|
HLO_COMPARISON_TYPE_UNSIGNED
|
||||||
|
]>;
|
||||||
|
|
||||||
|
|
||||||
class BASE_HLO_CompareOp {
|
class BASE_HLO_CompareOp {
|
||||||
string summary = "Comparison operator";
|
string summary = "Comparison operator";
|
||||||
|
|
||||||
string description = [{
|
string description = [{
|
||||||
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
|
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
|
||||||
|
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
|
||||||
|
types, SIGNED for signed element types and UNSIGNED for unsigned element
|
||||||
|
types.
|
||||||
|
|
||||||
See
|
See
|
||||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
||||||
|
|
|
@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
|
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
|
||||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||||
HLO_ComparisonDirectionAttr:$comparison_direction
|
HLO_ComparisonDirectionAttr:$comparison_direction,
|
||||||
|
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -190,11 +190,12 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
||||||
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
|
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
|
||||||
Value lhs, Value rhs,
|
Value lhs, Value rhs,
|
||||||
DenseIntElementsAttr broadcast_dimensions,
|
DenseIntElementsAttr broadcast_dimensions,
|
||||||
StringAttr comparison_direction) {
|
StringAttr comparison_direction,
|
||||||
|
StringAttr compare_type) {
|
||||||
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
|
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
|
||||||
builder.getI1Type(), broadcast_dimensions);
|
builder.getI1Type(), broadcast_dimensions);
|
||||||
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
||||||
comparison_direction);
|
comparison_direction, compare_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||||
|
|
|
@ -2611,10 +2611,12 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
||||||
Value rhs, StringAttr comparison_direction) {
|
Value rhs, StringAttr comparison_direction,
|
||||||
|
StringAttr compare_type) {
|
||||||
auto new_type =
|
auto new_type =
|
||||||
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
|
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
|
||||||
build(builder, result, new_type, lhs, rhs, comparison_direction);
|
build(builder, result, new_type, lhs, rhs, comparison_direction,
|
||||||
|
compare_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult CompareOp::inferReturnTypeComponents(
|
LogicalResult CompareOp::inferReturnTypeComponents(
|
||||||
|
|
|
@ -505,9 +505,9 @@ struct HloCompareAdaptor {
|
||||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||||
OpBuilder &builder) {
|
OpBuilder &builder) {
|
||||||
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
|
return builder.create<mhlo::CompareOp>(
|
||||||
broadcasted_lhs, broadcasted_rhs,
|
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||||
from_op.comparison_direction());
|
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,8 @@ def : Pat<(HLOClient_AcosOp $input),
|
||||||
(HLO_CompareOp
|
(HLO_CompareOp
|
||||||
$input,
|
$input,
|
||||||
(HLO_ConstantLike<"-1"> $input),
|
(HLO_ConstantLike<"-1"> $input),
|
||||||
HLO_COMPARISON_DIRECTION_NE
|
HLO_COMPARISON_DIRECTION_NE,
|
||||||
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
),
|
),
|
||||||
(HLO_MulOp
|
(HLO_MulOp
|
||||||
(HLO_ConstantLike<"2"> $input),
|
(HLO_ConstantLike<"2"> $input),
|
||||||
|
@ -67,7 +68,8 @@ def : Pat<(HLOClient_SinhOp $input),
|
||||||
(HLO_CompareOp
|
(HLO_CompareOp
|
||||||
(HLO_AbsOp $input),
|
(HLO_AbsOp $input),
|
||||||
(HLO_ConstantLike<"1"> $input),
|
(HLO_ConstantLike<"1"> $input),
|
||||||
HLO_COMPARISON_DIRECTION_LT
|
HLO_COMPARISON_DIRECTION_LT,
|
||||||
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
),
|
),
|
||||||
(HLO_DivOp
|
(HLO_DivOp
|
||||||
(HLO_SubOp
|
(HLO_SubOp
|
||||||
|
|
Loading…
Reference in New Issue