[MLIR][HLO] Add `Elementwise` trait to unary element-wise ops
PiperOrigin-RevId: 363428909
This commit is contained in:
parent
cd52adb20e
commit
f1408e791e
|
@ -119,31 +119,30 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
||||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||||
|
|
||||||
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
Type TensorType>: HLO_Op<mnemonic,
|
Type TensorType> : HLO_Op<mnemonic, traits # [Elementwise,
|
||||||
!listconcat(traits,
|
InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
||||||
[InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
SameOperandsAndResultShape]> {
|
||||||
SameOperandsAndResultShape])> {
|
let arguments = (ins TensorType:$operand);
|
||||||
let arguments = (ins TensorType:$operand);
|
let results = (outs TensorType);
|
||||||
let results = (outs TensorType);
|
let extraClassDeclaration = [{
|
||||||
let extraClassDeclaration = [{
|
static LogicalResult inferReturnTypeComponents(
|
||||||
static LogicalResult inferReturnTypeComponents(
|
MLIRContext* context, Optional<Location> location,
|
||||||
MLIRContext* context, Optional<Location> location,
|
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||||
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
return failure();
|
||||||
return failure();
|
}
|
||||||
}
|
LogicalResult reifyReturnTypeShapes(
|
||||||
LogicalResult reifyReturnTypeShapes(
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
&reifiedReturnShapes);
|
||||||
&reifiedReturnShapes);
|
}
|
||||||
}
|
bool inferInputOutputShapeEquality(int input, int output) {
|
||||||
bool inferInputOutputShapeEquality(int input, int output) {
|
return true;
|
||||||
return true;
|
}
|
||||||
}
|
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
return getOperation()->getResult(0);
|
||||||
return getOperation()->getResult(0);
|
}
|
||||||
}
|
}];
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Abs supports complex to real, so element type is not guaranteed to match.
|
// Abs supports complex to real, so element type is not guaranteed to match.
|
||||||
|
|
Loading…
Reference in New Issue