diff --git a/tests/unfuse_batch_norm.mlir b/tests/unfuse_batch_norm.mlir index 867c1d2..c193072 100644 --- a/tests/unfuse_batch_norm.mlir +++ b/tests/unfuse_batch_norm.mlir @@ -110,7 +110,7 @@ func @batchNormInference_dynamic_shape( // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor