[XLA][MLIR] Lower `tf.Tan` and `tf.Sin` to MLHLO

Add `tan` op and lowering to CHLO dialect, move CHLO lowerings to
`chlo_legalize_to_hlo_patterns` and extend missing patterns.

PiperOrigin-RevId: 331128170
This commit is contained in:
A. Unique TensorFlower 2020-09-11 05:05:26 -07:00 committed by TensorFlow MLIR Team
parent 90927f6b53
commit a7a7184eb6
5 changed files with 3 additions and 84 deletions

View File

@ -344,13 +344,13 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
//===----------------------------------------------------------------------===//
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType> : HLOClient_Op<mnemonic,
Type TensorType>: HLOClient_Op<mnemonic,
!listconcat(traits, [InferFusibilityOpInterface])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
}
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
let summary = "Acos operator";
@ -364,20 +364,7 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
}];
}
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
let summary = "Tan operation";
let description = [{
Returns `Tan(operand)` element-wise.
$$
\tan(x) = \sin(x) / \cos(x)
$$
}];
}
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
[NoSideEffect, SameOperandsAndResultShape,
InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,

View File

@ -25,10 +25,6 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td)
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td)
mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters)
add_public_tablegen_target(MLIRChloLegalizeToHloIncGen)
add_mlir_library(ChloPasses
chlo_legalize_to_hlo.cc
@ -36,7 +32,6 @@ add_mlir_library(ChloPasses
DEPENDS
MLIRhlo_opsIncGen
MLIRChloLegalizeToHloIncGen
LINK_COMPONENTS
Core

View File

@ -469,13 +469,10 @@ struct HloCompareAdaptor {
}
};
#include "generated_chlo_legalize_to_hlo.inc"
} // namespace
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns);
// Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved.

View File

@ -1,59 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern definition file for CHLO to MHLO.
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//
// Expand acos to MHLO dialect as follows:
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
// = pi if x == -1
def : Pat<(HLOClient_AcosOp $input),
(HLO_SelectOp
(HLO_CompareOp $input,
(HLO_ConstantLike<"0"> $input),
HLO_COMPARISON_DIRECTION_NE
),
(HLO_MulOp
(HLO_ConstantLike<"2.0f"> $input),
(HLO_Atan2Op
(HLO_SqrtOp
(HLO_SubOp
(HLO_ConstantLike<"1"> $input),
(HLO_MulOp $input, $input)
)
),
(HLO_AddOp
(HLO_ConstantLike<"1"> $input),
$input
)
)
),
(HLO_ConstantLike<"M_PI"> $input))>;
// Express tan in MHLO dialect as
// tan(x) = sin(x) / cos(x).
def : Pat<(HLOClient_TanOp $input),
(HLO_DivOp
(HLO_SinOp $input),
(HLO_CosOp $input)
)>;

View File

@ -515,7 +515,6 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<mhlo::SinOp>,
HloToLhloOpConverter<mhlo::SliceOp>,
HloToLhloOpConverter<mhlo::SqrtOp>,
HloToLhloOpConverter<mhlo::SubOp>,