diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 7ae241d..1fcca87 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -293,9 +293,9 @@ Value MaterializeScalarRankSpecializationCase( llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { if (v == non_scalar_operand) return flat_non_scalar_operand; return b - .create( + .create( loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v) - .dest(); + .getResult(); })); // Materialize ranked variants for the element-wise operations. @@ -585,6 +585,8 @@ struct LowerRankSpecializationClusterPattern LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Restoring the result shape currently relies on all operands being used // for a single result. The result shape is then the broadcasted shape of // all operands. @@ -595,7 +597,6 @@ struct LowerRankSpecializationClusterPattern // If there is only one unranked operand and all others are known scalars, // we can flatten the operands to rank 1. - Location loc = op.getLoc(); if (Optional non_scalar_operand = FindUniqueNonScalar(op.operands())) { rewriter.replaceOp(op, diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 45001f1..535bc69 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -488,7 +488,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] // CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) -// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]] +// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG0]]) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]] @@ -500,7 +500,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] // CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] // CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) -// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]] +// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG1]]) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor, tensor) // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF: scf.yield %[[INNER_RES_]]