[MLIR][HLO] Reshape to scalars in rank specialization
Scalars were incorrectly casted to scalar tensors when they have to be reshaped. PiperOrigin-RevId: 375049088
This commit is contained in:
parent
c3147d76d5
commit
97e6103933
|
@ -293,9 +293,9 @@ Value MaterializeScalarRankSpecializationCase(
|
|||
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
if (v == non_scalar_operand) return flat_non_scalar_operand;
|
||||
return b
|
||||
.create<tensor::CastOp>(
|
||||
.create<mhlo::ReshapeOp>(
|
||||
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
|
||||
.dest();
|
||||
.getResult();
|
||||
}));
|
||||
|
||||
// Materialize ranked variants for the element-wise operations.
|
||||
|
@ -585,6 +585,8 @@ struct LowerRankSpecializationClusterPattern
|
|||
|
||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Restoring the result shape currently relies on all operands being used
|
||||
// for a single result. The result shape is then the broadcasted shape of
|
||||
// all operands.
|
||||
|
@ -595,7 +597,6 @@ struct LowerRankSpecializationClusterPattern
|
|||
|
||||
// If there is only one unranked operand and all others are known scalars,
|
||||
// we can flatten the operands to rank 1.
|
||||
Location loc = op.getLoc();
|
||||
if (Optional<Value> non_scalar_operand =
|
||||
FindUniqueNonScalar(op.operands())) {
|
||||
rewriter.replaceOp(op,
|
||||
|
|
|
@ -488,7 +488,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG0]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor<f32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
|
@ -500,7 +500,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG1]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor<?xf32>, tensor<f32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
|
|
Loading…
Reference in New Issue