Fold xla iota across a 1-length dimension into a zero value
Iota across length-1 is just a constant. Fold into it. PiperOrigin-RevId: 320443468
This commit is contained in:
		
							parent
							
								
									c3be2474dd
								
							
						
					
					
						commit
						06ae59074f
					
				| 
						 | 
					@ -76,6 +76,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
 | 
				
			||||||
  // TODO(b/130357376): Iota has special conversion logic to HLO.
 | 
					  // TODO(b/130357376): Iota has special conversion logic to HLO.
 | 
				
			||||||
  let hasCustomHLOConverter = 1;
 | 
					  let hasCustomHLOConverter = 1;
 | 
				
			||||||
  let hasCanonicalizer = 1;
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
 | 
					  let hasFolder = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
 | 
					def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -248,6 +248,17 @@ void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
 | 
				
			||||||
  results.insert<IotaBroadcast>(context);
 | 
					  results.insert<IotaBroadcast>(context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
 | 
				
			||||||
 | 
					  auto dimension = iota_dimension().getLimitedValue();
 | 
				
			||||||
 | 
					  auto result_ty = getResult().getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					  if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
 | 
				
			||||||
 | 
					    Builder builder(getContext());
 | 
				
			||||||
 | 
					    return builder.getZeroAttr(result_ty);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return {};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// DynamicIotaOp
 | 
					// DynamicIotaOp
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -432,6 +432,33 @@ func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32>
 | 
				
			||||||
  return %0 : tensor<5x?xi32>
 | 
					  return %0 : tensor<5x?xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @dynamic_iota_constant
 | 
				
			||||||
 | 
					func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> {
 | 
				
			||||||
 | 
					  // CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32>
 | 
				
			||||||
 | 
					  // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: return [[BROADCAST]]
 | 
				
			||||||
 | 
					  return %0 : tensor<1x?xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @iota_constant
 | 
				
			||||||
 | 
					func @iota_constant() -> tensor<1xi32> {
 | 
				
			||||||
 | 
					  // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: return [[CONST]] : tensor<1xi32>
 | 
				
			||||||
 | 
					  return %0 : tensor<1xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @iota_constant_multi
 | 
				
			||||||
 | 
					func @iota_constant_multi() -> tensor<1x4xi32> {
 | 
				
			||||||
 | 
					  // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: return [[CONST]] : tensor<1x4xi32>
 | 
				
			||||||
 | 
					  return %0 : tensor<1x4xi32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: @iota_not_lowered_to_constant
 | 
					// CHECK-LABEL: @iota_not_lowered_to_constant
 | 
				
			||||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
 | 
					func @iota_not_lowered_to_constant() -> tensor<4xi32> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue