/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This is the legalization pattern that converts complex operations into
// equivalent real value operations.

include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"

//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//

// Add and subtraction are elementwise and can be distributed across the real
// and imaginary components.
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
  def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
             HLO_ComplexTensor:$rhs),
            (HLO_ComplexOp
              (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
              (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;

// Complex multiplication results in a cross product multiplication between the
// real and imaginary components such that:
//   result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
//   result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
           HLO_ComplexTensor:$rhs),
          (HLO_ComplexOp
           (HLO_SubOp
            (HLO_MulOp
             (HLO_RealOp:$lhs_real $lhs),
             (HLO_RealOp:$rhs_real $rhs)),
            (HLO_MulOp
             (HLO_ImagOp:$lhs_imag $lhs),
             (HLO_ImagOp:$rhs_imag $rhs))),
           (HLO_AddOp
            (HLO_MulOp $lhs_real, $rhs_imag),
            (HLO_MulOp $lhs_imag, $rhs_real)))>;

// Multiplication between a complex and real tensor can be distributed by
// applying the real multiplicant to both the real and complex component.
//
// Note that the sourcep pattern is not legal according to the HLO dialect but
// instead handle intermediates generated by other patterns.
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
          (HLO_ComplexOp
           (HLO_MulOp (HLO_RealOp $lhs), $rhs),
           (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;

def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
          (HLO_ComplexOp
           (HLO_MulOp $lhs, (HLO_RealOp $rhs)),
           (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;


// Division is performed by normalizing the denominator by multiplying by the
// conjugate of the rhs.
//   numerator = lhs * conj(rhs)
//   denominator = rhs * conj(rhs)
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
            (HLO_DivOp
             (HLO_MulOp:$num $lhs,
              (HLO_ComplexOp:$conj
               (HLO_RealOp $rhs),
               (HLO_NegOp (HLO_ImagOp $rhs)))),
             (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;


def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
          (HLO_ComplexOp
           (HLO_DivOp (HLO_RealOp $lhs), $rhs),
           (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;


// Absolute value is evaluated as:
//   result = sqrt(val.real * val.real + val.imag * val.imag)
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
           (HLO_SqrtOp
             (HLO_AddOp
              (HLO_MulOp (HLO_RealOp:$real $val), $real),
              (HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;

// Exponential can be lowered to an exponential on the real component and a
// sum of sinusoids of the imaginary component, which equates to a normal
// exponential operator multiplied by Euler's formula.
//
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
          (HLO_MulOp
           (HLO_ExpOp (HLO_RealOp $val)),
           (HLO_ComplexOp
            (HLO_CosOp (HLO_ImagOp:$imag $val)),
            (HLO_SinOp $imag)))>;