Avoid Broadcast op if all shapes are (known to) be equal.

The rank specialization case for shapes which are either of the same shape or a
scalar doesn't need to compute the final result shape.

PiperOrigin-RevId: 380129316
This commit is contained in:
Adrian Kuegel 2021-06-18 00:21:34 -07:00 committed by TensorFlow MLIR Team
parent d4a7901284
commit 4c282fb542
2 changed files with 24 additions and 29 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@ -341,17 +342,20 @@ SmallVector<Value, 8> MaterializeRankedOperations(
SmallVector<Value, 8> MaterializeFinalReshape( SmallVector<Value, 8> MaterializeFinalReshape(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
ValueRange unshaped_results) { ValueRange unshaped_results, llvm::Optional<Value> shape = {}) {
if (!shape) {
// Compute result shape. // Compute result shape.
auto non_scalar_operands = llvm::make_filter_range( auto non_scalar_operands = llvm::make_filter_range(
op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); }); op.operands(),
[](Value v) { return !IsScalarTensorType(v.getType()); });
SmallVector<Value, 8> results; SmallVector<Value, 8> results;
auto operand_shapes = auto operand_shapes =
llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) { llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result(); return b.create<shape::ShapeOfOp>(loc, v).result();
})); }));
auto shape = b.create<shape::BroadcastOp>( shape = b.create<shape::BroadcastOp>(
loc, shape::getExtentTensorType(b.getContext()), operand_shapes); loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
}
// Reshape results. // Reshape results.
return llvm::to_vector<8>( return llvm::to_vector<8>(
@ -359,7 +363,7 @@ SmallVector<Value, 8> MaterializeFinalReshape(
return b return b
.create<mhlo::DynamicReshapeOp>( .create<mhlo::DynamicReshapeOp>(
loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped, loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped,
shape) shape.getValue())
.result(); .result();
})); }));
} }
@ -664,7 +668,8 @@ MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
MaterializeRankedOperations(b, loc, bvm, op); MaterializeRankedOperations(b, loc, bvm, op);
// Restore the results' expected shape. // Restore the results' expected shape.
return MaterializeFinalReshape(b, loc, op, unshaped_results); return MaterializeFinalReshape(b, loc, op, unshaped_results,
non_scalar_shapes.front());
} }
Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(

View File

@ -159,8 +159,7 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>) // CHECK-SCF: %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
// CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>) // CHECK-SCF: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>) // CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] // CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----
@ -229,8 +228,7 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32> // CHECK-SCF: %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32>
// CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32> // CHECK-SCF: %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32>
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32> // CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32>
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] // CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----
@ -295,8 +293,7 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] // CHECK-SCF: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-SCF: %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>) // CHECK-SCF: %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] // CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----
@ -327,8 +324,7 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
// CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>) // CHECK-SCF: %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
// CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>) // CHECK-SCF: %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32> // CHECK-SCF: %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32>
// CHECK-SCF: %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]] // CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----
@ -610,10 +606,7 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]]) // CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]])
// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]]) // CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]])
// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]] // CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]]
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[S0]])
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----
@ -647,8 +640,5 @@ func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} // CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32}
// CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"} // CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"}
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]]) // CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]])
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[S1]])
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]