[KERNEL_GEN] Add unranked Conj kernel.
PiperOrigin-RevId: 344243271
This commit is contained in:
parent
a6948f6b41
commit
5583c63cab
|
@ -372,6 +372,19 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Conj operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Conj(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\conj(x) = (\real(x), \neg(\imag(x)))
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Sinh operation";
|
let summary = "Sinh operation";
|
||||||
|
|
|
@ -60,6 +60,11 @@ def : Pat<(HLOClient_AtanOp $input),
|
||||||
(HLO_ConstantLike<"1"> $input)
|
(HLO_ConstantLike<"1"> $input)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
|
// Express `conj` as
|
||||||
|
// conj(x) = (re(x), -im(x)).
|
||||||
|
def : Pat<(HLOClient_ConjOp $v),
|
||||||
|
(HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>;
|
||||||
|
|
||||||
// Express `sinh` as
|
// Express `sinh` as
|
||||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||||
|
|
|
@ -50,8 +50,8 @@ namespace {
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(AcosOp) sep fn(AtanOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) \
|
fn(AcosOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) sep fn(ErfcOp) \
|
||||||
sep fn(TanOp)
|
sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
|
|
@ -24,3 +24,14 @@ func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> {
|
||||||
return %result : tensor<?x?xf32>
|
return %result : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @conj
|
||||||
|
func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
|
||||||
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor
|
||||||
|
// CHECK-NEXT: [[R1:%.*]] = "mhlo.real"([[INPUT]])
|
||||||
|
// CHECK-NEXT: [[R2:%.*]] = "mhlo.imag"([[INPUT]])
|
||||||
|
// CHECK-NEXT: [[R3:%.*]] = "mhlo.negate"([[R2]])
|
||||||
|
// CHECK-NEXT: [[R4:%.*]] = "mhlo.complex"([[R1]], [[R3]])
|
||||||
|
%1 = "chlo.conj"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>>
|
||||||
|
return %1 : tensor<3xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue