diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 7cf1901..c76165a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -78,6 +78,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { // TODO(b/130357376): Iota has special conversion logic to HLO. let hasCustomHLOConverter = 1; + let hasCanonicalizer = 1; } def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 90dac58..93f9ad7 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -214,6 +214,41 @@ static LogicalResult Verify(IotaOp op) { return success(); } +// Iota operations across multiple dimensions can be reduced to an iota and a +// ranked broadcast. +struct IotaBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IotaOp iota, + PatternRewriter& rewriter) const override { + auto result_ty = iota.getType().cast(); + if (!result_ty.hasRank() || result_ty.getRank() < 2) { + return failure(); + } + + auto iota_dimension = iota.iota_dimension(); + + auto iota_type = RankedTensorType::get( + {result_ty.getDimSize(iota_dimension.getLimitedValue())}, + result_ty.getElementType()); + + auto new_iota = rewriter.create(iota.getLoc(), iota_type, + rewriter.getI64IntegerAttr(0)); + + auto broadcast_attr = DenseIntElementsAttr::get( + RankedTensorType::get({1}, rewriter.getIntegerType(64)), + {iota_dimension}); + rewriter.replaceOpWithNewOp(iota, result_ty, new_iota, + broadcast_attr); + return success(); + } +}; + +void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // DynamicIotaOp //===----------------------------------------------------------------------===// @@ -235,11 +270,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern { } }; +// Dynamic Iota operations across multiple dimensions can be reduced to an iota +// and a ranked broadcast. +struct DynamicIotaBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter& rewriter) const override { + auto result_ty = iota.getType().cast(); + if (!result_ty.hasRank() || result_ty.getRank() < 2) { + return failure(); + } + + auto iota_dimension = iota.iota_dimension(); + auto iota_dimension_int = iota_dimension.getLimitedValue(); + + auto converted_shape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + iota.output_shape().getType().cast().getShape(), + rewriter.getI64Type()), + iota.output_shape()); + + auto sliced_shape = rewriter.create( + iota.getLoc(), converted_shape, + GetI64ElementsAttr(iota_dimension_int, &rewriter), + GetI64ElementsAttr(iota_dimension_int + 1, &rewriter), + GetI64ElementsAttr(1, &rewriter)); + + auto converted_sliced_shape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + {1}, + iota.output_shape().getType().cast().getElementType()), + sliced_shape); + + auto iota_type = RankedTensorType::get( + {result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType()); + + auto new_iota = rewriter.create( + iota.getLoc(), iota_type, converted_sliced_shape, + rewriter.getI64IntegerAttr(0)); + + auto broadcast_attr = DenseIntElementsAttr::get( + RankedTensorType::get({1}, rewriter.getIntegerType(64)), + {iota_dimension}); + rewriter.replaceOpWithNewOp( + iota, result_ty, new_iota, iota.output_shape(), broadcast_attr); + return success(); + } +}; + } // namespace void DynamicIotaOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 1124a46..9f37ca5 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -391,6 +391,30 @@ func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { return %0 : tensor<4xi32> } +// CHECK-LABEL: @dynamic_iota_broadcast +func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> + // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> + + // CHECK: return [[BROADCAST]] + return %0 : tensor<5x?xi32> +} + +// CHECK-LABEL: @dynamic_iota_broadcast_second +func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { + // CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64> + // CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex> + // CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor + // CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor<5x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> + + // CHECK: return [[BROADCAST]] + return %0 : tensor<5x?xi32> +} + + // CHECK-LABEL: @iota_not_lowered_to_constant func @iota_not_lowered_to_constant() -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "mhlo.iota" @@ -399,6 +423,24 @@ func @iota_not_lowered_to_constant() -> tensor<4xi32> { return %0 : tensor<4xi32> } +// CHECK-LABEL: @iota_broadcast +func @iota_broadcast() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32> + + return %0 : tensor<5x4xi32> +} + +// CHECK-LABEL: @iota_broadcast +func @iota_broadcast_second() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32> + + return %0 : tensor<5x4xi32> +} + // CHECK-LABEL: @unary_einsum func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor