[MLIR][HLO:Linalg] Lower mhlo.dynamic_iota to indexed_generic
This is the same as iota, but instead of taking the dimensions from the result tensor we use the supplied shape extents tensor. PiperOrigin-RevId: 362298548
This commit is contained in:
parent
09f8046816
commit
94f9740c67
|
@ -93,14 +93,24 @@ Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
|
|||
}
|
||||
|
||||
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
||||
Value tensor) {
|
||||
Value tensor,
|
||||
Value shape_tensor = nullptr) {
|
||||
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensor_type) return {};
|
||||
SmallVector<Value, 2> dyn_sizes;
|
||||
for (auto& en : llvm::enumerate(tensor_type.getShape())) {
|
||||
if (en.value() != ShapedType::kDynamicSize) continue;
|
||||
// If a shape tensor is present extract from there.
|
||||
if (shape_tensor) {
|
||||
Value extract = b.create<tensor::ExtractOp>(
|
||||
loc, shape_tensor,
|
||||
ValueRange{b.create<ConstantIndexOp>(loc, en.index())});
|
||||
dyn_sizes.push_back(
|
||||
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
|
||||
} else {
|
||||
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
|
||||
}
|
||||
}
|
||||
return dyn_sizes;
|
||||
}
|
||||
|
||||
|
@ -868,17 +878,20 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
|||
unsigned nloops = result_shaped_type.getRank();
|
||||
|
||||
Location loc = iota_op.getLoc();
|
||||
auto dyn_sizes = isLHLO
|
||||
// If this is a dynamic iota, the first argument will be a shape tensor.
|
||||
Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr;
|
||||
auto dyn_sizes =
|
||||
isLHLO
|
||||
? SmallVector<Value, 2>()
|
||||
: ExtractDynamicSizes(rewriter, loc,
|
||||
GetResultValue<isLHLO>(iota_op));
|
||||
: ExtractDynamicSizes(
|
||||
rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
|
||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
||||
loc,
|
||||
/*resultTensorTypes=*/
|
||||
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
||||
/*inputs=*/ValueRange{},
|
||||
/*outputBuffers=*/
|
||||
isLHLO ? ValueRange{args}
|
||||
isLHLO ? ValueRange{args.back()}
|
||||
: ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
|
||||
dyn_sizes)},
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
|
@ -1636,6 +1649,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
IotaConverter<mhlo::DynamicIotaOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
|
|
|
@ -839,6 +839,27 @@ func @iota() -> tensor<7x10xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @iota
|
||||
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
|
||||
func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
|
||||
%result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
|
||||
return %result : tensor<?x?x8xf32>
|
||||
}
|
||||
// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor<?xi32>
|
||||
// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index
|
||||
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
|
||||
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
|
||||
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32>
|
||||
// CHECK: linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32):
|
||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @shift_left(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
%result = "mhlo.shift_left"(%lhs, %rhs)
|
||||
|
|
Loading…
Reference in New Issue