Fix the MHLO to LMHLO lowering of 'gather'
The lowering assumes that the 'gather' op attributes are identical in both MHLO and LMHLO. But that's not true; some time ago the MHLO version was changed to pack 4 of its attributes into a struct. By doing the same for the LMHLO version we both fix the lowering for this op and resolve a longstanding TODO. PiperOrigin-RevId: 337943946
This commit is contained in:
parent
204cf7d544
commit
33c450e4cb
|
@ -602,11 +602,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
||||||
I64Attr:$index_vector_dim,
|
GatherDimensionNumbers:$dimension_numbers,
|
||||||
I64ElementsAttr:$offset_dims,
|
|
||||||
I64ElementsAttr:$slice_sizes,
|
I64ElementsAttr:$slice_sizes,
|
||||||
I64ElementsAttr:$collapsed_slice_dims,
|
|
||||||
I64ElementsAttr:$start_index_map,
|
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -287,6 +287,28 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @gather
|
||||||
|
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
|
||||||
|
%tensor_operand = tensor_load %operand : memref<13x7xf32>
|
||||||
|
%tensor_idxs = tensor_load %idxs : memref<5xi32>
|
||||||
|
%tensor_result =
|
||||||
|
"mhlo.gather"(%tensor_operand, %tensor_idxs)
|
||||||
|
{ dimension_numbers =
|
||||||
|
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
|
||||||
|
, index_vector_dim = 1 : i64
|
||||||
|
, offset_dims = dense<1> : tensor<1xi64>
|
||||||
|
, start_index_map = dense<0> : tensor<1xi64> }
|
||||||
|
, indices_are_sorted = false
|
||||||
|
, name = "gather.71"
|
||||||
|
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
|
||||||
|
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
|
||||||
|
// BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<5x7xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// BOTH-LABEL: func @imag_dyn
|
// BOTH-LABEL: func @imag_dyn
|
||||||
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
|
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
|
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
|
||||||
|
|
Loading…
Reference in New Issue