[XLA][MLIR] Lower `tf.Tan` and `tf.Sin` to MLHLO

Add `tan` op and lowering to CHLO dialect, move CHLO lowerings to
`chlo_legalize_to_hlo_patterns` and extend missing patterns.

PiperOrigin-RevId: 331506094
This commit is contained in:
A. Unique TensorFlower 2020-09-14 02:30:26 -07:00 committed by TensorFlow MLIR Team
parent ce1c8a1ebc
commit 48022987ce
5 changed files with 84 additions and 3 deletions

View File

@ -364,6 +364,19 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
}]; }];
} }
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
let summary = "Tan operation";
let description = [{
Returns `Tan(operand)` element-wise.
$$
\tan(x) = \sin(x) / \cos(x)
$$
}];
}
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
[NoSideEffect, SameOperandsAndResultShape, [NoSideEffect, SameOperandsAndResultShape,
InferTypeOpInterface, InferTypeOpInterface,

View File

@ -25,6 +25,10 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td)
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters) mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td)
mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters)
add_public_tablegen_target(MLIRChloLegalizeToHloIncGen)
add_mlir_library(ChloPasses add_mlir_library(ChloPasses
chlo_legalize_to_hlo.cc chlo_legalize_to_hlo.cc
@ -32,6 +36,7 @@ add_mlir_library(ChloPasses
DEPENDS DEPENDS
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
MLIRChloLegalizeToHloIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core

View File

@ -469,10 +469,13 @@ struct HloCompareAdaptor {
} }
}; };
#include "generated_chlo_legalize_to_hlo.inc"
} // namespace } // namespace
void PopulateLegalizeChloToHloPatterns(MLIRContext *context, void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns);
// Instantiate conversion templates for conforming binary elementwise ops // Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do // that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved. // not have special attributes that need to be preserved.

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern definition file for CHLO to MHLO.
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//
// Expand acos to MHLO dialect as follows:
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
// = pi if x == -1
def : Pat<(HLOClient_AcosOp $input),
(HLO_SelectOp
(HLO_CompareOp $input,
(HLO_ConstantLike<"0"> $input),
HLO_COMPARISON_DIRECTION_NE
),
(HLO_MulOp
(HLO_ConstantLike<"2.0f"> $input),
(HLO_Atan2Op
(HLO_SqrtOp
(HLO_SubOp
(HLO_ConstantLike<"1"> $input),
(HLO_MulOp $input, $input)
)
),
(HLO_AddOp
(HLO_ConstantLike<"1"> $input),
$input
)
)
),
(HLO_ConstantLike<"M_PI"> $input))>;
// Express tan in MHLO dialect as
// tan(x) = sin(x) / cos(x).
def : Pat<(HLOClient_TanOp $input),
(HLO_DivOp
(HLO_SinOp $input),
(HLO_CosOp $input)
)>;

View File

@ -515,6 +515,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::ReshapeOp>, HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<mhlo::SelectOp>, HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<mhlo::SignOp>, HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<mhlo::SinOp>,
HloToLhloOpConverter<mhlo::SliceOp>, HloToLhloOpConverter<mhlo::SliceOp>,
HloToLhloOpConverter<mhlo::SqrtOp>, HloToLhloOpConverter<mhlo::SqrtOp>,
HloToLhloOpConverter<mhlo::SubOp>, HloToLhloOpConverter<mhlo::SubOp>,