From dfe64d3958ad31aab82d65f03c44eae7c9e8c710 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 30 Sep 2020 12:13:21 -0700 Subject: [PATCH] Implement InferShapedTypeOpInterface for mhlo.complex Binary companion for https://github.com/tensorflow/tensorflow/commit/8bcd33e4b70f8b9b44bb71b4eac160c81b7a0cca PiperOrigin-RevId: 334651523 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 14 +++++++------- lib/Dialect/mhlo/IR/hlo_ops.cc | 11 ++++++----- tests/hlo-legalize-to-lhlo.mlir | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 6a021fb..2e17834 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -194,7 +194,8 @@ def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", - [NoSideEffect, DeclareOpInterfaceMethods], + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], HLO_ComplexTensor>, BASE_HLO_ImagOp { let results = (outs HLO_FpTensor); let hasFolder = 1; @@ -235,7 +236,8 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", BASE_HLO_PopulationCountOp; def HLO_RealOp: HLO_UnaryElementwiseOp<"real", - [NoSideEffect, DeclareOpInterfaceMethods], + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], HLO_ComplexTensor>, BASE_HLO_RealOp { let results = (outs HLO_FpTensor); let hasFolder = 1; @@ -315,12 +317,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add", def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; -def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsAndResultShape]>, +def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods]>, BASE_HLO_ComplexOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; - let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let results = (outs HLO_ComplexTensor); let hasFolder = 1; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 0bca35b..8e7673f 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -889,9 +889,10 @@ static LogicalResult Verify(ClampOp op) { // ComplexOp //===----------------------------------------------------------------------===// -void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, - Value rhs) { - auto type = lhs.getType(); +LogicalResult ComplexOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto type = operands[0].getType(); auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); Type result_ty; if (auto ranked_type = type.dyn_cast()) { @@ -901,8 +902,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, } else { result_ty = element_ty; } - - build(builder, state, result_ty, lhs, rhs); + inferredReturnTypes.push_back(result_ty); + return success(); } OpFoldResult ComplexOp::fold(ArrayRef operands) { diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 6d67c60..03800fa 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -236,6 +236,21 @@ func @complex(%real: memref<2x2xf32>, // ----- +// BOTH-LABEL: func @complex_dyn +func @complex_dyn(%real: memref, + %imag: memref, + %result: memref>) { + %tensor_real = tensor_load %real : memref + %tensor_imag = tensor_load %imag : memref + %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) + : (tensor, tensor) -> tensor> + // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref> + return +} + +// ----- + // BOTH-LABEL: func @real func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex>