[XLA][MLIR] Lower `tf.Tan` and `tf.Sin` to MLHLO
Add `tan` op and lowering to CHLO dialect, move CHLO lowerings to `chlo_legalize_to_hlo_patterns` and extend missing patterns. PiperOrigin-RevId: 331506094
This commit is contained in:
parent
ce1c8a1ebc
commit
48022987ce
|
@ -364,6 +364,19 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||
let summary = "Tan operation";
|
||||
|
||||
let description = [{
|
||||
Returns `Tan(operand)` element-wise.
|
||||
|
||||
$$
|
||||
\tan(x) = \sin(x) / \cos(x)
|
||||
$$
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
InferTypeOpInterface,
|
||||
|
|
|
@ -25,6 +25,10 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td)
|
|||
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td)
|
||||
mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRChloLegalizeToHloIncGen)
|
||||
|
||||
|
||||
add_mlir_library(ChloPasses
|
||||
chlo_legalize_to_hlo.cc
|
||||
|
@ -32,6 +36,7 @@ add_mlir_library(ChloPasses
|
|||
|
||||
DEPENDS
|
||||
MLIRhlo_opsIncGen
|
||||
MLIRChloLegalizeToHloIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
|
|
@ -469,10 +469,13 @@ struct HloCompareAdaptor {
|
|||
}
|
||||
};
|
||||
|
||||
#include "generated_chlo_legalize_to_hlo.inc"
|
||||
} // namespace
|
||||
|
||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
|
||||
// Instantiate conversion templates for conforming binary elementwise ops
|
||||
// that do not have different dtypes between operands and results and do
|
||||
// not have special attributes that need to be preserved.
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the legalization pattern definition file for CHLO to MHLO.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Expand acos to MHLO dialect as follows:
|
||||
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
|
||||
// = pi if x == -1
|
||||
def : Pat<(HLOClient_AcosOp $input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp $input,
|
||||
(HLO_ConstantLike<"0"> $input),
|
||||
HLO_COMPARISON_DIRECTION_NE
|
||||
),
|
||||
(HLO_MulOp
|
||||
(HLO_ConstantLike<"2.0f"> $input),
|
||||
(HLO_Atan2Op
|
||||
(HLO_SqrtOp
|
||||
(HLO_SubOp
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
(HLO_MulOp $input, $input)
|
||||
)
|
||||
),
|
||||
(HLO_AddOp
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
$input
|
||||
)
|
||||
)
|
||||
),
|
||||
(HLO_ConstantLike<"M_PI"> $input))>;
|
||||
|
||||
// Express tan in MHLO dialect as
|
||||
// tan(x) = sin(x) / cos(x).
|
||||
def : Pat<(HLOClient_TanOp $input),
|
||||
(HLO_DivOp
|
||||
(HLO_SinOp $input),
|
||||
(HLO_CosOp $input)
|
||||
)>;
|
||||
|
|
@ -515,6 +515,7 @@ void populateHLOToLHLOConversionPattern(
|
|||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||
HloToLhloOpConverter<mhlo::SignOp>,
|
||||
HloToLhloOpConverter<mhlo::SinOp>,
|
||||
HloToLhloOpConverter<mhlo::SliceOp>,
|
||||
HloToLhloOpConverter<mhlo::SqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::SubOp>,
|
||||
|
|
Loading…
Reference in New Issue