Implement InferShapedTypeOpInterface for mhlo.complex
Binary companion for 8bcd33e4b7
PiperOrigin-RevId: 334651523
This commit is contained in:
parent
019c5ef106
commit
dfe64d3958
|
@ -194,7 +194,8 @@ def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
||||||
|
|
||||||
def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
HLO_ComplexTensor>, BASE_HLO_ImagOp {
|
HLO_ComplexTensor>, BASE_HLO_ImagOp {
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -235,7 +236,8 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
|
||||||
BASE_HLO_PopulationCountOp;
|
BASE_HLO_PopulationCountOp;
|
||||||
|
|
||||||
def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
HLO_ComplexTensor>, BASE_HLO_RealOp {
|
HLO_ComplexTensor>, BASE_HLO_RealOp {
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -315,12 +317,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
||||||
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
||||||
|
|
||||||
def HLO_ComplexOp: HLO_Op<"complex",
|
def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex",
|
||||||
[NoSideEffect, SameOperandsAndResultShape]>,
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||||
BASE_HLO_ComplexOp {
|
BASE_HLO_ComplexOp {
|
||||||
let builders = [OpBuilder<
|
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
|
||||||
|
|
||||||
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
||||||
let results = (outs HLO_ComplexTensor);
|
let results = (outs HLO_ComplexTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
|
@ -889,9 +889,10 @@ static LogicalResult Verify(ClampOp op) {
|
||||||
// ComplexOp
|
// ComplexOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
LogicalResult ComplexOp::inferReturnTypes(
|
||||||
Value rhs) {
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
auto type = lhs.getType();
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
auto type = operands[0].getType();
|
||||||
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
||||||
Type result_ty;
|
Type result_ty;
|
||||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
@ -901,8 +902,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
||||||
} else {
|
} else {
|
||||||
result_ty = element_ty;
|
result_ty = element_ty;
|
||||||
}
|
}
|
||||||
|
inferredReturnTypes.push_back(result_ty);
|
||||||
build(builder, state, result_ty, lhs, rhs);
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
|
@ -236,6 +236,21 @@ func @complex(%real: memref<2x2xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @complex_dyn
|
||||||
|
func @complex_dyn(%real: memref<?xf32>,
|
||||||
|
%imag: memref<?xf32>,
|
||||||
|
%result: memref<?xcomplex<f32>>) {
|
||||||
|
%tensor_real = tensor_load %real : memref<?xf32>
|
||||||
|
%tensor_imag = tensor_load %imag : memref<?xf32>
|
||||||
|
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||||
|
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
|
||||||
|
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<?xcomplex<f32>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// BOTH-LABEL: func @real
|
// BOTH-LABEL: func @real
|
||||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||||
|
|
Loading…
Reference in New Issue