Use InferTypeOpInterface for HLO AbsOp and fix result shape inference

Shape inference in case of ops with complex element types need to use the element type of complex as the result element type and not the full operand type.

Before:
"mhlo.abs"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xtensor<4xcomplex<f32>>>
After:
"mhlo.abs"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
PiperOrigin-RevId: 348123967
This commit is contained in:
Smit Hinsu 2020-12-17 17:36:05 -08:00 committed by TensorFlow MLIR Team
parent 43ede42ce1
commit 8d051723c0
2 changed files with 17 additions and 13 deletions

View File

@ -146,10 +146,9 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
// Abs supports complex to real, so element type is not guaranteed to match.
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape],
[NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [
OpBuilderDAG<(ins "Value":$operand)>];
}
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",

View File

@ -454,18 +454,23 @@ static LogicalResult Verify(DynamicUpdateSliceOp op) {
// AbsOp
//===----------------------------------------------------------------------===//
void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) {
auto shaped_type = operand.getType().cast<ShapedType>();
Type new_type;
if (!shaped_type.getElementType().isa<ComplexType>()) {
new_type = operand.getType();
} else if (shaped_type.hasRank()) {
new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
} else {
new_type = UnrankedTensorType::get(operand.getType());
LogicalResult AbsOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto operand_ty = (*operands.begin()).getType().cast<ShapedType>();
Type element_ty = operand_ty.getElementType();
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
element_ty = complex_ty.getElementType();
}
return AbsOp::build(builder, result, new_type, operand);
Type result_ty;
if (operand_ty.hasRank()) {
result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty);
} else {
result_ty = UnrankedTensorType::get(element_ty);
}
inferredReturnTypes.push_back(result_ty);
return success();
}
//===----------------------------------------------------------------------===//