Avoid Broadcast op if all shapes are (known to) be equal.
The rank specialization case for shapes which are either of the same shape or a scalar doesn't need to compute the final result shape. PiperOrigin-RevId: 380129316
This commit is contained in:
parent
d4a7901284
commit
4c282fb542
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/EquivalenceClasses.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
@ -341,17 +342,20 @@ SmallVector<Value, 8> MaterializeRankedOperations(
|
|||
|
||||
SmallVector<Value, 8> MaterializeFinalReshape(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
ValueRange unshaped_results) {
|
||||
// Compute result shape.
|
||||
auto non_scalar_operands = llvm::make_filter_range(
|
||||
op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); });
|
||||
SmallVector<Value, 8> results;
|
||||
auto operand_shapes =
|
||||
llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
|
||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||
}));
|
||||
auto shape = b.create<shape::BroadcastOp>(
|
||||
loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
|
||||
ValueRange unshaped_results, llvm::Optional<Value> shape = {}) {
|
||||
if (!shape) {
|
||||
// Compute result shape.
|
||||
auto non_scalar_operands = llvm::make_filter_range(
|
||||
op.operands(),
|
||||
[](Value v) { return !IsScalarTensorType(v.getType()); });
|
||||
SmallVector<Value, 8> results;
|
||||
auto operand_shapes =
|
||||
llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
|
||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||
}));
|
||||
shape = b.create<shape::BroadcastOp>(
|
||||
loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
|
||||
}
|
||||
|
||||
// Reshape results.
|
||||
return llvm::to_vector<8>(
|
||||
|
@ -359,7 +363,7 @@ SmallVector<Value, 8> MaterializeFinalReshape(
|
|||
return b
|
||||
.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped,
|
||||
shape)
|
||||
shape.getValue())
|
||||
.result();
|
||||
}));
|
||||
}
|
||||
|
@ -664,7 +668,8 @@ MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
|
|||
MaterializeRankedOperations(b, loc, bvm, op);
|
||||
|
||||
// Restore the results' expected shape.
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_results,
|
||||
non_scalar_shapes.front());
|
||||
}
|
||||
|
||||
Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
|
||||
|
|
|
@ -159,8 +159,7 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
@ -229,8 +228,7 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
@ -295,8 +293,7 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>)
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
@ -327,8 +324,7 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
|
|||
// CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
|
||||
// CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32>
|
||||
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
@ -610,10 +606,7 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
|||
// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]])
|
||||
// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]])
|
||||
// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]]
|
||||
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[S0]])
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
@ -647,8 +640,5 @@ func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32}
|
||||
// CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"}
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]])
|
||||
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]]
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[S1]])
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
|
Loading…
Reference in New Issue