[MLIR][HLO:Linalg] Lower mhlo.dynamic_iota to indexed_generic
This is the same as iota, but instead of taking the dimensions from the result tensor we use the supplied shape extents tensor. PiperOrigin-RevId: 362298548
This commit is contained in:
		
							parent
							
								
									09f8046816
								
							
						
					
					
						commit
						94f9740c67
					
				|  | @ -93,14 +93,24 @@ Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type, | |||
| } | ||||
| 
 | ||||
| SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc, | ||||
|                                           Value tensor) { | ||||
|                                           Value tensor, | ||||
|                                           Value shape_tensor = nullptr) { | ||||
|   auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>(); | ||||
|   if (!tensor_type) return {}; | ||||
|   SmallVector<Value, 2> dyn_sizes; | ||||
|   for (auto& en : llvm::enumerate(tensor_type.getShape())) { | ||||
|     if (en.value() != ShapedType::kDynamicSize) continue; | ||||
|     // If a shape tensor is present extract from there.
 | ||||
|     if (shape_tensor) { | ||||
|       Value extract = b.create<tensor::ExtractOp>( | ||||
|           loc, shape_tensor, | ||||
|           ValueRange{b.create<ConstantIndexOp>(loc, en.index())}); | ||||
|       dyn_sizes.push_back( | ||||
|           b.create<IndexCastOp>(loc, b.getIndexType(), extract)); | ||||
|     } else { | ||||
|       dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index())); | ||||
|     } | ||||
|   } | ||||
|   return dyn_sizes; | ||||
| } | ||||
| 
 | ||||
|  | @ -868,17 +878,20 @@ class IotaConverter : public OpConversionPattern<OpTy> { | |||
|     unsigned nloops = result_shaped_type.getRank(); | ||||
| 
 | ||||
|     Location loc = iota_op.getLoc(); | ||||
|     auto dyn_sizes = isLHLO | ||||
|     // If this is a dynamic iota, the first argument will be a shape tensor.
 | ||||
|     Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr; | ||||
|     auto dyn_sizes = | ||||
|         isLHLO | ||||
|             ? SmallVector<Value, 2>() | ||||
|                          : ExtractDynamicSizes(rewriter, loc, | ||||
|                                                GetResultValue<isLHLO>(iota_op)); | ||||
|             : ExtractDynamicSizes( | ||||
|                   rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor); | ||||
|     auto linalg_op = rewriter.create<linalg::IndexedGenericOp>( | ||||
|         loc, | ||||
|         /*resultTensorTypes=*/ | ||||
|         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type}, | ||||
|         /*inputs=*/ValueRange{}, | ||||
|         /*outputBuffers=*/ | ||||
|         isLHLO ? ValueRange{args} | ||||
|         isLHLO ? ValueRange{args.back()} | ||||
|                : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type, | ||||
|                                           dyn_sizes)}, | ||||
|         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), | ||||
|  | @ -1636,6 +1649,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | |||
|       BroadcastConverter<mhlo::BroadcastOp, false>, | ||||
|       ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, | ||||
|       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, | ||||
|       IotaConverter<mhlo::DynamicIotaOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::AbsOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::AddOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::AndOp, false>, | ||||
|  |  | |||
|  | @ -839,6 +839,27 @@ func @iota() -> tensor<7x10xf32> { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||||
| // CHECK-LABEL: func @iota | ||||
| // CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32> | ||||
| func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> { | ||||
|   %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>) | ||||
|   return %result : tensor<?x?x8xf32> | ||||
| } | ||||
| // CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor<?xi32> | ||||
| // CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index | ||||
| // CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32> | ||||
| // CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index | ||||
| // CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32> | ||||
| // CHECK: linalg.indexed_generic | ||||
| // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] | ||||
| // CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32): | ||||
| // CHECK-NEXT:   %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 | ||||
| // CHECK-NEXT:   %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 | ||||
| // CHECK-NEXT:   linalg.yield %[[FLOAT_CAST]] : f32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @shift_left(%lhs: tensor<2x2xi32>, | ||||
|                  %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { | ||||
|   %result = "mhlo.shift_left"(%lhs, %rhs) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue