[MLIR][HLO] Add `Elementwise` trait to unary element-wise ops

PiperOrigin-RevId: 363428909
This commit is contained in:
A. Unique TensorFlower 2021-03-17 08:49:22 -07:00 committed by TensorFlow MLIR Team
parent cd52adb20e
commit f1408e791e
1 changed files with 24 additions and 25 deletions

View File

@ -119,31 +119,30 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType>: HLO_Op<mnemonic,
!listconcat(traits,
[InferShapedTypeOpInterface, InferFusibilityOpInterface,
SameOperandsAndResultShape])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
bool inferInputOutputShapeEquality(int input, int output) {
return true;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getResult(0);
}
}];
Type TensorType> : HLO_Op<mnemonic, traits # [Elementwise,
InferShapedTypeOpInterface, InferFusibilityOpInterface,
SameOperandsAndResultShape]> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
bool inferInputOutputShapeEquality(int input, int output) {
return true;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getResult(0);
}
}];
}
// Abs supports complex to real, so element type is not guaranteed to match.