[MLIR] Add cbrt, reduce-precision, and bitcast ops to MHLO.
PiperOrigin-RevId: 335109804
This commit is contained in:
		
							parent
							
								
									bcf6fbf612
								
							
						
					
					
						commit
						c708bfd6d0
					
				| 
						 | 
				
			
			@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
 | 
			
		|||
  >];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
 | 
			
		||||
 | 
			
		||||
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1423,4 +1426,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
 | 
			
		|||
  let hasCustomHLOConverter = 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This is an op for purposes internal to XLA/GPU.
 | 
			
		||||
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
 | 
			
		||||
  let arguments = (ins HLO_Tensor:$operand);
 | 
			
		||||
  let results = (outs HLO_Tensor);
 | 
			
		||||
  let hasCustomHLOConverter = 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
 | 
			
		||||
                           BASE_HLO_ReducePrecisionOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    HLO_FpTensor:$operand,
 | 
			
		||||
    I32Attr:$exponent_bits,
 | 
			
		||||
    I32Attr:$mantissa_bits
 | 
			
		||||
  );
 | 
			
		||||
  let results = (outs HLO_FpTensor:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // HLO_OPS
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
 | 
			
		|||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class BASE_HLO_CbrtOp {
 | 
			
		||||
  string summary = "Cubic root operator";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns element-wise cubic root of the operand.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class BASE_HLO_CeilOp {
 | 
			
		||||
  string summary = "Ceil operator";
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1336,4 +1347,17 @@ class BASE_HLO_WhileOp {
 | 
			
		|||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class BASE_HLO_BitcastOp {
 | 
			
		||||
  string summary = "Bitcast operator";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    This op changes the shape of the input in the way that the physical
 | 
			
		||||
    arranggment of elements are unchanged.
 | 
			
		||||
 | 
			
		||||
    However, the op needs layout information to make sense of "physical
 | 
			
		||||
    arrangement of elements". Layout support in MHLO is currently under
 | 
			
		||||
    exploration.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // HLO_OPS_BASE
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1193,3 +1193,24 @@ func @incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
 | 
			
		|||
  %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
 | 
			
		||||
  return %0 : tensor<?xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
 | 
			
		||||
  %0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
 | 
			
		||||
  return %0 : tensor<2x4xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
 | 
			
		||||
  %0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
 | 
			
		||||
  return %0 : tensor<2x4xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
 | 
			
		||||
  %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
 | 
			
		||||
  return %0 : tensor<2x4xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue