Fold xla iota across a 1-length dimension into a zero value

Iota across length-1 is just a constant. Fold into it.

PiperOrigin-RevId: 320443468
This commit is contained in:
Robert Suderman 2020-07-09 18:58:28 +00:00 committed by Mehdi Amini
parent c3be2474dd
commit 06ae59074f
3 changed files with 39 additions and 0 deletions

View File

@ -76,6 +76,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
// TODO(b/130357376): Iota has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {

View File

@ -248,6 +248,17 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<IotaBroadcast>(context);
}
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
auto dimension = iota_dimension().getLimitedValue();
auto result_ty = getResult().getType().cast<ShapedType>();
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
Builder builder(getContext());
return builder.getZeroAttr(result_ty);
}
return {};
}
//===----------------------------------------------------------------------===//
// DynamicIotaOp
//===----------------------------------------------------------------------===//

View File

@ -432,6 +432,33 @@ func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32>
return %0 : tensor<5x?xi32>
}
// CHECK-LABEL: @dynamic_iota_constant
func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> {
// CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32>
// CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32>
// CHECK: return [[BROADCAST]]
return %0 : tensor<1x?xi32>
}
// CHECK-LABEL: @iota_constant
func @iota_constant() -> tensor<1xi32> {
// CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32>
// CHECK: return [[CONST]] : tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK-LABEL: @iota_constant_multi
func @iota_constant_multi() -> tensor<1x4xi32> {
// CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32>
// CHECK: return [[CONST]] : tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// CHECK-LABEL: @iota_not_lowered_to_constant
func @iota_not_lowered_to_constant() -> tensor<4xi32> {