From 06ae59074f70ad01fa2c96f50cf31d8b36af1330 Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Thu, 9 Jul 2020 18:58:28 +0000 Subject: [PATCH] Fold xla iota across a 1-length dimension into a zero value Iota across length-1 is just a constant. Fold into it. PiperOrigin-RevId: 320443468 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 11 +++++++++ tests/canonicalize.mlir | 27 +++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index c714ef2..4ce8ba5 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -76,6 +76,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { // TODO(b/130357376): Iota has special conversion logic to HLO. let hasCustomHLOConverter = 1; let hasCanonicalizer = 1; + let hasFolder = 1; } def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index c9be037..16c19cb 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -248,6 +248,17 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } +OpFoldResult IotaOp::fold(ArrayRef operands) { + auto dimension = iota_dimension().getLimitedValue(); + auto result_ty = getResult().getType().cast(); + if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) { + Builder builder(getContext()); + return builder.getZeroAttr(result_ty); + } + + return {}; +} + //===----------------------------------------------------------------------===// // DynamicIotaOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 2f67427..8777412 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -432,6 +432,33 @@ func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> return %0 : tensor<5x?xi32> } +// CHECK-LABEL: @dynamic_iota_constant +func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> { + // CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32> + // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32> + %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32> + + // CHECK: return [[BROADCAST]] + return %0 : tensor<1x?xi32> +} + +// CHECK-LABEL: @iota_constant +func @iota_constant() -> tensor<1xi32> { + // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32> + + // CHECK: return [[CONST]] : tensor<1xi32> + return %0 : tensor<1xi32> +} + +// CHECK-LABEL: @iota_constant_multi +func @iota_constant_multi() -> tensor<1x4xi32> { + // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32> + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32> + + // CHECK: return [[CONST]] : tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} // CHECK-LABEL: @iota_not_lowered_to_constant func @iota_not_lowered_to_constant() -> tensor<4xi32> {