Use tensor_cast instead of mhlo::reshape in the lowering of unranked binary operations.

We know that the value already is a scalar and we just want to update the type, so no need to reshape anything.

PiperOrigin-RevId: 336252315
This commit is contained in:
Stephan Herhut 2020-10-09 01:46:01 -07:00 committed by TensorFlow MLIR Team
parent 41436ea0d9
commit d986bd7ad7
2 changed files with 4 additions and 4 deletions

View File

@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},

View File

@ -325,7 +325,7 @@ func @addUnrankedUnranked(
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
// Handle scalar LHS case
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
// CHECK: } else {
@ -334,7 +334,7 @@ func @addUnrankedUnranked(
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
// Handle scalar RHS case
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
// CHECK: } else {