Implement InferShapedTypeOpInterface and use inferReturnTypes for mhlo.imag and mhlo.real
This makes the lhlo lowering work with dynamic shapes. PiperOrigin-RevId: 334553472
This commit is contained in:
parent
39389587d2
commit
c8919f8419
|
@ -193,12 +193,9 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
|
||||||
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
||||||
|
|
||||||
def HLO_ImagOp: HLO_Op<
|
def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
||||||
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
|
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
let builders = [OpBuilder<
|
HLO_ComplexTensor>, BASE_HLO_ImagOp {
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
|
||||||
|
|
||||||
let arguments = (ins HLO_ComplexTensor);
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
@ -237,12 +234,9 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
||||||
BASE_HLO_PopulationCountOp;
|
BASE_HLO_PopulationCountOp;
|
||||||
|
|
||||||
def HLO_RealOp: HLO_Op<
|
def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
||||||
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
|
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
let builders = [OpBuilder<
|
HLO_ComplexTensor>, BASE_HLO_RealOp {
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
|
||||||
|
|
||||||
let arguments = (ins HLO_ComplexTensor);
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -932,8 +932,11 @@ Type CreateRealType(Type type) {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
LogicalResult ImagOp::inferReturnTypes(
|
||||||
build(builder, state, CreateRealType(val.getType()), val);
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
@ -945,8 +948,11 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
LogicalResult RealOp::inferReturnTypes(
|
||||||
build(builder, state, CreateRealType(val.getType()), val);
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
|
@ -248,6 +248,18 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @real_dyn
|
||||||
|
func @real_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
|
||||||
|
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
|
||||||
|
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||||
|
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||||
|
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// BOTH-LABEL: func @imag
|
// BOTH-LABEL: func @imag
|
||||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||||
|
@ -260,6 +272,18 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @imag_dyn
|
||||||
|
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
|
||||||
|
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
|
||||||
|
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||||
|
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||||
|
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// BOTH-LABEL: func @iota
|
// BOTH-LABEL: func @iota
|
||||||
func @iota(%result: memref<10xi32>) {
|
func @iota(%result: memref<10xi32>) {
|
||||||
%tensor_result = "mhlo.iota"()
|
%tensor_result = "mhlo.iota"()
|
||||||
|
|
Loading…
Reference in New Issue