Add compare_type optional attribute to CompareOp in HLO dialects

If unspecified, `compare_type` is FLOAT for float element types, SIGNED for signed element types and UNSIGNED for unsigned element types. compare_type can be TOTALORDER for float element types.

- Added import and export support the attribute.
- Restricted legalization from HLO to TF to the default compare types.
- Updated existing usage of the CompareOp

PiperOrigin-RevId: 339099219
This commit is contained in:
Smit Hinsu 2020-10-26 12:57:48 -07:00 committed by TensorFlow MLIR Team
parent f9843fabe1
commit 6eda9ed273
9 changed files with 53 additions and 20 deletions

View File

@ -427,7 +427,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
string summary = "Compare operator (with optional broadcasting)"; string summary = "Compare operator (with optional broadcasting)";
string description = [{ string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`. Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
@ -437,13 +440,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
HLO_Tensor:$rhs, HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions, OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
); );
let results = (outs HLO_PredTensor); let results = (outs HLO_PredTensor);
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" "DenseIntElementsAttr broadcast_dimensions, "
"StringAttr comparison_direction, StringAttr compare_type = {}"
>]; >];
} }

View File

@ -680,16 +680,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
let arguments = (ins let arguments = (ins
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
HLO_Tensor:$rhs, HLO_Tensor:$rhs,
HLO_ComparisonDirectionAttr:$comparison_direction HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
); );
let results = (outs HLO_PredTensor); let results = (outs HLO_PredTensor);
let hasFolder = 1; let hasFolder = 1;
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " OpBuilder<"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"StringAttr comparison_direction" "StringAttr comparison_direction, StringAttr compare_type = {}">,
>]; ];
let hasCustomHLOConverter = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -749,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
HLO_COMPARISON_DIRECTION_LT HLO_COMPARISON_DIRECTION_LT
]>; ]>;
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]>;
class BASE_HLO_CompareOp { class BASE_HLO_CompareOp {
string summary = "Comparison operator"; string summary = "Comparison operator";
string description = [{ string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`. Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.

View File

@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out, Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions, OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
); );
} }

View File

@ -190,11 +190,12 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
Value lhs, Value rhs, Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dimensions, DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) { StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
builder.getI1Type(), broadcast_dimensions); builder.getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions, build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction); comparison_direction, compare_type);
} }
LogicalResult BroadcastCompareOp::inferReturnTypeComponents( LogicalResult BroadcastCompareOp::inferReturnTypeComponents(

View File

@ -2611,10 +2611,12 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
Value rhs, StringAttr comparison_direction) { Value rhs, StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type = auto new_type =
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
build(builder, result, new_type, lhs, rhs, comparison_direction); build(builder, result, new_type, lhs, rhs, comparison_direction,
compare_type);
} }
LogicalResult CompareOp::inferReturnTypeComponents( LogicalResult CompareOp::inferReturnTypeComponents(

View File

@ -505,9 +505,9 @@ struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs, Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) { OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type, return builder.create<mhlo::CompareOp>(
broadcasted_lhs, broadcasted_rhs, from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction()); from_op.comparison_direction(), from_op.compare_typeAttr());
} }
}; };

View File

@ -31,7 +31,8 @@ def : Pat<(HLOClient_AcosOp $input),
(HLO_CompareOp (HLO_CompareOp
$input, $input,
(HLO_ConstantLike<"-1"> $input), (HLO_ConstantLike<"-1"> $input),
HLO_COMPARISON_DIRECTION_NE HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)
), ),
(HLO_MulOp (HLO_MulOp
(HLO_ConstantLike<"2"> $input), (HLO_ConstantLike<"2"> $input),
@ -67,7 +68,8 @@ def : Pat<(HLOClient_SinhOp $input),
(HLO_CompareOp (HLO_CompareOp
(HLO_AbsOp $input), (HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input), (HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_LT HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
), ),
(HLO_DivOp (HLO_DivOp
(HLO_SubOp (HLO_SubOp

View File

@ -28,7 +28,7 @@ func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64> // CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" } %0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
%1 = "mhlo_test.reify_return_type_shapes"(%0) %1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xi1>) -> tensor<2xi64> : (tensor<2x?xi1>) -> tensor<2xi64>