Use tensor_cast instead of mhlo::reshape in the lowering of unranked binary operations.
We know that the value already is a scalar and we just want to update the type, so no need to reshape anything. PiperOrigin-RevId: 336252315
This commit is contained in:
parent
41436ea0d9
commit
d986bd7ad7
|
@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
|
@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
|
|
|
@ -325,7 +325,7 @@ func @addUnrankedUnranked(
|
|||
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
|
@ -334,7 +334,7 @@ func @addUnrankedUnranked(
|
|||
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
|
|
Loading…
Reference in New Issue