[MLIR][HLO] Reshape to scalars in rank specialization

Scalars were incorrectly casted to scalar tensors when they have to be reshaped.

PiperOrigin-RevId: 375049088
This commit is contained in:
A. Unique TensorFlower 2021-05-21 03:11:11 -07:00 committed by TensorFlow MLIR Team
parent c3147d76d5
commit 97e6103933
2 changed files with 6 additions and 5 deletions

View File

@ -293,9 +293,9 @@ Value MaterializeScalarRankSpecializationCase(
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
if (v == non_scalar_operand) return flat_non_scalar_operand;
return b
.create<tensor::CastOp>(
.create<mhlo::ReshapeOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
.dest();
.getResult();
}));
// Materialize ranked variants for the element-wise operations.
@ -585,6 +585,8 @@ struct LowerRankSpecializationClusterPattern
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Restoring the result shape currently relies on all operands being used
// for a single result. The result shape is then the broadcasted shape of
// all operands.
@ -595,7 +597,6 @@ struct LowerRankSpecializationClusterPattern
// If there is only one unranked operand and all others are known scalars,
// we can flatten the operands to rank 1.
Location loc = op.getLoc();
if (Optional<Value> non_scalar_operand =
FindUniqueNonScalar(op.operands())) {
rewriter.replaceOp(op,

View File

@ -488,7 +488,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]]
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG0]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
@ -500,7 +500,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]]
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG1]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]