From 48022987ce7a88c447909e4bdc0fc3a3d8bb8ad7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Sep 2020 02:30:26 -0700 Subject: [PATCH] [XLA][MLIR] Lower `tf.Tan` and `tf.Sin` to MLHLO Add `tan` op and lowering to CHLO dialect, move CHLO lowerings to `chlo_legalize_to_hlo_patterns` and extend missing patterns. PiperOrigin-RevId: 331506094 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 19 +++++- lib/Dialect/mhlo/transforms/CMakeLists.txt | 5 ++ .../mhlo/transforms/chlo_legalize_to_hlo.cc | 3 + .../chlo_legalize_to_hlo_patterns.td | 59 +++++++++++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 1 + 5 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 2f3bbef..5b6cf34 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -344,13 +344,13 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< //===----------------------------------------------------------------------===// class HLOClient_UnaryElementwiseOp traits, - Type TensorType>: HLOClient_Op : HLOClient_Op { let arguments = (ins TensorType:$operand); let results = (outs TensorType); } -def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", +def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { let summary = "Acos operator"; @@ -364,7 +364,20 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", }]; } -def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like", +def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Tan operation"; + + let description = [{ + Returns `Tan(operand)` element-wise. + + $$ + \tan(x) = \sin(x) / \cos(x) + $$ + }]; +} + +def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", [NoSideEffect, SameOperandsAndResultShape, InferTypeOpInterface, DeclareOpInterfaceMethods, diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 5ae0ec2..e02add4 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -25,6 +25,10 @@ set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td) mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters) add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) +set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td) +mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters) +add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) + add_mlir_library(ChloPasses chlo_legalize_to_hlo.cc @@ -32,6 +36,7 @@ add_mlir_library(ChloPasses DEPENDS MLIRhlo_opsIncGen + MLIRChloLegalizeToHloIncGen LINK_COMPONENTS Core diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index c2db488..de2a99b 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -469,10 +469,13 @@ struct HloCompareAdaptor { } }; +#include "generated_chlo_legalize_to_hlo.inc" } // namespace void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { + populateWithGenerated(context, patterns); + // Instantiate conversion templates for conforming binary elementwise ops // that do not have different dtypes between operands and results and do // not have special attributes that need to be preserved. diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td new file mode 100644 index 0000000..7b612ff --- /dev/null +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for CHLO to MHLO. + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +// Expand acos to MHLO dialect as follows: +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// = pi if x == -1 +def : Pat<(HLOClient_AcosOp $input), + (HLO_SelectOp + (HLO_CompareOp $input, + (HLO_ConstantLike<"0"> $input), + HLO_COMPARISON_DIRECTION_NE + ), + (HLO_MulOp + (HLO_ConstantLike<"2.0f"> $input), + (HLO_Atan2Op + (HLO_SqrtOp + (HLO_SubOp + (HLO_ConstantLike<"1"> $input), + (HLO_MulOp $input, $input) + ) + ), + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + $input + ) + ) + ), + (HLO_ConstantLike<"M_PI"> $input))>; + +// Express tan in MHLO dialect as +// tan(x) = sin(x) / cos(x). +def : Pat<(HLOClient_TanOp $input), + (HLO_DivOp + (HLO_SinOp $input), + (HLO_CosOp $input) + )>; + diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index e900bae..edf2544 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -515,6 +515,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter,