[XLA][MLIR] Lower `tf.Tan` and `tf.Sin` to MLHLO
Add `tan` op and lowering to CHLO dialect, move CHLO lowerings to `chlo_legalize_to_hlo_patterns` and extend missing patterns. PiperOrigin-RevId: 331128170
This commit is contained in:
parent
90927f6b53
commit
a7a7184eb6
|
@ -344,13 +344,13 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
Type TensorType> : HLOClient_Op<mnemonic,
|
Type TensorType>: HLOClient_Op<mnemonic,
|
||||||
!listconcat(traits, [InferFusibilityOpInterface])> {
|
!listconcat(traits, [InferFusibilityOpInterface])> {
|
||||||
let arguments = (ins TensorType:$operand);
|
let arguments = (ins TensorType:$operand);
|
||||||
let results = (outs TensorType);
|
let results = (outs TensorType);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||||
let summary = "Acos operator";
|
let summary = "Acos operator";
|
||||||
|
|
||||||
|
@ -364,20 +364,7 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
|
||||||
let summary = "Tan operation";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
Returns `Tan(operand)` element-wise.
|
|
||||||
|
|
||||||
$$
|
|
||||||
\tan(x) = \sin(x) / \cos(x)
|
|
||||||
$$
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
|
||||||
[NoSideEffect, SameOperandsAndResultShape,
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
InferTypeOpInterface,
|
InferTypeOpInterface,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
||||||
|
|
|
@ -25,10 +25,6 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td)
|
||||||
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
|
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
|
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td)
|
|
||||||
mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters)
|
|
||||||
add_public_tablegen_target(MLIRChloLegalizeToHloIncGen)
|
|
||||||
|
|
||||||
|
|
||||||
add_mlir_library(ChloPasses
|
add_mlir_library(ChloPasses
|
||||||
chlo_legalize_to_hlo.cc
|
chlo_legalize_to_hlo.cc
|
||||||
|
@ -36,7 +32,6 @@ add_mlir_library(ChloPasses
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRhlo_opsIncGen
|
MLIRhlo_opsIncGen
|
||||||
MLIRChloLegalizeToHloIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
|
|
@ -469,13 +469,10 @@ struct HloCompareAdaptor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "generated_chlo_legalize_to_hlo.inc"
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
populateWithGenerated(context, patterns);
|
|
||||||
|
|
||||||
// Instantiate conversion templates for conforming binary elementwise ops
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
// that do not have different dtypes between operands and results and do
|
// that do not have different dtypes between operands and results and do
|
||||||
// not have special attributes that need to be preserved.
|
// not have special attributes that need to be preserved.
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// This is the legalization pattern definition file for CHLO to MHLO.
|
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Unary op patterns.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Expand acos to MHLO dialect as follows:
|
|
||||||
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
|
|
||||||
// = pi if x == -1
|
|
||||||
def : Pat<(HLOClient_AcosOp $input),
|
|
||||||
(HLO_SelectOp
|
|
||||||
(HLO_CompareOp $input,
|
|
||||||
(HLO_ConstantLike<"0"> $input),
|
|
||||||
HLO_COMPARISON_DIRECTION_NE
|
|
||||||
),
|
|
||||||
(HLO_MulOp
|
|
||||||
(HLO_ConstantLike<"2.0f"> $input),
|
|
||||||
(HLO_Atan2Op
|
|
||||||
(HLO_SqrtOp
|
|
||||||
(HLO_SubOp
|
|
||||||
(HLO_ConstantLike<"1"> $input),
|
|
||||||
(HLO_MulOp $input, $input)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
(HLO_AddOp
|
|
||||||
(HLO_ConstantLike<"1"> $input),
|
|
||||||
$input
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
(HLO_ConstantLike<"M_PI"> $input))>;
|
|
||||||
|
|
||||||
// Express tan in MHLO dialect as
|
|
||||||
// tan(x) = sin(x) / cos(x).
|
|
||||||
def : Pat<(HLOClient_TanOp $input),
|
|
||||||
(HLO_DivOp
|
|
||||||
(HLO_SinOp $input),
|
|
||||||
(HLO_CosOp $input)
|
|
||||||
)>;
|
|
||||||
|
|
|
@ -515,7 +515,6 @@ void populateHLOToLHLOConversionPattern(
|
||||||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||||
HloToLhloOpConverter<mhlo::SignOp>,
|
HloToLhloOpConverter<mhlo::SignOp>,
|
||||||
HloToLhloOpConverter<mhlo::SinOp>,
|
|
||||||
HloToLhloOpConverter<mhlo::SliceOp>,
|
HloToLhloOpConverter<mhlo::SliceOp>,
|
||||||
HloToLhloOpConverter<mhlo::SqrtOp>,
|
HloToLhloOpConverter<mhlo::SqrtOp>,
|
||||||
HloToLhloOpConverter<mhlo::SubOp>,
|
HloToLhloOpConverter<mhlo::SubOp>,
|
||||||
|
|
Loading…
Reference in New Issue