diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 626b5d3..d68fe92 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp auto if_op = rewriter.create( loc, result_type, IsScalarTensor(rewriter, op, lhs), true); OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); - Value reshaped_lhs = if_lhs_scalar_builder.create( + Value reshaped_lhs = if_lhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); Value if_lhs_scalar_result = if_lhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, @@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp else_lhs_scalar_builder.create(loc, if_rhs_scalar_op.getResult(0)); OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); - Value reshaped_rhs = if_rhs_scalar_builder.create( + Value reshaped_rhs = if_rhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); Value if_rhs_scalar_result = if_rhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index 94d20dc..60ec26f 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -325,7 +325,7 @@ func @addUnrankedUnranked( // CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index // Handle scalar LHS case // CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor // CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor, tensor<*xf32>) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_10]] : tensor<*xf32> // CHECK: } else { @@ -334,7 +334,7 @@ func @addUnrankedUnranked( // CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index // Handle scalar RHS case // CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor // CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_16]] : tensor<*xf32> // CHECK: } else {