Avoid Broadcast op if all shapes are (known to) be equal.
The rank specialization case for shapes which are either of the same shape or a scalar doesn't need to compute the final result shape. PiperOrigin-RevId: 380129316
This commit is contained in:
		
							parent
							
								
									d4a7901284
								
							
						
					
					
						commit
						4c282fb542
					
				| 
						 | 
				
			
			@ -15,6 +15,7 @@ limitations under the License.
 | 
			
		|||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "llvm/ADT/EquivalenceClasses.h"
 | 
			
		||||
#include "llvm/ADT/Optional.h"
 | 
			
		||||
#include "llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "llvm/ADT/SmallSet.h"
 | 
			
		||||
#include "llvm/ADT/SmallVector.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -341,17 +342,20 @@ SmallVector<Value, 8> MaterializeRankedOperations(
 | 
			
		|||
 | 
			
		||||
SmallVector<Value, 8> MaterializeFinalReshape(
 | 
			
		||||
    OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
 | 
			
		||||
    ValueRange unshaped_results) {
 | 
			
		||||
    ValueRange unshaped_results, llvm::Optional<Value> shape = {}) {
 | 
			
		||||
  if (!shape) {
 | 
			
		||||
    // Compute result shape.
 | 
			
		||||
    auto non_scalar_operands = llvm::make_filter_range(
 | 
			
		||||
      op.operands(), [](Value v) { return !IsScalarTensorType(v.getType()); });
 | 
			
		||||
        op.operands(),
 | 
			
		||||
        [](Value v) { return !IsScalarTensorType(v.getType()); });
 | 
			
		||||
    SmallVector<Value, 8> results;
 | 
			
		||||
    auto operand_shapes =
 | 
			
		||||
        llvm::to_vector<8>(llvm::map_range(non_scalar_operands, [&](Value v) {
 | 
			
		||||
          return b.create<shape::ShapeOfOp>(loc, v).result();
 | 
			
		||||
        }));
 | 
			
		||||
  auto shape = b.create<shape::BroadcastOp>(
 | 
			
		||||
    shape = b.create<shape::BroadcastOp>(
 | 
			
		||||
        loc, shape::getExtentTensorType(b.getContext()), operand_shapes);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Reshape results.
 | 
			
		||||
  return llvm::to_vector<8>(
 | 
			
		||||
| 
						 | 
				
			
			@ -359,7 +363,7 @@ SmallVector<Value, 8> MaterializeFinalReshape(
 | 
			
		|||
        return b
 | 
			
		||||
            .create<mhlo::DynamicReshapeOp>(
 | 
			
		||||
                loc, DeriveUnrankedTensorTypes(unshaped.getType()), unshaped,
 | 
			
		||||
                shape)
 | 
			
		||||
                shape.getValue())
 | 
			
		||||
            .result();
 | 
			
		||||
      }));
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -664,7 +668,8 @@ MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
 | 
			
		|||
      MaterializeRankedOperations(b, loc, bvm, op);
 | 
			
		||||
 | 
			
		||||
  // Restore the results' expected shape.
 | 
			
		||||
  return MaterializeFinalReshape(b, loc, op, unshaped_results);
 | 
			
		||||
  return MaterializeFinalReshape(b, loc, op, unshaped_results,
 | 
			
		||||
                                 non_scalar_shapes.front());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -159,8 +159,7 @@ func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		|||
// CHECK-SCF:       %[[TMP0:.*]] = "mhlo.sqrt"(%[[FLAT_ARG]]) : (tensor<?xf32>)
 | 
			
		||||
// CHECK-SCF:       %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) : (tensor<?xf32>)
 | 
			
		||||
// CHECK-SCF:       %[[UNSHAPED_RES:.*]] = "mhlo.sqrt"(%[[TMP1]]) : (tensor<?xf32>)
 | 
			
		||||
// CHECK-SCF:       %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
 | 
			
		||||
// CHECK-SCF:       %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:       %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:       return %[[RES]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -229,8 +228,7 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		|||
// CHECK-SCF:      %[[TMP0:.*]] = chlo.tan %[[FLAT_ARG]] : tensor<?xf32>
 | 
			
		||||
// CHECK-SCF:      %[[TMP1:.*]] = chlo.tan %[[TMP0]] : tensor<?xf32>
 | 
			
		||||
// CHECK-SCF:      %[[UNSHAPED_RES:.*]] = chlo.tan %[[TMP1]] : tensor<?xf32>
 | 
			
		||||
// CHECK-SCF:      %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
 | 
			
		||||
// CHECK-SCF:      %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:      %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:      return %[[RES]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -295,8 +293,7 @@ func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		|||
// CHECK-SCF:       %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
 | 
			
		||||
// CHECK-SCF:       %[[FLAT_ARG:.*]] = "mhlo.dynamic_reshape"(%[[ARG]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 | 
			
		||||
// CHECK-SCF:       %[[UNSHAPED_RES:.*]] = chlo.broadcast_maximum %[[FLAT_ARG]], %[[C0]] : (tensor<?xf32>, tensor<f32>)
 | 
			
		||||
// CHECK-SCF:       %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
 | 
			
		||||
// CHECK-SCF:       %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:       %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:       return %[[RES]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -327,8 +324,7 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
 | 
			
		|||
// CHECK-SCF:       %[[IMAG:.*]] = "mhlo.imag"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
 | 
			
		||||
// CHECK-SCF:       %[[REAL:.*]] = "mhlo.real"(%[[FLAT_ARG]]) : (tensor<?xcomplex<f32>>)
 | 
			
		||||
// CHECK-SCF:       %[[UNSHAPED_RES:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] : tensor<?xf32>
 | 
			
		||||
// CHECK-SCF:       %[[RES_SHAPE:.*]] = shape.shape_of %[[ARG]]
 | 
			
		||||
// CHECK-SCF:       %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:       %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
 | 
			
		||||
// CHECK-SCF:       return %[[RES]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -610,10 +606,7 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
 | 
			
		|||
// CHECK-SCF-DAG:   %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]])
 | 
			
		||||
// CHECK-SCF-DAG:   %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]])
 | 
			
		||||
// CHECK-SCF:       %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]]
 | 
			
		||||
// CHECK-SCF-DAG:   %[[S0:.*]] = shape.shape_of %[[ARG0]]
 | 
			
		||||
// CHECK-SCF-DAG:   %[[S1:.*]] = shape.shape_of %[[ARG1]]
 | 
			
		||||
// CHECK-SCF-DAG:   %[[RES_S:.*]] = shape.broadcast %8, %9
 | 
			
		||||
// CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
 | 
			
		||||
// CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[S0]])
 | 
			
		||||
// CHECK-SCF:       return %[[RES]]
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
| 
						 | 
				
			
			@ -647,8 +640,5 @@ func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
 | 
			
		|||
// CHECK-SCF-DAG:   %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32}
 | 
			
		||||
// CHECK-SCF-DAG:   %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"}
 | 
			
		||||
// CHECK-SCF:       %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]])
 | 
			
		||||
// CHECK-SCF-DAG:   %[[S0:.*]] = shape.shape_of %[[ARG0]]
 | 
			
		||||
// CHECK-SCF-DAG:   %[[S1:.*]] = shape.shape_of %[[ARG1]]
 | 
			
		||||
// CHECK-SCF-DAG:   %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]]
 | 
			
		||||
// CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
 | 
			
		||||
// CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[S1]])
 | 
			
		||||
// CHECK-SCF:       return %[[RES]]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue