Fix the MHLO to LMHLO lowering of 'gather'
The lowering assumes that the 'gather' op attributes are identical in both MHLO and LMHLO. But that's not true; some time ago the MHLO version was changed to pack 4 of its attributes into a struct. By doing the same for the LMHLO version we both fix the lowering for this op and resolve a longstanding TODO. PiperOrigin-RevId: 337943946
This commit is contained in:
		
							parent
							
								
									204cf7d544
								
							
						
					
					
						commit
						33c450e4cb
					
				|  | @ -602,11 +602,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { | ||||||
|   let arguments = (ins |   let arguments = (ins | ||||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$operand, |     Arg<LHLO_Buffer, "", [MemRead]>:$operand, | ||||||
|     Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices, |     Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices, | ||||||
|     I64Attr:$index_vector_dim, |     GatherDimensionNumbers:$dimension_numbers, | ||||||
|     I64ElementsAttr:$offset_dims, |  | ||||||
|     I64ElementsAttr:$slice_sizes, |     I64ElementsAttr:$slice_sizes, | ||||||
|     I64ElementsAttr:$collapsed_slice_dims, |  | ||||||
|     I64ElementsAttr:$start_index_map, |  | ||||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$output |     Arg<LHLO_Buffer, "", [MemWrite]>:$output | ||||||
|   ); |   ); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -287,6 +287,28 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | // BOTH-LABEL: func @gather | ||||||
|  | func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) { | ||||||
|  |   %tensor_operand = tensor_load %operand : memref<13x7xf32> | ||||||
|  |   %tensor_idxs = tensor_load %idxs : memref<5xi32> | ||||||
|  |   %tensor_result = | ||||||
|  |     "mhlo.gather"(%tensor_operand, %tensor_idxs) | ||||||
|  |       { dimension_numbers = | ||||||
|  |         { collapsed_slice_dims = dense<0> : tensor<1xi64> | ||||||
|  |         , index_vector_dim = 1 : i64 | ||||||
|  |         , offset_dims = dense<1> : tensor<1xi64> | ||||||
|  |         , start_index_map = dense<0> : tensor<1xi64> } | ||||||
|  |       , indices_are_sorted = false | ||||||
|  |       , name = "gather.71" | ||||||
|  |       , slice_sizes = dense<[1, 7]> : tensor<2xi64> } | ||||||
|  |       : (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32> | ||||||
|  |   // BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}}) | ||||||
|  |   tensor_store %tensor_result, %result : memref<5x7xf32> | ||||||
|  |   return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| // BOTH-LABEL: func @imag_dyn | // BOTH-LABEL: func @imag_dyn | ||||||
| func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) { | func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) { | ||||||
|   %tensor_operand = tensor_load %operand : memref<?xcomplex<f32>> |   %tensor_operand = tensor_load %operand : memref<?xcomplex<f32>> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue