[MLIR][CHLO] Add `is_inf`, `is_pos_inf`, and `is_neg_inf` to CHLO dialect
Also add the respective lowerings to MHLO. PiperOrigin-RevId: 354101955
This commit is contained in:
parent
f4f728f18e
commit
d77c9ad6fa
|
@ -71,6 +71,9 @@ Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
|
|||
|
||||
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
|
||||
|
||||
Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
|
||||
bool negative);
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -337,16 +337,19 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type TensorType> : HLOClient_Op<mnemonic, !listconcat(traits, [
|
||||
InferFusibilityOpInterface, NoSideEffect, SameOperandsAndResultType])> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType:$result);
|
||||
Type ArgTensorType, Type ResultTensorType> : HLOClient_Op<mnemonic,
|
||||
!listconcat(traits, [InferFusibilityOpInterface, NoSideEffect,
|
||||
SameOperandsAndResultShape])> {
|
||||
let arguments = (ins ArgTensorType:$operand);
|
||||
let results = (outs ResultTensorType:$result);
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Acos operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -359,8 +362,8 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Acosh operation";
|
||||
|
||||
let description = [{
|
||||
|
@ -373,8 +376,8 @@ def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Asin operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -386,8 +389,8 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Asinh operation";
|
||||
|
||||
let description = [{
|
||||
|
@ -399,8 +402,8 @@ def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Atan operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -412,8 +415,8 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Atanh operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -426,8 +429,8 @@ def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Conj operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -439,8 +442,8 @@ def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Cosh operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -452,8 +455,8 @@ def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Sinh operation";
|
||||
|
||||
let description = [{
|
||||
|
@ -466,8 +469,8 @@ def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
||||
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
||||
let summary = "Tan operation";
|
||||
|
||||
let description = [{
|
||||
|
@ -498,8 +501,7 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
|||
}
|
||||
|
||||
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
|
||||
[NoSideEffect, SameOperandsAndResultShape],
|
||||
HLO_FpTensor> {
|
||||
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||
let summary = "Erfc operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -511,8 +513,7 @@ def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
|
|||
}
|
||||
|
||||
def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
|
||||
[NoSideEffect, SameOperandsAndResultShape],
|
||||
HLO_FpTensor> {
|
||||
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||
let summary = "Erfc operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -523,6 +524,33 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
|
||||
HLO_PredTensor> {
|
||||
let summary = "IsInf predicate";
|
||||
|
||||
let description = [{
|
||||
Returns if a value is +/-inf element-wise.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
|
||||
HLO_FpTensor, HLO_PredTensor> {
|
||||
let summary = "IsNegInf predicate";
|
||||
|
||||
let description = [{
|
||||
Returns if a value is -inf element-wise.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [],
|
||||
HLO_FpTensor, HLO_PredTensor> {
|
||||
let summary = "IsPosInf predicate";
|
||||
|
||||
let description = [{
|
||||
Returns if a value is +inf element-wise.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Broadcasting compare op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -33,6 +33,12 @@ class HLO_ConstantLike<string value> : NativeCodeCall<
|
|||
def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
|
||||
"chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
|
||||
|
||||
def HLO_ConstantLikePosInfValue : NativeCodeCall<
|
||||
"chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">;
|
||||
|
||||
def HLO_ConstantLikeNegInfValue : NativeCodeCall<
|
||||
"chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
|
@ -50,4 +56,8 @@ class GetScalarOfType<int value> : NativeCodeCall<
|
|||
def IdentityBroadcastDims : AttrConstraint<
|
||||
CPred<"hlo::IsSequenceStartingWith0($_self)">>;
|
||||
|
||||
def NonComplexElementType : Type<
|
||||
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
|
||||
"Non-complex element type">;
|
||||
|
||||
#endif // HLO_UTILS
|
||||
|
|
|
@ -39,6 +39,13 @@ Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
|
|||
b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
|
||||
}
|
||||
|
||||
Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
|
||||
bool negative) {
|
||||
auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
|
||||
return getConstantLike(
|
||||
b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
|
||||
}
|
||||
|
||||
Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
|
||||
Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
|
|
|
@ -23,10 +23,6 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
|||
// Unary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NonComplexElementType : Type<
|
||||
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
|
||||
"Non complex element type">;
|
||||
|
||||
// Expand acos to MHLO dialect as follows:
|
||||
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
|
||||
// = pi if x == -1
|
||||
|
@ -285,6 +281,33 @@ def : Pat<(HLOClient_CoshOp NonComplexElementType:$input),
|
|||
)
|
||||
)>;
|
||||
|
||||
// Express `is_inf` as
|
||||
// is_inf(x) = is_pos_inf(|x|)
|
||||
def : Pat<(HLOClient_IsInfOp NonComplexElementType:$input),
|
||||
(HLOClient_IsPosInfOp
|
||||
(HLO_AbsOp $input)
|
||||
)>;
|
||||
|
||||
// Express `is_pos_inf` as
|
||||
// is_pos_inf(x) = (x == +inf)
|
||||
def : Pat<(HLOClient_IsPosInfOp NonComplexElementType:$input),
|
||||
(HLO_CompareOp
|
||||
$input,
|
||||
(HLO_ConstantLikePosInfValue $input),
|
||||
HLO_COMPARISON_DIRECTION_EQ,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
)>;
|
||||
|
||||
// Express `is_neg_inf` as
|
||||
// is_neg_inf(x) = (x == -inf)
|
||||
def : Pat<(HLOClient_IsNegInfOp NonComplexElementType:$input),
|
||||
(HLO_CompareOp
|
||||
$input,
|
||||
(HLO_ConstantLikeNegInfValue $input),
|
||||
HLO_COMPARISON_DIRECTION_EQ,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
)>;
|
||||
|
||||
// Express `sinh` as
|
||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||
|
|
|
@ -652,3 +652,34 @@ func @erfc_f16(%arg : tensor<f16>) -> tensor<f16> {
|
|||
%1 = "chlo.erfc"(%arg) : (tensor<f16>) -> tensor<f16>
|
||||
return %1 : tensor<f16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @is_inf_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||
func @is_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
|
||||
// CHECK: %[[ABS:.*]] = "mhlo.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[POS_INF:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.compare"(%[[ABS]], %[[POS_INF]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: return %[[RESULT]] : tensor<i1>
|
||||
%1 = chlo.is_inf %arg : tensor<f32> -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @is_pos_inf_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||
func @is_pos_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
|
||||
// CHECK: %[[POS_INF:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.compare"(%[[ARG]], %[[POS_INF]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: return %[[RESULT]] : tensor<i1>
|
||||
%1 = chlo.is_pos_inf %arg : tensor<f32> -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @is_neg_inf_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||
func @is_neg_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
|
||||
// CHECK: %[[NEG_INF:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.compare"(%[[ARG]], %[[NEG_INF]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: return %[[RESULT]] : tensor<i1>
|
||||
%1 = chlo.is_neg_inf %arg : tensor<f32> -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
|
|
@ -90,10 +90,10 @@ func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32> -> tensor<?xf32>
|
||||
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[B]] : tensor<*xf32>
|
||||
%result = chlo.tan %a : tensor<*xf32>
|
||||
%result = chlo.tan %a : tensor<*xf32> -> tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue