From c8919f8419b7d4f5e340f994e865dfeb6e947353 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 30 Sep 2020 02:01:45 -0700 Subject: [PATCH] Implement InferShapedTypeOpInterface and use inferReturnTypes for mhlo.imag and mhlo.real This makes the lhlo lowering work with dynamic shapes. PiperOrigin-RevId: 334553472 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 18 ++++++---------- lib/Dialect/mhlo/IR/hlo_ops.cc | 14 ++++++++---- tests/hlo-legalize-to-lhlo.mlir | 24 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 86d6e34..6a021fb 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -193,12 +193,9 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; -def HLO_ImagOp: HLO_Op< - "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); +def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", + [NoSideEffect, DeclareOpInterfaceMethods], + HLO_ComplexTensor>, BASE_HLO_ImagOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } @@ -237,12 +234,9 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; -def HLO_RealOp: HLO_Op< - "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); +def HLO_RealOp: HLO_UnaryElementwiseOp<"real", + [NoSideEffect, DeclareOpInterfaceMethods], + HLO_ComplexTensor>, BASE_HLO_RealOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 31de088..0bca35b 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -932,8 +932,11 @@ Type CreateRealType(Type type) { } } // namespace -void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { - build(builder, state, CreateRealType(val.getType()), val); +LogicalResult ImagOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(CreateRealType(operands[0].getType())); + return success(); } OpFoldResult ImagOp::fold(ArrayRef operands) { @@ -945,8 +948,11 @@ OpFoldResult ImagOp::fold(ArrayRef operands) { return {}; } -void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { - build(builder, state, CreateRealType(val.getType()), val); +LogicalResult RealOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(CreateRealType(operands[0].getType())); + return success(); } OpFoldResult RealOp::fold(ArrayRef operands) { diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index ed38833..6d67c60 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -248,6 +248,18 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @real_dyn +func @real_dyn(%operand: memref>, %result: memref) { + %tensor_operand = tensor_load %operand : memref> + %tensor_result = "mhlo.real"(%tensor_operand) + : (tensor>) -> tensor + // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref + return +} + +// ----- + // BOTH-LABEL: func @imag func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> @@ -260,6 +272,18 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @imag_dyn +func @imag_dyn(%operand: memref>, %result: memref) { + %tensor_operand = tensor_load %operand : memref> + %tensor_result = "mhlo.imag"(%tensor_operand) + : (tensor>) -> tensor + // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref + return +} + +// ----- + // BOTH-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "mhlo.iota"()