[MLIR][HLO] Annotate `mhlo.clamp` and `mhlo.select` as element-wise broadcasting

The operations allow for a limited form of broadcasting which allows some
operands to be scalars. As such they are neither strictly `Elementwise`, nor
`Broadcasting`. They do fulfill the requirements for `BroadcastingElementwise`
though.

PiperOrigin-RevId: 379719961
This commit is contained in:
A. Unique TensorFlower 2021-06-16 07:58:09 -07:00 committed by TensorFlow MLIR Team
parent a65cf627c4
commit 82696f8598
11 changed files with 67 additions and 28 deletions

1
BUILD
View File

@ -505,6 +505,7 @@ cc_library(
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h",
"include/mlir-hlo/utils/broadcast_utils.h",
"include/mlir-hlo/utils/hlo_utils.h",
],

View File

@ -53,10 +53,6 @@ namespace mlir {
namespace chlo {
namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
template <typename ConcreteType>
class Broadcasting
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};

View File

@ -68,10 +68,6 @@ class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
}
def HLOClient_BroadcastingElementwise
: HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
}
//===----------------------------------------------------------------------===//
// CHLO binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank
@ -89,7 +85,7 @@ def HLOClient_BroadcastingElementwise
class HLOClient_BroadcastBinaryElementwiseOp<
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
HLOClient_BroadcastingElementwise, HLOClient_Broadcasting,
HLO_BroadcastingElementwise, HLOClient_Broadcasting,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
let arguments = (ins
@ -581,7 +577,7 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
}
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
[NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise,
[NoSideEffect, HLOClient_Broadcasting, HLO_BroadcastingElementwise,
SameOperandsAndResultShape, InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
@ -711,7 +707,7 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
NoSideEffect,
HLOClient_Broadcasting,
HLOClient_BroadcastingElementwise,
HLO_BroadcastingElementwise,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
"inferReturnTypeComponents"]>]> {
string summary = "Select operator (with optional numpy-style broadcasting)";

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h"
// clang-format off
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"

View File

@ -1361,8 +1361,8 @@ def HLO_CholeskyOp : HLO_Op<"cholesky",
let results = (outs HLO_FpOrComplexTensor);
}
def HLO_ClampOp : HLO_Op<"clamp",
[NoSideEffect, SameOperandsAndResultElementType]> {
def HLO_ClampOp : HLO_Op<"clamp", [NoSideEffect,
SameOperandsAndResultElementType, HLO_BroadcastingElementwise]> {
let summary = "Clamp operator";
let description = [{
Clamps an operand to within the range between a minimum and maximum value.
@ -1373,12 +1373,12 @@ def HLO_ClampOp : HLO_Op<"clamp",
See https://www.tensorflow.org/xla/operation_semantics#clamp.
}];
let arguments = (ins
HLO_Tensor:$min,
HLO_Tensor:$operand,
HLO_Tensor:$max
);
let results = (outs HLO_Tensor);
}
@ -1743,11 +1743,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]> {
}
// TODO(jpienaar): Add broadcastable trait.
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect,
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Select operator";
let description = [{
Constructs an output tensor from the elements of `on_true` and `on_false`

View File

@ -0,0 +1,33 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace mhlo {
namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
} // namespace OpTrait
} // namespace mhlo
} // namespace mlir
#endif

View File

@ -152,4 +152,16 @@ def ConvolutionAttributes {
class BASE_HLO_ConvOp {
}
//===----------------------------------------------------------------------===//
// Common traits
//===----------------------------------------------------------------------===//
class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
let cppNamespace = "::mlir::mhlo::OpTrait";
}
// An operation that is essentially element-wise but may implement broadcasting
// semantics.
def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
#endif // HLO_OPS_BASE

View File

@ -1561,7 +1561,8 @@ class DynamicReshapeOpSameShapeOpResult
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
Operation* def_op = op.operand().getDefiningOp();
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
if (!def_op ||
!def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
return failure();
}
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
@ -2098,7 +2099,7 @@ Operation* ReduceWindowOp::getReductionOp(int result_index) {
if (arg0_num == result_index && arg1_num == other_arg_index)
return compute_op;
if (arg0_num == other_arg_index && arg1_num == result_index &&
compute_op->hasTrait<OpTrait::IsCommutative>())
compute_op->hasTrait<mlir::OpTrait::IsCommutative>())
return compute_op;
return nullptr;
}

View File

@ -177,8 +177,8 @@ struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations.
if (!op->hasTrait<OpTrait::Elementwise>() &&
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
return failure();
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
@ -336,8 +336,8 @@ struct EarlyBroadcastInDimOpPattern
PatternRewriter &rewriter) const override {
Operation *producer_op = bcast_op.operand().getDefiningOp();
if (!producer_op ||
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
!producer_op->hasTrait<OpTrait::Elementwise>()) {
!producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
!producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
return failure();
}

View File

@ -66,9 +66,9 @@ namespace {
bool IsClusterable(Operation *op) {
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
if (op->getNumOperands() == 0) return false;
return (op->hasTrait<OpTrait::Elementwise>() &&
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
(op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>());
}
@ -729,7 +729,7 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
};
for (Operation &nested_op : op.getBody()->without_terminator()) {
if (nested_op.hasTrait<OpTrait::SameOperandsAndResultShape>()) {
if (nested_op.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
union_sets(nested_op.getOperands());
union_sets(nested_op.getResults());
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())

View File

@ -65,7 +65,7 @@ class SinkConstantsToControlFlowPass
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
Value constant = use->get();
auto op = constant.getDefiningOp();
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
if (!op || !op->hasTrait<mlir::OpTrait::ConstantLike>()) return;
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it.