[NFC] Remove all_gather_dimension from all-to-all in the unit test
PiperOrigin-RevId: 372463706
This commit is contained in:
parent
f8c629207c
commit
ce4c76314e
|
@ -58,7 +58,7 @@ func @invalid_alltoall(%input0: memref<2xf32>, %output: memref<8xf32>) {
|
|||
"lmhlo.all_to_all"(%input0, %output)
|
||||
{channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false,
|
||||
replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>,
|
||||
use_global_device_ids = false, all_gather_dimension = 0 : i64} : (memref<2xf32>, memref<8xf32>) -> ()
|
||||
use_global_device_ids = false} : (memref<2xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ func @invalid_alltoall(%input0: memref<2xf32>, %output: memref<8xf32>) {
|
|||
"lmhlo.all_to_all"(%input0, %output)
|
||||
{channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false,
|
||||
replica_groups = dense<0> : tensor<1xi64>,
|
||||
use_global_device_ids = false, all_gather_dimension = 0 : i64} : (memref<2xf32>, memref<8xf32>) -> ()
|
||||
use_global_device_ids = false} : (memref<2xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue