Canonicalize multidimensional iota to use broadcast
There is no reason to have a multidimensional iota for codegen. This should be canonicalized to a single dimensional iota followed by a broadcast. Changing iota to on a single dimension and a broadcast substantially simplifies implementing iota operations. PiperOrigin-RevId: 320095470
This commit is contained in:
		
							parent
							
								
									8900222fed
								
							
						
					
					
						commit
						e1651b6090
					
				| 
						 | 
					@ -78,6 +78,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // TODO(b/130357376): Iota has special conversion logic to HLO.
 | 
					  // TODO(b/130357376): Iota has special conversion logic to HLO.
 | 
				
			||||||
  let hasCustomHLOConverter = 1;
 | 
					  let hasCustomHLOConverter = 1;
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
 | 
					def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -214,6 +214,41 @@ static LogicalResult Verify(IotaOp op) {
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Iota operations across multiple dimensions can be reduced to an iota and a
 | 
				
			||||||
 | 
					// ranked broadcast.
 | 
				
			||||||
 | 
					struct IotaBroadcast : public OpRewritePattern<IotaOp> {
 | 
				
			||||||
 | 
					  using OpRewritePattern<IotaOp>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(IotaOp iota,
 | 
				
			||||||
 | 
					                                PatternRewriter& rewriter) const override {
 | 
				
			||||||
 | 
					    auto result_ty = iota.getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					    if (!result_ty.hasRank() || result_ty.getRank() < 2) {
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto iota_dimension = iota.iota_dimension();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto iota_type = RankedTensorType::get(
 | 
				
			||||||
 | 
					        {result_ty.getDimSize(iota_dimension.getLimitedValue())},
 | 
				
			||||||
 | 
					        result_ty.getElementType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
 | 
				
			||||||
 | 
					                                            rewriter.getI64IntegerAttr(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto broadcast_attr = DenseIntElementsAttr::get(
 | 
				
			||||||
 | 
					        RankedTensorType::get({1}, rewriter.getIntegerType(64)),
 | 
				
			||||||
 | 
					        {iota_dimension});
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
 | 
				
			||||||
 | 
					                                                  broadcast_attr);
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
 | 
				
			||||||
 | 
					                                         MLIRContext* context) {
 | 
				
			||||||
 | 
					  results.insert<IotaBroadcast>(context);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// DynamicIotaOp
 | 
					// DynamicIotaOp
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -235,11 +270,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Dynamic Iota operations across multiple dimensions can be reduced to an iota
 | 
				
			||||||
 | 
					// and a ranked broadcast.
 | 
				
			||||||
 | 
					struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
 | 
				
			||||||
 | 
					  using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(DynamicIotaOp iota,
 | 
				
			||||||
 | 
					                                PatternRewriter& rewriter) const override {
 | 
				
			||||||
 | 
					    auto result_ty = iota.getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					    if (!result_ty.hasRank() || result_ty.getRank() < 2) {
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto iota_dimension = iota.iota_dimension();
 | 
				
			||||||
 | 
					    auto iota_dimension_int = iota_dimension.getLimitedValue();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto converted_shape = rewriter.create<IndexCastOp>(
 | 
				
			||||||
 | 
					        iota.getLoc(),
 | 
				
			||||||
 | 
					        RankedTensorType::get(
 | 
				
			||||||
 | 
					            iota.output_shape().getType().cast<ShapedType>().getShape(),
 | 
				
			||||||
 | 
					            rewriter.getI64Type()),
 | 
				
			||||||
 | 
					        iota.output_shape());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto sliced_shape = rewriter.create<SliceOp>(
 | 
				
			||||||
 | 
					        iota.getLoc(), converted_shape,
 | 
				
			||||||
 | 
					        GetI64ElementsAttr(iota_dimension_int, &rewriter),
 | 
				
			||||||
 | 
					        GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
 | 
				
			||||||
 | 
					        GetI64ElementsAttr(1, &rewriter));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto converted_sliced_shape = rewriter.create<IndexCastOp>(
 | 
				
			||||||
 | 
					        iota.getLoc(),
 | 
				
			||||||
 | 
					        RankedTensorType::get(
 | 
				
			||||||
 | 
					            {1},
 | 
				
			||||||
 | 
					            iota.output_shape().getType().cast<ShapedType>().getElementType()),
 | 
				
			||||||
 | 
					        sliced_shape);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto iota_type = RankedTensorType::get(
 | 
				
			||||||
 | 
					        {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto new_iota = rewriter.create<DynamicIotaOp>(
 | 
				
			||||||
 | 
					        iota.getLoc(), iota_type, converted_sliced_shape,
 | 
				
			||||||
 | 
					        rewriter.getI64IntegerAttr(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto broadcast_attr = DenseIntElementsAttr::get(
 | 
				
			||||||
 | 
					        RankedTensorType::get({1}, rewriter.getIntegerType(64)),
 | 
				
			||||||
 | 
					        {iota_dimension});
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
 | 
				
			||||||
 | 
					        iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void DynamicIotaOp::getCanonicalizationPatterns(
 | 
					void DynamicIotaOp::getCanonicalizationPatterns(
 | 
				
			||||||
    OwningRewritePatternList& results, MLIRContext* context) {
 | 
					    OwningRewritePatternList& results, MLIRContext* context) {
 | 
				
			||||||
  results.insert<DynamicIotaIsStatic>(context);
 | 
					  results.insert<DynamicIotaIsStatic>(context);
 | 
				
			||||||
 | 
					  results.insert<DynamicIotaBroadcast>(context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -391,6 +391,30 @@ func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
 | 
				
			||||||
  return %0 : tensor<4xi32>
 | 
					  return %0 : tensor<4xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @dynamic_iota_broadcast
 | 
				
			||||||
 | 
					func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
 | 
				
			||||||
 | 
					  // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
 | 
				
			||||||
 | 
					  // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: return [[BROADCAST]]
 | 
				
			||||||
 | 
					  return %0 : tensor<5x?xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @dynamic_iota_broadcast_second
 | 
				
			||||||
 | 
					func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<?xi32>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<2xindex>) -> tensor<5x?xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: return [[BROADCAST]]
 | 
				
			||||||
 | 
					  return %0 : tensor<5x?xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: @iota_not_lowered_to_constant
 | 
					// CHECK-LABEL: @iota_not_lowered_to_constant
 | 
				
			||||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
 | 
					func @iota_not_lowered_to_constant() -> tensor<4xi32> {
 | 
				
			||||||
  // CHECK: [[RESULT:%.*]] = "mhlo.iota"
 | 
					  // CHECK: [[RESULT:%.*]] = "mhlo.iota"
 | 
				
			||||||
| 
						 | 
					@ -399,6 +423,24 @@ func @iota_not_lowered_to_constant() -> tensor<4xi32> {
 | 
				
			||||||
  return %0 : tensor<4xi32>
 | 
					  return %0 : tensor<4xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @iota_broadcast
 | 
				
			||||||
 | 
					func @iota_broadcast() -> tensor<5x4xi32> {
 | 
				
			||||||
 | 
					  // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
 | 
				
			||||||
 | 
					  // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return %0 : tensor<5x4xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @iota_broadcast
 | 
				
			||||||
 | 
					func @iota_broadcast_second() -> tensor<5x4xi32> {
 | 
				
			||||||
 | 
					  // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
 | 
				
			||||||
 | 
					  // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return %0 : tensor<5x4xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: @unary_einsum
 | 
					// CHECK-LABEL: @unary_einsum
 | 
				
			||||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
 | 
					func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
 | 
				
			||||||
  // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
 | 
					  // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue