diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index d3aeefe..4ac5f69 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -524,7 +524,8 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc", }]; } -def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor, +def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", + [DeclareOpInterfaceMethods], HLO_FpTensor, HLO_PredTensor> { let summary = "IsInf predicate"; @@ -533,8 +534,9 @@ def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor, }]; } -def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [], - HLO_FpTensor, HLO_PredTensor> { +def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", + [DeclareOpInterfaceMethods], HLO_FpTensor, + HLO_PredTensor> { let summary = "IsNegInf predicate"; let description = [{ @@ -542,8 +544,9 @@ def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [], }]; } -def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [], - HLO_FpTensor, HLO_PredTensor> { +def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", + [DeclareOpInterfaceMethods], HLO_FpTensor, + HLO_PredTensor> { let summary = "IsPosInf predicate"; let description = [{ diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 58fe0e2..f0e7819 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -227,12 +227,52 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents( attributes, element_type, inferedReturnShapes); } + LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), reifiedReturnShapes); } +//===----------------------------------------------------------------------===// +// IsInfOp +//===----------------------------------------------------------------------===// + +static Type getIsInfLikeReturnType(Value operand) { + Builder b(operand.getContext()); + return mhlo::getSameShapeTensorType(operand.getType().cast(), + b.getI1Type()); +} + +LogicalResult IsInfOp::inferReturnTypes( + MLIRContext* ctx, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front())); + return success(); +} + +//===----------------------------------------------------------------------===// +// IsNegInfOp +//===----------------------------------------------------------------------===// + +LogicalResult IsNegInfOp::inferReturnTypes( + MLIRContext* ctx, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front())); + return success(); +} + +//===----------------------------------------------------------------------===// +// IsPosInfOp +//===----------------------------------------------------------------------===// + +LogicalResult IsPosInfOp::inferReturnTypes( + MLIRContext* ctx, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front())); + return success(); +} + //===----------------------------------------------------------------------===// // Macros for method definitions that are common to most broadcasting ops. //===----------------------------------------------------------------------===//