[MLIR][HLO] Annotate `mhlo.clamp` and `mhlo.select` as element-wise broadcasting

The operations allow for a limited form of broadcasting which allows some
operands to be scalars. As such they are neither strictly `Elementwise`, nor
`Broadcasting`. They do fulfill the requirements for `BroadcastingElementwise`
though.

PiperOrigin-RevId: 379719961
This commit is contained in:
A. Unique TensorFlower 2021-06-16 07:58:09 -07:00 committed by TensorFlow MLIR Team
parent a65cf627c4
commit 82696f8598
11 changed files with 67 additions and 28 deletions

1
BUILD
View File

@ -505,6 +505,7 @@ cc_library(
hdrs = [ hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h", "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h",
"include/mlir-hlo/utils/broadcast_utils.h", "include/mlir-hlo/utils/broadcast_utils.h",
"include/mlir-hlo/utils/hlo_utils.h", "include/mlir-hlo/utils/hlo_utils.h",
], ],

View File

@ -53,10 +53,6 @@ namespace mlir {
namespace chlo { namespace chlo {
namespace OpTrait { namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
template <typename ConcreteType> template <typename ConcreteType>
class Broadcasting class Broadcasting
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {}; : public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};

View File

@ -68,10 +68,6 @@ class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> { def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
} }
def HLOClient_BroadcastingElementwise
: HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CHLO binary elementwise op definitions. // CHLO binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank // From the client perspective, each of these support both explicit rank
@ -89,7 +85,7 @@ def HLOClient_BroadcastingElementwise
class HLOClient_BroadcastBinaryElementwiseOp< class HLOClient_BroadcastBinaryElementwiseOp<
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [ string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
HLOClient_BroadcastingElementwise, HLOClient_Broadcasting, HLO_BroadcastingElementwise, HLOClient_Broadcasting,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [ DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> { "inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
let arguments = (ins let arguments = (ins
@ -581,7 +577,7 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
} }
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
[NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise, [NoSideEffect, HLOClient_Broadcasting, HLO_BroadcastingElementwise,
SameOperandsAndResultShape, InferTypeOpInterface, SameOperandsAndResultShape, InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>, ["inferReturnTypeComponents"]>,
@ -711,7 +707,7 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [ def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
NoSideEffect, NoSideEffect,
HLOClient_Broadcasting, HLOClient_Broadcasting,
HLOClient_BroadcastingElementwise, HLO_BroadcastingElementwise,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [ DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
"inferReturnTypeComponents"]>]> { "inferReturnTypeComponents"]>]> {
string summary = "Select operator (with optional numpy-style broadcasting)"; string summary = "Select operator (with optional numpy-style broadcasting)";

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
// clang-format off // clang-format off
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"

View File

@ -1361,8 +1361,8 @@ def HLO_CholeskyOp : HLO_Op<"cholesky",
let results = (outs HLO_FpOrComplexTensor); let results = (outs HLO_FpOrComplexTensor);
} }
def HLO_ClampOp : HLO_Op<"clamp", def HLO_ClampOp : HLO_Op<"clamp", [NoSideEffect,
[NoSideEffect, SameOperandsAndResultElementType]> { SameOperandsAndResultElementType, HLO_BroadcastingElementwise]> {
let summary = "Clamp operator"; let summary = "Clamp operator";
let description = [{ let description = [{
Clamps an operand to within the range between a minimum and maximum value. Clamps an operand to within the range between a minimum and maximum value.
@ -1373,12 +1373,12 @@ def HLO_ClampOp : HLO_Op<"clamp",
See https://www.tensorflow.org/xla/operation_semantics#clamp. See https://www.tensorflow.org/xla/operation_semantics#clamp.
}]; }];
let arguments = (ins let arguments = (ins
HLO_Tensor:$min, HLO_Tensor:$min,
HLO_Tensor:$operand, HLO_Tensor:$operand,
HLO_Tensor:$max HLO_Tensor:$max
); );
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
@ -1743,11 +1743,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]> {
} }
// TODO(jpienaar): Add broadcastable trait. // TODO(jpienaar): Add broadcastable trait.
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>, ["inferReturnTypeComponents", "reifyReturnTypeShapes"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface>, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
]> {
let summary = "Select operator"; let summary = "Select operator";
let description = [{ let description = [{
Constructs an output tensor from the elements of `on_true` and `on_false` Constructs an output tensor from the elements of `on_true` and `on_false`

View File

@ -0,0 +1,33 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace mhlo {
namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
} // namespace OpTrait
} // namespace mhlo
} // namespace mlir
#endif

View File

@ -152,4 +152,16 @@ def ConvolutionAttributes {
class BASE_HLO_ConvOp { class BASE_HLO_ConvOp {
} }
//===----------------------------------------------------------------------===//
// Common traits
//===----------------------------------------------------------------------===//
class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
let cppNamespace = "::mlir::mhlo::OpTrait";
}
// An operation that is essentially element-wise but may implement broadcasting
// semantics.
def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
#endif // HLO_OPS_BASE #endif // HLO_OPS_BASE

View File

@ -1561,7 +1561,8 @@ class DynamicReshapeOpSameShapeOpResult
LogicalResult matchAndRewrite(DynamicReshapeOp op, LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
Operation* def_op = op.operand().getDefiningOp(); Operation* def_op = op.operand().getDefiningOp();
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) { if (!def_op ||
!def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
return failure(); return failure();
} }
Operation* input_def_op = def_op->getOperand(0).getDefiningOp(); Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
@ -2098,7 +2099,7 @@ Operation* ReduceWindowOp::getReductionOp(int result_index) {
if (arg0_num == result_index && arg1_num == other_arg_index) if (arg0_num == result_index && arg1_num == other_arg_index)
return compute_op; return compute_op;
if (arg0_num == other_arg_index && arg1_num == result_index && if (arg0_num == other_arg_index && arg1_num == result_index &&
compute_op->hasTrait<OpTrait::IsCommutative>()) compute_op->hasTrait<mlir::OpTrait::IsCommutative>())
return compute_op; return compute_op;
return nullptr; return nullptr;
} }

View File

@ -177,8 +177,8 @@ struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations. // Apply to all elementwise and broadcasting elementwise operations.
if (!op->hasTrait<OpTrait::Elementwise>() && if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>()) !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
return failure(); return failure();
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter); return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
@ -336,8 +336,8 @@ struct EarlyBroadcastInDimOpPattern
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Operation *producer_op = bcast_op.operand().getDefiningOp(); Operation *producer_op = bcast_op.operand().getDefiningOp();
if (!producer_op || if (!producer_op ||
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() || !producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
!producer_op->hasTrait<OpTrait::Elementwise>()) { !producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
return failure(); return failure();
} }

View File

@ -66,9 +66,9 @@ namespace {
bool IsClusterable(Operation *op) { bool IsClusterable(Operation *op) {
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false; if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
if (op->getNumOperands() == 0) return false; if (op->getNumOperands() == 0) return false;
return (op->hasTrait<OpTrait::Elementwise>() && return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) || op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() && (op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>()); op->hasTrait<chlo::OpTrait::Broadcasting>());
} }
@ -729,7 +729,7 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
for (Value v : vs.drop_front()) eqs.unionSets(repr, v); for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
}; };
for (Operation &nested_op : op.getBody()->without_terminator()) { for (Operation &nested_op : op.getBody()->without_terminator()) {
if (nested_op.hasTrait<OpTrait::SameOperandsAndResultShape>()) { if (nested_op.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
union_sets(nested_op.getOperands()); union_sets(nested_op.getOperands());
union_sets(nested_op.getResults()); union_sets(nested_op.getResults());
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())

View File

@ -65,7 +65,7 @@ class SinkConstantsToControlFlowPass
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
Value constant = use->get(); Value constant = use->get();
auto op = constant.getDefiningOp(); auto op = constant.getDefiningOp();
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return; if (!op || !op->hasTrait<mlir::OpTrait::ConstantLike>()) return;
auto map_entry = sunk_constant.try_emplace(constant, nullptr); auto map_entry = sunk_constant.try_emplace(constant, nullptr);
if (!map_entry.second) { if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it. // This constant has already been cloned into the region, reuse it.