[MLIR][HLO] Add `Elementwise` trait to unary element-wise ops

PiperOrigin-RevId: 363428909
This commit is contained in:
A. Unique TensorFlower 2021-03-17 08:49:22 -07:00 committed by TensorFlow MLIR Team
parent cd52adb20e
commit f1408e791e
1 changed files with 24 additions and 25 deletions

View File

@ -119,31 +119,30 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType>: HLO_Op<mnemonic, Type TensorType> : HLO_Op<mnemonic, traits # [Elementwise,
!listconcat(traits, InferShapedTypeOpInterface, InferFusibilityOpInterface,
[InferShapedTypeOpInterface, InferFusibilityOpInterface, SameOperandsAndResultShape]> {
SameOperandsAndResultShape])> { let arguments = (ins TensorType:$operand);
let arguments = (ins TensorType:$operand); let results = (outs TensorType);
let results = (outs TensorType); let extraClassDeclaration = [{
let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents(
static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional<Location> location,
MLIRContext* context, Optional<Location> location, ValueRange operands, DictionaryAttr attributes, RegionRange regions,
ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { return failure();
return failure(); }
} LogicalResult reifyReturnTypeShapes(
LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), &reifiedReturnShapes);
&reifiedReturnShapes); }
} bool inferInputOutputShapeEquality(int input, int output) {
bool inferInputOutputShapeEquality(int input, int output) { return true;
return true; }
} llvm::Optional<Value> inferEffectiveWorkloadShape() {
llvm::Optional<Value> inferEffectiveWorkloadShape() { return getOperation()->getResult(0);
return getOperation()->getResult(0); }
} }];
}];
} }
// Abs supports complex to real, so element type is not guaranteed to match. // Abs supports complex to real, so element type is not guaranteed to match.