Implement constant folding for mhlo.Sign.
PiperOrigin-RevId: 373550014
This commit is contained in:
		
							parent
							
								
									d764806c1e
								
							
						
					
					
						commit
						d2cc74317c
					
				|  | @ -396,6 +396,7 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", | ||||||
|     See |     See | ||||||
|     https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. |     https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. | ||||||
|   }]; |   }]; | ||||||
|  |   let hasFolder = 1; | ||||||
| } | } | ||||||
| def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", | def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", | ||||||
|     [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { |     [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { | ||||||
|  |  | ||||||
|  | @ -2498,6 +2498,29 @@ struct logical_not { | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | template <typename FloatOrInt> | ||||||
|  | struct sign { | ||||||
|  |   APFloat compute(const APFloat& f) { | ||||||
|  |     if (f.isZero() || f.isNaN()) return f; | ||||||
|  |     double value = f.isNegative() ? -1.0 : 1.0; | ||||||
|  |     APFloat val(value); | ||||||
|  |     bool unused; | ||||||
|  |     val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused); | ||||||
|  |     return val; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   APInt compute(const APInt& i) { | ||||||
|  |     APInt r = i; | ||||||
|  |     if (r == 0) return r; | ||||||
|  |     if (r.isNegative()) { | ||||||
|  |       return APInt(r.getBitWidth(), -1, /*isSigned=*/true); | ||||||
|  |     } | ||||||
|  |     return APInt(r.getBitWidth(), 1, /*isSigned=*/true); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| #define UNARY_FOLDER(Op, Func)                                                \ | #define UNARY_FOLDER(Op, Func)                                                \ | ||||||
|   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                          \ |   OpFoldResult Op::fold(ArrayRef<Attribute> attrs) {                          \ | ||||||
|     if (getElementTypeOrSelf(getType()).isa<FloatType>())                     \ |     if (getElementTypeOrSelf(getType()).isa<FloatType>())                     \ | ||||||
|  | @ -2522,6 +2545,7 @@ struct logical_not { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| UNARY_FOLDER(NegOp, std::negate); | UNARY_FOLDER(NegOp, std::negate); | ||||||
|  | UNARY_FOLDER(SignOp, sign); | ||||||
| UNARY_FOLDER_INT(NotOp, logical_not); | UNARY_FOLDER_INT(NotOp, logical_not); | ||||||
| UNARY_FOLDER_FLOAT(RoundOp, round); | UNARY_FOLDER_FLOAT(RoundOp, round); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -679,6 +679,46 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> { | ||||||
|   return %arg0 : tensor<i64> |   return %arg0 : tensor<i64> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: fold_sign_posi | ||||||
|  | func @fold_sign_posi() -> tensor<i32> { | ||||||
|  |   // CHECK: %0 = mhlo.constant dense<1> : tensor<i32> | ||||||
|  |   %0 = mhlo.constant dense<2> : tensor<i32> | ||||||
|  |   %1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32> | ||||||
|  |   return %1 : tensor<i32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: fold_sign_negi | ||||||
|  | func @fold_sign_negi() -> tensor<i32> { | ||||||
|  |   // CHECK: %0 = mhlo.constant dense<-1> : tensor<i32> | ||||||
|  |   %0 = mhlo.constant dense<-2> : tensor<i32> | ||||||
|  |   %1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32> | ||||||
|  |   return %1 : tensor<i32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: fold_sign_posf | ||||||
|  | func @fold_sign_posf() -> tensor<bf16> { | ||||||
|  |   // CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor<bf16> | ||||||
|  |   %0 = mhlo.constant dense<2.000000e+00> : tensor<bf16> | ||||||
|  |   %1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16> | ||||||
|  |   return %1 : tensor<bf16> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: fold_sign_negf | ||||||
|  | func @fold_sign_negf() -> tensor<bf16> { | ||||||
|  |   // CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor<bf16> | ||||||
|  |   %0 = mhlo.constant dense<-2.000000e+00> : tensor<bf16> | ||||||
|  |   %1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16> | ||||||
|  |   return %1 : tensor<bf16> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: fold_sign_negzf | ||||||
|  | func @fold_sign_negzf() -> tensor<bf16> { | ||||||
|  |   // CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16> | ||||||
|  |   %0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16> | ||||||
|  |   %1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16> | ||||||
|  |   return %1 : tensor<bf16> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: fold_compare_same_eq | // CHECK-LABEL: fold_compare_same_eq | ||||||
| func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> { | func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> { | ||||||
|   // CHECK: %0 = mhlo.constant dense<true> : tensor<i1> |   // CHECK: %0 = mhlo.constant dense<true> : tensor<i1> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue