[XLA][MLIR] Lower `tf.Tan` and `tf.Sin` to MLHLO
Add `tan` op and lowering to CHLO dialect, move CHLO lowerings to `chlo_legalize_to_hlo_patterns` and extend missing patterns. PiperOrigin-RevId: 331506094
This commit is contained in:
parent
ce1c8a1ebc
commit
48022987ce
|
@ -344,13 +344,13 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
Type TensorType>: HLOClient_Op<mnemonic,
|
Type TensorType> : HLOClient_Op<mnemonic,
|
||||||
!listconcat(traits, [InferFusibilityOpInterface])> {
|
!listconcat(traits, [InferFusibilityOpInterface])> {
|
||||||
let arguments = (ins TensorType:$operand);
|
let arguments = (ins TensorType:$operand);
|
||||||
let results = (outs TensorType);
|
let results = (outs TensorType);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
|
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||||
let summary = "Acos operator";
|
let summary = "Acos operator";
|
||||||
|
|
||||||
|
@ -364,7 +364,20 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
|
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
||||||
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Tan operation";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Tan(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\tan(x) = \sin(x) / \cos(x)
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||||
[NoSideEffect, SameOperandsAndResultShape,
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
InferTypeOpInterface,
|
InferTypeOpInterface,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
||||||
|
|
|
@ -25,6 +25,10 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td)
|
||||||
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
|
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
|
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td)
|
||||||
|
mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(MLIRChloLegalizeToHloIncGen)
|
||||||
|
|
||||||
|
|
||||||
add_mlir_library(ChloPasses
|
add_mlir_library(ChloPasses
|
||||||
chlo_legalize_to_hlo.cc
|
chlo_legalize_to_hlo.cc
|
||||||
|
@ -32,6 +36,7 @@ add_mlir_library(ChloPasses
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRhlo_opsIncGen
|
MLIRhlo_opsIncGen
|
||||||
|
MLIRChloLegalizeToHloIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
|
|
@ -469,10 +469,13 @@ struct HloCompareAdaptor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include "generated_chlo_legalize_to_hlo.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
|
populateWithGenerated(context, patterns);
|
||||||
|
|
||||||
// Instantiate conversion templates for conforming binary elementwise ops
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
// that do not have different dtypes between operands and results and do
|
// that do not have different dtypes between operands and results and do
|
||||||
// not have special attributes that need to be preserved.
|
// not have special attributes that need to be preserved.
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the legalization pattern definition file for CHLO to MHLO.
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Unary op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Expand acos to MHLO dialect as follows:
|
||||||
|
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
|
||||||
|
// = pi if x == -1
|
||||||
|
def : Pat<(HLOClient_AcosOp $input),
|
||||||
|
(HLO_SelectOp
|
||||||
|
(HLO_CompareOp $input,
|
||||||
|
(HLO_ConstantLike<"0"> $input),
|
||||||
|
HLO_COMPARISON_DIRECTION_NE
|
||||||
|
),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_ConstantLike<"2.0f"> $input),
|
||||||
|
(HLO_Atan2Op
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_SubOp
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
(HLO_MulOp $input, $input)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
$input
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(HLO_ConstantLike<"M_PI"> $input))>;
|
||||||
|
|
||||||
|
// Express tan in MHLO dialect as
|
||||||
|
// tan(x) = sin(x) / cos(x).
|
||||||
|
def : Pat<(HLOClient_TanOp $input),
|
||||||
|
(HLO_DivOp
|
||||||
|
(HLO_SinOp $input),
|
||||||
|
(HLO_CosOp $input)
|
||||||
|
)>;
|
||||||
|
|
|
@ -515,6 +515,7 @@ void populateHLOToLHLOConversionPattern(
|
||||||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||||
HloToLhloOpConverter<mhlo::SignOp>,
|
HloToLhloOpConverter<mhlo::SignOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::SinOp>,
|
||||||
HloToLhloOpConverter<mhlo::SliceOp>,
|
HloToLhloOpConverter<mhlo::SliceOp>,
|
||||||
HloToLhloOpConverter<mhlo::SqrtOp>,
|
HloToLhloOpConverter<mhlo::SqrtOp>,
|
||||||
HloToLhloOpConverter<mhlo::SubOp>,
|
HloToLhloOpConverter<mhlo::SubOp>,
|
||||||
|
|
Loading…
Reference in New Issue