[MLIR][CHLO] Add `is_inf`, `is_pos_inf`, and `is_neg_inf` to CHLO dialect

Also add the respective lowerings to MHLO.

PiperOrigin-RevId: 354101955
This commit is contained in:
A. Unique TensorFlower 2021-01-27 08:59:58 -08:00 committed by TensorFlow MLIR Team
parent f4f728f18e
commit d77c9ad6fa
7 changed files with 137 additions and 35 deletions

View File

@ -71,6 +71,9 @@ Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
bool negative);
} // namespace chlo
} // namespace mlir

View File

@ -337,16 +337,19 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
//===----------------------------------------------------------------------===//
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType:$result);
Type ArgTensorType, Type ResultTensorType> : HLOClient_Op<mnemonic,
!listconcat(traits, [InferFusibilityOpInterface, NoSideEffect,
SameOperandsAndResultShape])> {
let arguments = (ins ArgTensorType:$operand);
let results = (outs ResultTensorType:$result);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
HLO_FpOrComplexTensor> {
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Acos operator";
let description = [{
@ -359,8 +362,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
}];
}
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [],
HLO_FpOrComplexTensor> {
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Acosh operation";
let description = [{
@ -373,8 +376,8 @@ def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [],
}];
}
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
HLO_FpOrComplexTensor> {
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Asin operator";
let description = [{
@ -386,8 +389,8 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
}];
}
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
HLO_FpOrComplexTensor> {
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Asinh operation";
let description = [{
@ -399,8 +402,8 @@ def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
}];
}
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Atan operator";
let description = [{
@ -412,8 +415,8 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
}];
}
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
HLO_FpOrComplexTensor> {
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Atanh operator";
let description = [{
@ -426,8 +429,8 @@ def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
}];
}
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
HLO_FpOrComplexTensor> {
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Conj operator";
let description = [{
@ -439,8 +442,8 @@ def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
}];
}
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
HLO_FpOrComplexTensor> {
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Cosh operator";
let description = [{
@ -452,8 +455,8 @@ def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> {
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Sinh operation";
let description = [{
@ -466,8 +469,8 @@ def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
}];
}
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
HLO_FpOrComplexTensor> {
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
let summary = "Tan operation";
let description = [{
@ -498,8 +501,7 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
}
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
[NoSideEffect, SameOperandsAndResultShape],
HLO_FpTensor> {
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
let summary = "Erfc operator";
let description = [{
@ -511,8 +513,7 @@ def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
}
def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
[NoSideEffect, SameOperandsAndResultShape],
HLO_FpTensor> {
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
let summary = "Erfc operator";
let description = [{
@ -523,6 +524,33 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
}];
}
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
HLO_PredTensor> {
let summary = "IsInf predicate";
let description = [{
Returns if a value is +/-inf element-wise.
}];
}
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
HLO_FpTensor, HLO_PredTensor> {
let summary = "IsNegInf predicate";
let description = [{
Returns if a value is -inf element-wise.
}];
}
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [],
HLO_FpTensor, HLO_PredTensor> {
let summary = "IsPosInf predicate";
let description = [{
Returns if a value is +inf element-wise.
}];
}
//===----------------------------------------------------------------------===//
// Broadcasting compare op
//===----------------------------------------------------------------------===//

View File

@ -33,6 +33,12 @@ class HLO_ConstantLike<string value> : NativeCodeCall<
def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
def HLO_ConstantLikePosInfValue : NativeCodeCall<
"chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">;
def HLO_ConstantLikeNegInfValue : NativeCodeCall<
"chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<
@ -50,4 +56,8 @@ class GetScalarOfType<int value> : NativeCodeCall<
def IdentityBroadcastDims : AttrConstraint<
CPred<"hlo::IsSequenceStartingWith0($_self)">>;
def NonComplexElementType : Type<
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
"Non-complex element type">;
#endif // HLO_UTILS

View File

@ -39,6 +39,13 @@ Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
}
Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
bool negative) {
auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
return getConstantLike(
b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
}
Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());

View File

@ -23,10 +23,6 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
// Unary op patterns.
//===----------------------------------------------------------------------===//
def NonComplexElementType : Type<
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
"Non complex element type">;
// Expand acos to MHLO dialect as follows:
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
// = pi if x == -1
@ -285,6 +281,33 @@ def : Pat<(HLOClient_CoshOp NonComplexElementType:$input),
)
)>;
// Express `is_inf` as
// is_inf(x) = is_pos_inf(|x|)
def : Pat<(HLOClient_IsInfOp NonComplexElementType:$input),
(HLOClient_IsPosInfOp
(HLO_AbsOp $input)
)>;
// Express `is_pos_inf` as
// is_pos_inf(x) = (x == +inf)
def : Pat<(HLOClient_IsPosInfOp NonComplexElementType:$input),
(HLO_CompareOp
$input,
(HLO_ConstantLikePosInfValue $input),
HLO_COMPARISON_DIRECTION_EQ,
(HLO_DEFAULT_COMPARISON_TYPE)
)>;
// Express `is_neg_inf` as
// is_neg_inf(x) = (x == -inf)
def : Pat<(HLOClient_IsNegInfOp NonComplexElementType:$input),
(HLO_CompareOp
$input,
(HLO_ConstantLikeNegInfValue $input),
HLO_COMPARISON_DIRECTION_EQ,
(HLO_DEFAULT_COMPARISON_TYPE)
)>;
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.

View File

@ -652,3 +652,34 @@ func @erfc_f16(%arg : tensor<f16>) -> tensor<f16> {
%1 = "chlo.erfc"(%arg) : (tensor<f16>) -> tensor<f16>
return %1 : tensor<f16>
}
// CHECK-LABEL: @is_inf_f32
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func @is_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
// CHECK: %[[ABS:.*]] = "mhlo.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[POS_INF:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.compare"(%[[ABS]], %[[POS_INF]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: return %[[RESULT]] : tensor<i1>
%1 = chlo.is_inf %arg : tensor<f32> -> tensor<i1>
return %1 : tensor<i1>
}
// CHECK-LABEL: @is_pos_inf_f32
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func @is_pos_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
// CHECK: %[[POS_INF:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.compare"(%[[ARG]], %[[POS_INF]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: return %[[RESULT]] : tensor<i1>
%1 = chlo.is_pos_inf %arg : tensor<f32> -> tensor<i1>
return %1 : tensor<i1>
}
// CHECK-LABEL: @is_neg_inf_f32
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func @is_neg_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
// CHECK: %[[NEG_INF:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.compare"(%[[ARG]], %[[NEG_INF]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: return %[[RESULT]] : tensor<i1>
%1 = chlo.is_neg_inf %arg : tensor<f32> -> tensor<i1>
return %1 : tensor<i1>
}

View File

@ -90,10 +90,10 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32> -> tensor<?xf32>
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[B]] : tensor<*xf32>
%result = chlo.tan %a : tensor<*xf32>
%result = chlo.tan %a : tensor<*xf32> -> tensor<*xf32>
return %result : tensor<*xf32>
}