From 82696f85980039756b06c2b8c77bfb26215292b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Jun 2021 07:58:09 -0700 Subject: [PATCH] [MLIR][HLO] Annotate `mhlo.clamp` and `mhlo.select` as element-wise broadcasting The operations allow for a limited form of broadcasting which allows some operands to be scalars. As such they are neither strictly `Elementwise`, nor `Broadcasting`. They do fulfill the requirements for `BroadcastingElementwise` though. PiperOrigin-RevId: 379719961 --- BUILD | 1 + include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h | 4 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 10 ++---- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h | 1 + include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 11 +++---- .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h | 33 +++++++++++++++++++ .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td | 12 +++++++ lib/Dialect/mhlo/IR/hlo_ops.cc | 5 +-- .../mhlo/transforms/broadcast_propagation.cc | 8 ++--- .../mhlo/transforms/rank_specialization.cc | 8 ++--- .../sink_constants_to_control_flow.cc | 2 +- 11 files changed, 67 insertions(+), 28 deletions(-) create mode 100644 include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h diff --git a/BUILD b/BUILD index ede0238..6215c5f 100644 --- a/BUILD +++ b/BUILD @@ -505,6 +505,7 @@ cc_library( hdrs = [ "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h", + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h", "include/mlir-hlo/utils/broadcast_utils.h", "include/mlir-hlo/utils/hlo_utils.h", ], diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index b778e94..6cb77a0 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -53,10 +53,6 @@ namespace mlir { namespace chlo { namespace OpTrait { -template -class BroadcastingElementwise - : public mlir::OpTrait::TraitBase {}; - template class Broadcasting : public mlir::OpTrait::TraitBase {}; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index fd806d1..3348559 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -68,10 +68,6 @@ class HLOClient_NativeOpTrait : NativeOpTrait { def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> { } -def HLOClient_BroadcastingElementwise - : HLOClient_NativeOpTrait<"BroadcastingElementwise"> { -} - //===----------------------------------------------------------------------===// // CHLO binary elementwise op definitions. // From the client perspective, each of these support both explicit rank @@ -89,7 +85,7 @@ def HLOClient_BroadcastingElementwise class HLOClient_BroadcastBinaryElementwiseOp< string mnemonic, list traits> : HLOClient_Op]> { let arguments = (ins @@ -581,7 +577,7 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", } def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", - [NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise, + [NoSideEffect, HLOClient_Broadcasting, HLO_BroadcastingElementwise, SameOperandsAndResultShape, InferTypeOpInterface, DeclareOpInterfaceMethods, @@ -711,7 +707,7 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [ NoSideEffect, HLOClient_Broadcasting, - HLOClient_BroadcastingElementwise, + HLO_BroadcastingElementwise, DeclareOpInterfaceMethods]> { string summary = "Select operator (with optional numpy-style broadcasting)"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 427ef62..af25e62 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // clang-format off +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index f713beb..1649793 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1361,8 +1361,8 @@ def HLO_CholeskyOp : HLO_Op<"cholesky", let results = (outs HLO_FpOrComplexTensor); } -def HLO_ClampOp : HLO_Op<"clamp", - [NoSideEffect, SameOperandsAndResultElementType]> { +def HLO_ClampOp : HLO_Op<"clamp", [NoSideEffect, + SameOperandsAndResultElementType, HLO_BroadcastingElementwise]> { let summary = "Clamp operator"; let description = [{ Clamps an operand to within the range between a minimum and maximum value. @@ -1373,12 +1373,12 @@ def HLO_ClampOp : HLO_Op<"clamp", See https://www.tensorflow.org/xla/operation_semantics#clamp. }]; + let arguments = (ins HLO_Tensor:$min, HLO_Tensor:$operand, HLO_Tensor:$max ); - let results = (outs HLO_Tensor); } @@ -1743,11 +1743,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]> { } // TODO(jpienaar): Add broadcastable trait. -def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, +def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - ]> { + DeclareOpInterfaceMethods]> { let summary = "Select operator"; let description = [{ Constructs an output tensor from the elements of `on_true` and `on_false` diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h new file mode 100644 index 0000000..48e158e --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace mhlo { +namespace OpTrait { + +template +class BroadcastingElementwise + : public mlir::OpTrait::TraitBase {}; + +} // namespace OpTrait +} // namespace mhlo +} // namespace mlir + +#endif diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index e7f68c5..bcc16b6 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -152,4 +152,16 @@ def ConvolutionAttributes { class BASE_HLO_ConvOp { } +//===----------------------------------------------------------------------===// +// Common traits +//===----------------------------------------------------------------------===// + +class HLO_NativeOpTrait : NativeOpTrait { + let cppNamespace = "::mlir::mhlo::OpTrait"; +} + +// An operation that is essentially element-wise but may implement broadcasting +// semantics. +def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">; + #endif // HLO_OPS_BASE diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 1101cf2..d98706b 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1561,7 +1561,8 @@ class DynamicReshapeOpSameShapeOpResult LogicalResult matchAndRewrite(DynamicReshapeOp op, PatternRewriter& rewriter) const override { Operation* def_op = op.operand().getDefiningOp(); - if (!def_op || !def_op->hasTrait()) { + if (!def_op || + !def_op->hasTrait()) { return failure(); } Operation* input_def_op = def_op->getOperand(0).getDefiningOp(); @@ -2098,7 +2099,7 @@ Operation* ReduceWindowOp::getReductionOp(int result_index) { if (arg0_num == result_index && arg1_num == other_arg_index) return compute_op; if (arg0_num == other_arg_index && arg1_num == result_index && - compute_op->hasTrait()) + compute_op->hasTrait()) return compute_op; return nullptr; } diff --git a/lib/Dialect/mhlo/transforms/broadcast_propagation.cc b/lib/Dialect/mhlo/transforms/broadcast_propagation.cc index a147855..82ca3e7 100644 --- a/lib/Dialect/mhlo/transforms/broadcast_propagation.cc +++ b/lib/Dialect/mhlo/transforms/broadcast_propagation.cc @@ -177,8 +177,8 @@ struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Apply to all elementwise and broadcasting elementwise operations. - if (!op->hasTrait() && - !op->hasTrait()) + if (!op->hasTrait() && + !op->hasTrait()) return failure(); return MoveIntoAssumingOpMatchAndRewrite(op, rewriter); @@ -336,8 +336,8 @@ struct EarlyBroadcastInDimOpPattern PatternRewriter &rewriter) const override { Operation *producer_op = bcast_op.operand().getDefiningOp(); if (!producer_op || - !producer_op->hasTrait() || - !producer_op->hasTrait()) { + !producer_op->hasTrait() || + !producer_op->hasTrait()) { return failure(); } diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 4db229c..cd9eb00 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -66,9 +66,9 @@ namespace { bool IsClusterable(Operation *op) { if (!llvm::isa(op)) return false; if (op->getNumOperands() == 0) return false; - return (op->hasTrait() && - op->hasTrait()) || - (op->hasTrait() && + return (op->hasTrait() && + op->hasTrait()) || + (op->hasTrait() && op->hasTrait()); } @@ -729,7 +729,7 @@ SmallVector, 4> FindNonScalarShapeEquivalences( for (Value v : vs.drop_front()) eqs.unionSets(repr, v); }; for (Operation &nested_op : op.getBody()->without_terminator()) { - if (nested_op.hasTrait()) { + if (nested_op.hasTrait()) { union_sets(nested_op.getOperands()); union_sets(nested_op.getResults()); if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index ae57fd5..869fa87 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -65,7 +65,7 @@ class SinkConstantsToControlFlowPass visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { Value constant = use->get(); auto op = constant.getDefiningOp(); - if (!op || !op->hasTrait()) return; + if (!op || !op->hasTrait()) return; auto map_entry = sunk_constant.try_emplace(constant, nullptr); if (!map_entry.second) { // This constant has already been cloned into the region, reuse it.