diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 1ef3f62..b64aa9f 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -15,6 +15,7 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -341,17 +342,20 @@ SmallVector MaterializeRankedOperations( SmallVector MaterializeFinalReshape( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - ValueRange unshaped_results) { - // Compute result shape. - auto non_scalar_operands = llvm::make_filter_range( - op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); }); - SmallVector results; - auto operand_shapes = - llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) { - return b.create(loc, v).result(); - })); - auto shape = b.create( - loc, shape::getExtentTensorType(b.getContext()), operand_shapes); + ValueRange unshaped_results, llvm::Optional shape = {}) { + if (!shape) { + // Compute result shape. + auto non_scalar_operands = llvm::make_filter_range( + op.operands(), + [](Value v) { return !IsScalarTensorType(v.getType()); }); + SmallVector results; + auto operand_shapes = + llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) { + return b.create(loc, v).result(); + })); + shape = b.create( + loc, shape::getExtentTensorType(b.getContext()), operand_shapes); + } // Reshape results. return llvm::to_vector<8>( @@ -359,7 +363,7 @@ SmallVector MaterializeFinalReshape( return b .create( loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped, - shape) + shape.getValue()) .result(); })); } @@ -664,7 +668,8 @@ MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( MaterializeRankedOperations(b, loc, bvm, op); // Restore the results' expected shape. - return MaterializeFinalReshape(b, loc, op, unshaped_results); + return MaterializeFinalReshape(b, loc, op, unshaped_results, + non_scalar_shapes.front()); } Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index fba118e..9422c87 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -159,8 +159,7 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor) // CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor) // CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor) -// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-SCF: return %[[RES]] // ----- @@ -229,8 +228,7 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor // CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor // CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor -// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-SCF: return %[[RES]] // ----- @@ -295,8 +293,7 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] // CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor, tensor) -// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-SCF: return %[[RES]] // ----- @@ -327,8 +324,7 @@ func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { // CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor>) // CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor>) // CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor -// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] -// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-SCF: return %[[RES]] // ----- @@ -610,10 +606,7 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) // CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]]) // CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]]) // CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]] -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9 -// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]]) +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[S0]]) // CHECK-SCF: return %[[RES]] // ----- @@ -647,8 +640,5 @@ func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} // CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"} // CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]]) -// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] -// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] -// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]] -// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[S1]]) // CHECK-SCF: return %[[RES]]