[MLIR][HLO] Annotate `mhlo.clamp` and `mhlo.select` as element-wise broadcasting
The operations allow for a limited form of broadcasting which allows some operands to be scalars. As such they are neither strictly `Elementwise`, nor `Broadcasting`. They do fulfill the requirements for `BroadcastingElementwise` though. PiperOrigin-RevId: 379719961
This commit is contained in:
parent
a65cf627c4
commit
82696f8598
1
BUILD
1
BUILD
|
@ -505,6 +505,7 @@ cc_library(
|
|||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h",
|
||||
"include/mlir-hlo/utils/broadcast_utils.h",
|
||||
"include/mlir-hlo/utils/hlo_utils.h",
|
||||
],
|
||||
|
|
|
@ -53,10 +53,6 @@ namespace mlir {
|
|||
namespace chlo {
|
||||
namespace OpTrait {
|
||||
|
||||
template <typename ConcreteType>
|
||||
class BroadcastingElementwise
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class Broadcasting
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};
|
||||
|
|
|
@ -68,10 +68,6 @@ class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
|
|||
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
|
||||
}
|
||||
|
||||
def HLOClient_BroadcastingElementwise
|
||||
: HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CHLO binary elementwise op definitions.
|
||||
// From the client perspective, each of these support both explicit rank
|
||||
|
@ -89,7 +85,7 @@ def HLOClient_BroadcastingElementwise
|
|||
|
||||
class HLOClient_BroadcastBinaryElementwiseOp<
|
||||
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
|
||||
HLOClient_BroadcastingElementwise, HLOClient_Broadcasting,
|
||||
HLO_BroadcastingElementwise, HLOClient_Broadcasting,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
|
||||
let arguments = (ins
|
||||
|
@ -581,7 +577,7 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
|||
}
|
||||
|
||||
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||
[NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise,
|
||||
[NoSideEffect, HLOClient_Broadcasting, HLO_BroadcastingElementwise,
|
||||
SameOperandsAndResultShape, InferTypeOpInterface,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
|
@ -711,7 +707,7 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
|
||||
NoSideEffect,
|
||||
HLOClient_Broadcasting,
|
||||
HLOClient_BroadcastingElementwise,
|
||||
HLO_BroadcastingElementwise,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||
"inferReturnTypeComponents"]>]> {
|
||||
string summary = "Select operator (with optional numpy-style broadcasting)";
|
||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
// clang-format off
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
|
|
|
@ -1361,8 +1361,8 @@ def HLO_CholeskyOp : HLO_Op<"cholesky",
|
|||
let results = (outs HLO_FpOrComplexTensor);
|
||||
}
|
||||
|
||||
def HLO_ClampOp : HLO_Op<"clamp",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
def HLO_ClampOp : HLO_Op<"clamp", [NoSideEffect,
|
||||
SameOperandsAndResultElementType, HLO_BroadcastingElementwise]> {
|
||||
let summary = "Clamp operator";
|
||||
let description = [{
|
||||
Clamps an operand to within the range between a minimum and maximum value.
|
||||
|
@ -1373,12 +1373,12 @@ def HLO_ClampOp : HLO_Op<"clamp",
|
|||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#clamp.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$min,
|
||||
HLO_Tensor:$operand,
|
||||
HLO_Tensor:$max
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
|
@ -1743,11 +1743,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]> {
|
|||
}
|
||||
|
||||
// TODO(jpienaar): Add broadcastable trait.
|
||||
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect,
|
||||
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
]> {
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Select operator";
|
||||
let description = [{
|
||||
Constructs an output tensor from the elements of `on_true` and `on_false`
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace OpTrait {
|
||||
|
||||
template <typename ConcreteType>
|
||||
class BroadcastingElementwise
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
|
@ -152,4 +152,16 @@ def ConvolutionAttributes {
|
|||
class BASE_HLO_ConvOp {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
|
||||
let cppNamespace = "::mlir::mhlo::OpTrait";
|
||||
}
|
||||
|
||||
// An operation that is essentially element-wise but may implement broadcasting
|
||||
// semantics.
|
||||
def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
|
||||
|
||||
#endif // HLO_OPS_BASE
|
||||
|
|
|
@ -1561,7 +1561,8 @@ class DynamicReshapeOpSameShapeOpResult
|
|||
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* def_op = op.operand().getDefiningOp();
|
||||
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
||||
if (!def_op ||
|
||||
!def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
|
||||
return failure();
|
||||
}
|
||||
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
|
||||
|
@ -2098,7 +2099,7 @@ Operation* ReduceWindowOp::getReductionOp(int result_index) {
|
|||
if (arg0_num == result_index && arg1_num == other_arg_index)
|
||||
return compute_op;
|
||||
if (arg0_num == other_arg_index && arg1_num == result_index &&
|
||||
compute_op->hasTrait<OpTrait::IsCommutative>())
|
||||
compute_op->hasTrait<mlir::OpTrait::IsCommutative>())
|
||||
return compute_op;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -177,8 +177,8 @@ struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
|
|||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Apply to all elementwise and broadcasting elementwise operations.
|
||||
if (!op->hasTrait<OpTrait::Elementwise>() &&
|
||||
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
|
||||
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
|
||||
return failure();
|
||||
|
||||
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
|
||||
|
@ -336,8 +336,8 @@ struct EarlyBroadcastInDimOpPattern
|
|||
PatternRewriter &rewriter) const override {
|
||||
Operation *producer_op = bcast_op.operand().getDefiningOp();
|
||||
if (!producer_op ||
|
||||
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
|
||||
!producer_op->hasTrait<OpTrait::Elementwise>()) {
|
||||
!producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
|
||||
!producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -66,9 +66,9 @@ namespace {
|
|||
bool IsClusterable(Operation *op) {
|
||||
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
|
||||
if (op->getNumOperands() == 0) return false;
|
||||
return (op->hasTrait<OpTrait::Elementwise>() &&
|
||||
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
|
||||
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
|
||||
return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
|
||||
(op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() &&
|
||||
op->hasTrait<chlo::OpTrait::Broadcasting>());
|
||||
}
|
||||
|
||||
|
@ -729,7 +729,7 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
|
|||
for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
|
||||
};
|
||||
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||
if (nested_op.hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
||||
if (nested_op.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
|
||||
union_sets(nested_op.getOperands());
|
||||
union_sets(nested_op.getResults());
|
||||
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
|
||||
|
|
|
@ -65,7 +65,7 @@ class SinkConstantsToControlFlowPass
|
|||
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
||||
Value constant = use->get();
|
||||
auto op = constant.getDefiningOp();
|
||||
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
|
||||
if (!op || !op->hasTrait<mlir::OpTrait::ConstantLike>()) return;
|
||||
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
||||
if (!map_entry.second) {
|
||||
// This constant has already been cloned into the region, reuse it.
|
||||
|
|
Loading…
Reference in New Issue