2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
|
|
|
// This is the legalization pattern that converts complex operations into
|
|
|
|
// equivalent real value operations.
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
include "mlir/IR/OpBase.td"
|
|
|
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
|
|
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Binary op patterns.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Add and subtraction are elementwise and can be distributed across the real
|
|
|
|
// and imaginary components.
|
|
|
|
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
|
|
|
|
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
|
|
|
|
HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
|
|
|
|
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
|
|
|
|
|
|
|
|
// Complex multiplication results in a cross product multiplication between the
|
|
|
|
// real and imaginary components such that:
|
|
|
|
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
|
|
|
|
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
|
|
|
|
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
|
|
|
|
HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_SubOp
|
|
|
|
(HLO_MulOp
|
|
|
|
(HLO_RealOp:$lhs_real $lhs),
|
|
|
|
(HLO_RealOp:$rhs_real $rhs)),
|
|
|
|
(HLO_MulOp
|
|
|
|
(HLO_ImagOp:$lhs_imag $lhs),
|
|
|
|
(HLO_ImagOp:$rhs_imag $rhs))),
|
|
|
|
(HLO_AddOp
|
|
|
|
(HLO_MulOp $lhs_real, $rhs_imag),
|
|
|
|
(HLO_MulOp $lhs_imag, $rhs_real)))>;
|
|
|
|
|
|
|
|
|
|
|
|
// Division is performed by normalizing the denominator by multiplying by the
|
|
|
|
// conjugate of the rhs.
|
|
|
|
// numerator = lhs * conj(rhs)
|
|
|
|
// denominator = rhs * conj(rhs)
|
|
|
|
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
2020-11-07 07:23:11 +08:00
|
|
|
(HLO_DivOp
|
|
|
|
(HLO_RealOp (HLO_MulOp:$num $lhs,
|
|
|
|
(HLO_ComplexOp:$conj
|
|
|
|
(HLO_RealOp $rhs),
|
|
|
|
(HLO_NegOp (HLO_ImagOp $rhs))))),
|
|
|
|
(HLO_AddOp:$den
|
|
|
|
(HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)),
|
|
|
|
(HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))),
|
|
|
|
(HLO_DivOp (HLO_ImagOp $num), $den))>;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Absolute value is evaluated as:
|
|
|
|
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
|
|
|
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
|
|
|
(HLO_SqrtOp
|
|
|
|
(HLO_AddOp
|
|
|
|
(HLO_MulOp (HLO_RealOp:$real $val), $real),
|
2020-07-25 06:17:48 +08:00
|
|
|
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Exponential can be lowered to an exponential on the real component and a
|
|
|
|
// sum of sinusoids of the imaginary component, which equates to a normal
|
|
|
|
// exponential operator multiplied by Euler's formula.
|
|
|
|
//
|
2020-11-07 07:23:11 +08:00
|
|
|
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b))
|
2020-07-07 04:57:00 +08:00
|
|
|
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
2020-11-07 07:23:11 +08:00
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_MulOp
|
2020-07-07 04:57:00 +08:00
|
|
|
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
2020-11-07 07:23:11 +08:00
|
|
|
(HLO_ExpOp:$exp (HLO_RealOp:$real $val))),
|
|
|
|
(HLO_MulOp (HLO_SinOp $imag), $exp))>;
|
2021-04-15 15:34:29 +08:00
|
|
|
|
|
|
|
foreach pair = [[HLO_COMPARISON_DIRECTION_NE, HLO_OrOp],
|
|
|
|
[HLO_COMPARISON_DIRECTION_EQ, HLO_AndOp]] in {
|
|
|
|
def : Pat<(HLO_CompareOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, pair[0], $compare_type),
|
|
|
|
(pair[1]
|
|
|
|
(HLO_CompareOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), pair[0], $compare_type),
|
|
|
|
(HLO_CompareOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), pair[0], $compare_type))>;
|
|
|
|
}
|