[MLIR][HLO] Add `Elementwise` trait to unary element-wise ops
PiperOrigin-RevId: 363428909
This commit is contained in:
parent
cd52adb20e
commit
f1408e791e
|
@ -119,31 +119,30 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
|||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type TensorType>: HLO_Op<mnemonic,
|
||||
!listconcat(traits,
|
||||
[InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
||||
SameOperandsAndResultShape])> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType);
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
return failure();
|
||||
}
|
||||
LogicalResult reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getResult(0);
|
||||
}
|
||||
}];
|
||||
Type TensorType> : HLO_Op<mnemonic, traits # [Elementwise,
|
||||
InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
||||
SameOperandsAndResultShape]> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType);
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
return failure();
|
||||
}
|
||||
LogicalResult reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getResult(0);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// Abs supports complex to real, so element type is not guaranteed to match.
|
||||
|
|
Loading…
Reference in New Issue