2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
|
|
|
// This is the legalization pattern that converts complex operations into
|
|
|
|
// equivalent real value operations.
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
include "mlir/IR/OpBase.td"
|
|
|
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
|
|
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Binary op patterns.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Add and subtraction are elementwise and can be distributed across the real
|
|
|
|
// and imaginary components.
|
|
|
|
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
|
|
|
|
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
|
|
|
|
HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
|
|
|
|
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
|
|
|
|
|
|
|
|
// Complex multiplication results in a cross product multiplication between the
|
|
|
|
// real and imaginary components such that:
|
|
|
|
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
|
|
|
|
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
|
|
|
|
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
|
|
|
|
HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_SubOp
|
|
|
|
(HLO_MulOp
|
|
|
|
(HLO_RealOp:$lhs_real $lhs),
|
|
|
|
(HLO_RealOp:$rhs_real $rhs)),
|
|
|
|
(HLO_MulOp
|
|
|
|
(HLO_ImagOp:$lhs_imag $lhs),
|
|
|
|
(HLO_ImagOp:$rhs_imag $rhs))),
|
|
|
|
(HLO_AddOp
|
|
|
|
(HLO_MulOp $lhs_real, $rhs_imag),
|
|
|
|
(HLO_MulOp $lhs_imag, $rhs_real)))>;
|
|
|
|
|
|
|
|
// Multiplication between a complex and real tensor can be distributed by
|
|
|
|
// applying the real multiplicant to both the real and complex component.
|
|
|
|
//
|
|
|
|
// Note that the sourcep pattern is not legal according to the HLO dialect but
|
|
|
|
// instead handle intermediates generated by other patterns.
|
|
|
|
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
|
|
|
|
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
|
|
|
|
|
|
|
|
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
|
|
|
|
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
|
|
|
|
|
|
|
|
|
|
|
|
// Division is performed by normalizing the denominator by multiplying by the
|
|
|
|
// conjugate of the rhs.
|
|
|
|
// numerator = lhs * conj(rhs)
|
|
|
|
// denominator = rhs * conj(rhs)
|
|
|
|
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
|
|
|
|
(HLO_DivOp
|
|
|
|
(HLO_MulOp:$num $lhs,
|
|
|
|
(HLO_ComplexOp:$conj
|
|
|
|
(HLO_RealOp $rhs),
|
|
|
|
(HLO_NegOp (HLO_ImagOp $rhs)))),
|
|
|
|
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
|
|
|
|
|
|
|
|
|
|
|
|
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
|
|
|
|
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
|
|
|
|
|
|
|
|
|
|
|
|
// Absolute value is evaluated as:
|
|
|
|
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
|
|
|
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
|
|
|
(HLO_SqrtOp
|
|
|
|
(HLO_AddOp
|
|
|
|
(HLO_MulOp (HLO_RealOp:$real $val), $real),
|
2020-07-25 06:17:48 +08:00
|
|
|
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Exponential can be lowered to an exponential on the real component and a
|
|
|
|
// sum of sinusoids of the imaginary component, which equates to a normal
|
|
|
|
// exponential operator multiplied by Euler's formula.
|
|
|
|
//
|
|
|
|
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
|
|
|
|
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
|
|
|
(HLO_MulOp
|
|
|
|
(HLO_ExpOp (HLO_RealOp $val)),
|
|
|
|
(HLO_ComplexOp
|
|
|
|
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
|
|
|
(HLO_SinOp $imag)))>;
|