[MLIR][HLO] Annotate `mhlo.clamp` and `mhlo.select` as element-wise broadcasting
The operations allow for a limited form of broadcasting which allows some operands to be scalars. As such they are neither strictly `Elementwise`, nor `Broadcasting`. They do fulfill the requirements for `BroadcastingElementwise` though. PiperOrigin-RevId: 379719961
This commit is contained in:
parent
a65cf627c4
commit
82696f8598
1
BUILD
1
BUILD
|
@ -505,6 +505,7 @@ cc_library(
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h",
|
"include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h",
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h",
|
||||||
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h",
|
||||||
"include/mlir-hlo/utils/broadcast_utils.h",
|
"include/mlir-hlo/utils/broadcast_utils.h",
|
||||||
"include/mlir-hlo/utils/hlo_utils.h",
|
"include/mlir-hlo/utils/hlo_utils.h",
|
||||||
],
|
],
|
||||||
|
|
|
@ -53,10 +53,6 @@ namespace mlir {
|
||||||
namespace chlo {
|
namespace chlo {
|
||||||
namespace OpTrait {
|
namespace OpTrait {
|
||||||
|
|
||||||
template <typename ConcreteType>
|
|
||||||
class BroadcastingElementwise
|
|
||||||
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
|
|
||||||
|
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class Broadcasting
|
class Broadcasting
|
||||||
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};
|
: public mlir::OpTrait::TraitBase<ConcreteType, Broadcasting> {};
|
||||||
|
|
|
@ -68,10 +68,6 @@ class HLOClient_NativeOpTrait<string name> : NativeOpTrait<name> {
|
||||||
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
|
def HLOClient_Broadcasting : HLOClient_NativeOpTrait<"Broadcasting"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_BroadcastingElementwise
|
|
||||||
: HLOClient_NativeOpTrait<"BroadcastingElementwise"> {
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CHLO binary elementwise op definitions.
|
// CHLO binary elementwise op definitions.
|
||||||
// From the client perspective, each of these support both explicit rank
|
// From the client perspective, each of these support both explicit rank
|
||||||
|
@ -89,7 +85,7 @@ def HLOClient_BroadcastingElementwise
|
||||||
|
|
||||||
class HLOClient_BroadcastBinaryElementwiseOp<
|
class HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
|
string mnemonic, list<OpTrait> traits> : HLOClient_Op<mnemonic, traits # [
|
||||||
HLOClient_BroadcastingElementwise, HLOClient_Broadcasting,
|
HLO_BroadcastingElementwise, HLOClient_Broadcasting,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||||
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
|
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -581,7 +577,7 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||||
[NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise,
|
[NoSideEffect, HLOClient_Broadcasting, HLO_BroadcastingElementwise,
|
||||||
SameOperandsAndResultShape, InferTypeOpInterface,
|
SameOperandsAndResultShape, InferTypeOpInterface,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||||
["inferReturnTypeComponents"]>,
|
["inferReturnTypeComponents"]>,
|
||||||
|
@ -711,7 +707,7 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
|
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
HLOClient_Broadcasting,
|
HLOClient_Broadcasting,
|
||||||
HLOClient_BroadcastingElementwise,
|
HLO_BroadcastingElementwise,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||||
"inferReturnTypeComponents"]>]> {
|
"inferReturnTypeComponents"]>]> {
|
||||||
string summary = "Select operator (with optional numpy-style broadcasting)";
|
string summary = "Select operator (with optional numpy-style broadcasting)";
|
||||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||||
|
|
|
@ -1361,8 +1361,8 @@ def HLO_CholeskyOp : HLO_Op<"cholesky",
|
||||||
let results = (outs HLO_FpOrComplexTensor);
|
let results = (outs HLO_FpOrComplexTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_ClampOp : HLO_Op<"clamp",
|
def HLO_ClampOp : HLO_Op<"clamp", [NoSideEffect,
|
||||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
SameOperandsAndResultElementType, HLO_BroadcastingElementwise]> {
|
||||||
let summary = "Clamp operator";
|
let summary = "Clamp operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
Clamps an operand to within the range between a minimum and maximum value.
|
Clamps an operand to within the range between a minimum and maximum value.
|
||||||
|
@ -1373,12 +1373,12 @@ def HLO_ClampOp : HLO_Op<"clamp",
|
||||||
|
|
||||||
See https://www.tensorflow.org/xla/operation_semantics#clamp.
|
See https://www.tensorflow.org/xla/operation_semantics#clamp.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$min,
|
HLO_Tensor:$min,
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
HLO_Tensor:$max
|
HLO_Tensor:$max
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1743,11 +1743,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jpienaar): Add broadcastable trait.
|
// TODO(jpienaar): Add broadcastable trait.
|
||||||
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect,
|
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||||
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>,
|
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>,
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
]> {
|
|
||||||
let summary = "Select operator";
|
let summary = "Select operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
Constructs an output tensor from the elements of `on_true` and `on_false`
|
Constructs an output tensor from the elements of `on_true` and `on_false`
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
namespace OpTrait {
|
||||||
|
|
||||||
|
template <typename ConcreteType>
|
||||||
|
class BroadcastingElementwise
|
||||||
|
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
|
||||||
|
|
||||||
|
} // namespace OpTrait
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif
|
|
@ -152,4 +152,16 @@ def ConvolutionAttributes {
|
||||||
class BASE_HLO_ConvOp {
|
class BASE_HLO_ConvOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common traits
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class HLO_NativeOpTrait<string name> : NativeOpTrait<name> {
|
||||||
|
let cppNamespace = "::mlir::mhlo::OpTrait";
|
||||||
|
}
|
||||||
|
|
||||||
|
// An operation that is essentially element-wise but may implement broadcasting
|
||||||
|
// semantics.
|
||||||
|
def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">;
|
||||||
|
|
||||||
#endif // HLO_OPS_BASE
|
#endif // HLO_OPS_BASE
|
||||||
|
|
|
@ -1561,7 +1561,8 @@ class DynamicReshapeOpSameShapeOpResult
|
||||||
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
Operation* def_op = op.operand().getDefiningOp();
|
Operation* def_op = op.operand().getDefiningOp();
|
||||||
if (!def_op || !def_op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
if (!def_op ||
|
||||||
|
!def_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
|
Operation* input_def_op = def_op->getOperand(0).getDefiningOp();
|
||||||
|
@ -2098,7 +2099,7 @@ Operation* ReduceWindowOp::getReductionOp(int result_index) {
|
||||||
if (arg0_num == result_index && arg1_num == other_arg_index)
|
if (arg0_num == result_index && arg1_num == other_arg_index)
|
||||||
return compute_op;
|
return compute_op;
|
||||||
if (arg0_num == other_arg_index && arg1_num == result_index &&
|
if (arg0_num == other_arg_index && arg1_num == result_index &&
|
||||||
compute_op->hasTrait<OpTrait::IsCommutative>())
|
compute_op->hasTrait<mlir::OpTrait::IsCommutative>())
|
||||||
return compute_op;
|
return compute_op;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -177,8 +177,8 @@ struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
|
||||||
LogicalResult matchAndRewrite(Operation *op,
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Apply to all elementwise and broadcasting elementwise operations.
|
// Apply to all elementwise and broadcasting elementwise operations.
|
||||||
if (!op->hasTrait<OpTrait::Elementwise>() &&
|
if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||||
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
|
!op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
|
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
|
||||||
|
@ -336,8 +336,8 @@ struct EarlyBroadcastInDimOpPattern
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Operation *producer_op = bcast_op.operand().getDefiningOp();
|
Operation *producer_op = bcast_op.operand().getDefiningOp();
|
||||||
if (!producer_op ||
|
if (!producer_op ||
|
||||||
!producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
|
!producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
|
||||||
!producer_op->hasTrait<OpTrait::Elementwise>()) {
|
!producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,9 +66,9 @@ namespace {
|
||||||
bool IsClusterable(Operation *op) {
|
bool IsClusterable(Operation *op) {
|
||||||
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
|
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
|
||||||
if (op->getNumOperands() == 0) return false;
|
if (op->getNumOperands() == 0) return false;
|
||||||
return (op->hasTrait<OpTrait::Elementwise>() &&
|
return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||||
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
|
op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
|
||||||
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
|
(op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() &&
|
||||||
op->hasTrait<chlo::OpTrait::Broadcasting>());
|
op->hasTrait<chlo::OpTrait::Broadcasting>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -729,7 +729,7 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
|
||||||
for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
|
for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
|
||||||
};
|
};
|
||||||
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||||
if (nested_op.hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
if (nested_op.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
|
||||||
union_sets(nested_op.getOperands());
|
union_sets(nested_op.getOperands());
|
||||||
union_sets(nested_op.getResults());
|
union_sets(nested_op.getResults());
|
||||||
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
|
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
|
||||||
|
|
|
@ -65,7 +65,7 @@ class SinkConstantsToControlFlowPass
|
||||||
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
||||||
Value constant = use->get();
|
Value constant = use->get();
|
||||||
auto op = constant.getDefiningOp();
|
auto op = constant.getDefiningOp();
|
||||||
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
|
if (!op || !op->hasTrait<mlir::OpTrait::ConstantLike>()) return;
|
||||||
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
||||||
if (!map_entry.second) {
|
if (!map_entry.second) {
|
||||||
// This constant has already been cloned into the region, reuse it.
|
// This constant has already been cloned into the region, reuse it.
|
||||||
|
|
Loading…
Reference in New Issue