diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 11f2fee..39e93e5 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -93,13 +93,23 @@ Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type, } SmallVector ExtractDynamicSizes(OpBuilder& b, Location loc, - Value tensor) { + Value tensor, + Value shape_tensor = nullptr) { auto tensor_type = tensor.getType().dyn_cast(); if (!tensor_type) return {}; SmallVector dyn_sizes; for (auto& en : llvm::enumerate(tensor_type.getShape())) { if (en.value() != ShapedType::kDynamicSize) continue; - dyn_sizes.push_back(b.create(loc, tensor, en.index())); + // If a shape tensor is present extract from there. + if (shape_tensor) { + Value extract = b.create( + loc, shape_tensor, + ValueRange{b.create(loc, en.index())}); + dyn_sizes.push_back( + b.create(loc, b.getIndexType(), extract)); + } else { + dyn_sizes.push_back(b.create(loc, tensor, en.index())); + } } return dyn_sizes; } @@ -868,17 +878,20 @@ class IotaConverter : public OpConversionPattern { unsigned nloops = result_shaped_type.getRank(); Location loc = iota_op.getLoc(); - auto dyn_sizes = isLHLO - ? SmallVector() - : ExtractDynamicSizes(rewriter, loc, - GetResultValue(iota_op)); + // If this is a dynamic iota, the first argument will be a shape tensor. + Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr; + auto dyn_sizes = + isLHLO + ? SmallVector() + : ExtractDynamicSizes( + rewriter, loc, GetResultValue(iota_op), shape_tensor); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/ isLHLO ? ArrayRef{} : ArrayRef{result_shaped_type}, /*inputs=*/ValueRange{}, /*outputBuffers=*/ - isLHLO ? ValueRange{args} + isLHLO ? ValueRange{args.back()} : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type, dyn_sizes)}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), @@ -1636,6 +1649,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, BroadcastConverter, ConstConverter, HloDynamicBroadcastInDimConverter, HloBroadcastInDimConverter, IotaConverter, + IotaConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index c8259d0..b9e37c5 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -839,6 +839,27 @@ func @iota() -> tensor<7x10xf32> { // ----- +// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @iota +// CHECK-SAME: %[[SHAPE:.*]]: tensor +func @iota(%shape: tensor) -> tensor { + %result = "mhlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) + return %result : tensor +} +// CHECK: %[[E1:.*]] = tensor.extract %[[SHAPE]][%c0] : tensor +// CHECK: %[[I1:.*]] = index_cast %[[E1]] : i32 to index +// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor +// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index +// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32): +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 +// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 + +// ----- + func @shift_left(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %result = "mhlo.shift_left"(%lhs, %rhs)