[MLIR][CHLO] Implement type inference for `is_inf`-like operations in CHLO
PiperOrigin-RevId: 354265834
This commit is contained in:
		
							parent
							
								
									fe2e5a175f
								
							
						
					
					
						commit
						c3ddcd6c7f
					
				| 
						 | 
				
			
			@ -524,7 +524,8 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
 | 
			
		|||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
 | 
			
		||||
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf",
 | 
			
		||||
    [DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
 | 
			
		||||
    HLO_PredTensor> {
 | 
			
		||||
  let summary = "IsInf predicate";
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -533,8 +534,9 @@ def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
 | 
			
		|||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
 | 
			
		||||
    HLO_FpTensor, HLO_PredTensor> {
 | 
			
		||||
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf",
 | 
			
		||||
    [DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
 | 
			
		||||
    HLO_PredTensor> {
 | 
			
		||||
  let summary = "IsNegInf predicate";
 | 
			
		||||
 | 
			
		||||
  let description = [{
 | 
			
		||||
| 
						 | 
				
			
			@ -542,8 +544,9 @@ def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
 | 
			
		|||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [],
 | 
			
		||||
    HLO_FpTensor, HLO_PredTensor> {
 | 
			
		||||
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf",
 | 
			
		||||
    [DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
 | 
			
		||||
    HLO_PredTensor> {
 | 
			
		||||
  let summary = "IsPosInf predicate";
 | 
			
		||||
 | 
			
		||||
  let description = [{
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -227,12 +227,52 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
 | 
			
		|||
                                                    attributes, element_type,
 | 
			
		||||
                                                    inferedReturnShapes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
 | 
			
		||||
    OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
 | 
			
		||||
  return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
 | 
			
		||||
                                                reifiedReturnShapes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// IsInfOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
static Type getIsInfLikeReturnType(Value operand) {
 | 
			
		||||
  Builder b(operand.getContext());
 | 
			
		||||
  return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
 | 
			
		||||
                                      b.getI1Type());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult IsInfOp::inferReturnTypes(
 | 
			
		||||
    MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
 | 
			
		||||
    RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
 | 
			
		||||
  inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// IsNegInfOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
LogicalResult IsNegInfOp::inferReturnTypes(
 | 
			
		||||
    MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
 | 
			
		||||
    RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
 | 
			
		||||
  inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// IsPosInfOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
LogicalResult IsPosInfOp::inferReturnTypes(
 | 
			
		||||
    MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
 | 
			
		||||
    RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
 | 
			
		||||
  inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Macros for method definitions that are common to most broadcasting ops.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue