Use InferTypeOpInterface for HLO AbsOp and fix result shape inference

Shape inference in case of ops with complex element types need to use the element type of complex as the result element type and not the full operand type.

Before:
"mhlo.abs"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xtensor<4xcomplex<f32>>>
After:
"mhlo.abs"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
PiperOrigin-RevId: 348123967
This commit is contained in:
Smit Hinsu 2020-12-17 17:36:05 -08:00 committed by TensorFlow MLIR Team
parent 43ede42ce1
commit 8d051723c0
2 changed files with 17 additions and 13 deletions

View File

@ -146,10 +146,9 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
// Abs supports complex to real, so element type is not guaranteed to match. // Abs supports complex to real, so element type is not guaranteed to match.
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape], [NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [
OpBuilderDAG<(ins "Value":$operand)>];
} }
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",

View File

@ -454,18 +454,23 @@ static LogicalResult Verify(DynamicUpdateSliceOp op) {
// AbsOp // AbsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) { LogicalResult AbsOp::inferReturnTypes(
auto shaped_type = operand.getType().cast<ShapedType>(); MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
Type new_type; RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
if (!shaped_type.getElementType().isa<ComplexType>()) { auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
new_type = operand.getType(); Type element_ty = operand_ty.getElementType();
} else if (shaped_type.hasRank()) { if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType()); element_ty = complex_ty.getElementType();
} else {
new_type = UnrankedTensorType::get(operand.getType());
} }
return AbsOp::build(builder, result, new_type, operand); Type result_ty;
if (operand_ty.hasRank()) {
result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
} else {
result_ty = UnrankedTensorType::get(element_ty);
}
inferredReturnTypes.push_back(result_ty);
return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//