Use InferTypeOpInterface for HLO AbsOp and fix result shape inference
Shape inference in case of ops with complex element types need to use the element type of complex as the result element type and not the full operand type. Before: "mhlo.abs"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xtensor<4xcomplex<f32>>> After: "mhlo.abs"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32> PiperOrigin-RevId: 348123967
This commit is contained in:
parent
43ede42ce1
commit
8d051723c0
|
@ -146,10 +146,9 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
|
|
||||||
// Abs supports complex to real, so element type is not guaranteed to match.
|
// Abs supports complex to real, so element type is not guaranteed to match.
|
||||||
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
||||||
[NoSideEffect, SameOperandsAndResultShape],
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
||||||
let builders = [
|
|
||||||
OpBuilderDAG<(ins "Value":$operand)>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
||||||
|
|
|
@ -454,18 +454,23 @@ static LogicalResult Verify(DynamicUpdateSliceOp op) {
|
||||||
// AbsOp
|
// AbsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) {
|
LogicalResult AbsOp::inferReturnTypes(
|
||||||
auto shaped_type = operand.getType().cast<ShapedType>();
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
Type new_type;
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
if (!shaped_type.getElementType().isa<ComplexType>()) {
|
auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
|
||||||
new_type = operand.getType();
|
Type element_ty = operand_ty.getElementType();
|
||||||
} else if (shaped_type.hasRank()) {
|
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
|
||||||
new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
|
element_ty = complex_ty.getElementType();
|
||||||
} else {
|
|
||||||
new_type = UnrankedTensorType::get(operand.getType());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return AbsOp::build(builder, result, new_type, operand);
|
Type result_ty;
|
||||||
|
if (operand_ty.hasRank()) {
|
||||||
|
result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
|
||||||
|
} else {
|
||||||
|
result_ty = UnrankedTensorType::get(element_ty);
|
||||||
|
}
|
||||||
|
inferredReturnTypes.push_back(result_ty);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue