Fold xla iota across a 1-length dimension into a zero value
Iota across length-1 is just a constant. Fold into it. PiperOrigin-RevId: 320443468
This commit is contained in:
parent
c3be2474dd
commit
06ae59074f
|
@ -76,6 +76,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
|
||||||
// TODO(b/130357376): Iota has special conversion logic to HLO.
|
// TODO(b/130357376): Iota has special conversion logic to HLO.
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
||||||
|
|
|
@ -248,6 +248,17 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
results.insert<IotaBroadcast>(context);
|
results.insert<IotaBroadcast>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto dimension = iota_dimension().getLimitedValue();
|
||||||
|
auto result_ty = getResult().getType().cast<ShapedType>();
|
||||||
|
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
|
||||||
|
Builder builder(getContext());
|
||||||
|
return builder.getZeroAttr(result_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DynamicIotaOp
|
// DynamicIotaOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -432,6 +432,33 @@ func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32>
|
||||||
return %0 : tensor<5x?xi32>
|
return %0 : tensor<5x?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dynamic_iota_constant
|
||||||
|
func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> {
|
||||||
|
// CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32>
|
||||||
|
// CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32>
|
||||||
|
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32>
|
||||||
|
|
||||||
|
// CHECK: return [[BROADCAST]]
|
||||||
|
return %0 : tensor<1x?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @iota_constant
|
||||||
|
func @iota_constant() -> tensor<1xi32> {
|
||||||
|
// CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32>
|
||||||
|
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32>
|
||||||
|
|
||||||
|
// CHECK: return [[CONST]] : tensor<1xi32>
|
||||||
|
return %0 : tensor<1xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @iota_constant_multi
|
||||||
|
func @iota_constant_multi() -> tensor<1x4xi32> {
|
||||||
|
// CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32>
|
||||||
|
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32>
|
||||||
|
|
||||||
|
// CHECK: return [[CONST]] : tensor<1x4xi32>
|
||||||
|
return %0 : tensor<1x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @iota_not_lowered_to_constant
|
// CHECK-LABEL: @iota_not_lowered_to_constant
|
||||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
||||||
|
|
Loading…
Reference in New Issue