From 44d0464d162283c834b6a194effa4b3735457d28 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Wed, 3 Feb 2021 15:02:21 -0800 Subject: [PATCH] Use linalg.fill on tensors instead of tensor.generate in MHLO -> Linalg conversion. linalg.fill on tensors is a structured op that allows use tile + fuse to reduce the fill overhead. PiperOrigin-RevId: 355490400 --- .../mhlo/transforms/legalize_to_linalg.cc | 96 ++++++---------- tests/hlo-legalize-to-linalg.mlir | 105 ++++++++++++------ 2 files changed, 107 insertions(+), 94 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index e267ce2..ae05e1b 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -84,22 +84,12 @@ bool VerifyHloOpBufferOrTensorSemantics(Operation* op) { : llvm::all_of(op->getResults(), verify_type); } -// TODO(pifon): Migrate to InitTensorOp when available. -template Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type, - SmallVectorImpl& dyn_sizes) { - if (isLHLO) return nullptr; + ArrayRef dyn_sizes) { return b.create(loc, dyn_sizes, type.getShape(), type.getElementType()); } -template -Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type) { - SmallVector empty; - return GetInitTensor(b, loc, type, empty); -} - -// TODO(pifon): This logic is used everywhere, the code should be shared. SmallVector ExtractDynamicSizes(OpBuilder& b, Location loc, Value tensor) { auto tensor_type = tensor.getType().dyn_cast(); @@ -200,7 +190,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { ShapedType result_type = result.getType().template cast(); auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]); output_buffers.push_back( - GetInitTensor(rewriter, loc, result_type, dyn_sizes)); + GetInitTensor(rewriter, loc, result_type, dyn_sizes)); op_result_types.push_back(result.getType()); } body_result_types = llvm::to_vector<4>(llvm::map_range( @@ -397,9 +387,9 @@ class DataMovementOpConverter : public OpConversionPattern { /*resultTensorTypes=*/isLHLO ? ArrayRef{} : result_type, /*inputs=*/args.front(), /*outputBuffers=*/ - isLHLO ? ValueRange{args.back()} - : ValueRange{GetInitTensor(rewriter, loc, result_type, - dyn_sizes)}, + isLHLO + ? ValueRange{args.back()} + : ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { nested_builder.create(loc, *args.begin()); @@ -859,6 +849,10 @@ class IotaConverter : public OpConversionPattern { unsigned nloops = result_shaped_type.getRank(); Location loc = iota_op.getLoc(); + auto dyn_sizes = isLHLO + ? SmallVector() + : ExtractDynamicSizes(rewriter, loc, + GetResultValue(iota_op)); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/ @@ -866,8 +860,8 @@ class IotaConverter : public OpConversionPattern { /*inputs=*/ValueRange{}, /*outputBuffers=*/ isLHLO ? ValueRange{args} - : ValueRange{GetInitTensor(rewriter, loc, - result_shaped_type)}, + : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type, + dyn_sizes)}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, @@ -1107,21 +1101,20 @@ DotOperationType GetDotOperationType(mhlo::DotOp dot_op) { return DotOperationType::kUnsupported; } -SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, +SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, Value lhs, Value rhs, - ShapedType result_type, DotOperationType type) { - SmallVector dyn_shape; + SmallVector dyn_shape; switch (type) { case DotOperationType::kMatrixMatrix: { - if (result_type.isDynamicDim(0)) + if (lhs.getType().cast().isDynamicDim(0)) dyn_shape.push_back(b.create(loc, lhs, 0)); - if (result_type.isDynamicDim(1)) + if (rhs.getType().cast().isDynamicDim(1)) dyn_shape.push_back(b.create(loc, rhs, 1)); break; } case DotOperationType::kMatrixVector: { - if (result_type.isDynamicDim(0)) + if (lhs.getType().cast().isDynamicDim(0)) dyn_shape.push_back(b.create(loc, lhs, 0)); break; } @@ -1148,39 +1141,31 @@ class DotOpOnTensorsConversion : public OpConversionPattern { Type result_type = op.getResult().getType(); auto shaped_type = result_type.cast(); DotOperationType op_type = GetDotOperationType(op); - SmallVector dyn_shape = GetDotOpInitTensorDynSizes( - rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type); auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); Value zero = rewriter.create(loc, zero_attr); - auto init_tensor = - rewriter.create(loc, result_type, dyn_shape); - { - OpBuilder::InsertionGuard guard(rewriter); - SmallVector arg_types(shaped_type.getRank(), - rewriter.getIndexType()); - Region& region = init_tensor.body(); - Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); - rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, zero); - } + SmallVector dyn_shape = GetDotOpInitTensorDynSizes( + rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type); + auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); + Value zero_tensor = + rewriter.create(loc, init_tensor, zero).getResult(0); linalg::LinalgOp linalg_op; switch (op_type) { case DotOperationType::kMatrixMatrix: { linalg_op = rewriter.create( loc, TypeRange{result_type}, - ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor}); break; } case DotOperationType::kMatrixVector: { linalg_op = rewriter.create( loc, TypeRange{result_type}, - ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor}); break; } case DotOperationType::kVectorDot: { linalg_op = rewriter.create( loc, TypeRange{result_type}, - ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor}); break; } case DotOperationType::kUnsupported: @@ -1248,21 +1233,13 @@ class DotGeneralOpOnTensorsConversion rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type); auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); Value zero = rewriter.create(loc, zero_attr); - auto init_tensor = - rewriter.create(loc, result_type, dyn_shape); - { - OpBuilder::InsertionGuard guard(rewriter); - SmallVector arg_types(shaped_type.getRank(), - rewriter.getIndexType()); - Region& region = init_tensor.body(); - Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); - rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, zero); - } + auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); + Value zero_tensor = + rewriter.create(loc, init_tensor, zero).getResult(0); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/TypeRange{result_type}, /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, - /*outputBuffers=*/ValueRange{init_tensor}); + /*outputBuffers=*/ValueRange{zero_tensor}); rewriter.replaceOp(op, linalg_op.getResults()); return success(); } @@ -1375,21 +1352,14 @@ class ReduceOnTensorsConversion : public OpConversionPattern { SmallVector dyn_shape = GetReduceOpInitTensorDynSizes( rewriter, loc, adaptor.operands()[0], result_type.cast(), reduction_dims); - auto init_tensor = - rewriter.create(loc, result_type, dyn_shape); - { - OpBuilder::InsertionGuard guard(rewriter); - SmallVector arg_types(shaped_type.getRank(), - rewriter.getIndexType()); - Region& region = init_tensor.body(); - Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); - rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, init_value); - } + auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); + Value filled_tensor = + rewriter.create(loc, init_tensor, init_value) + .getResult(0); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/op.getResultTypes(), inputs, - /*outputBuffers=*/ValueRange{init_tensor}, indexing_maps, + /*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps, GetParallelAndReductionIterators(src_rank, reduction_dims.size())); // Convert the signature of the body. The reduce op region apply function diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 677369a..34dcfbf 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -910,10 +910,13 @@ func @dot_matmul(%arg0: tensor<2x3xf32>, return %0 : tensor<2x?xf32> } // CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) -// CHECK: %[[INIT:.*]] = tensor.generate +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) -// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>) // ----- @@ -924,10 +927,13 @@ func @dot_matvec(%arg0: tensor, return %0 : tensor } // CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) -// CHECK: %[[INIT:.*]] = tensor.generate +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.matvec // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3xf32>) -// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) // ----- @@ -937,10 +943,11 @@ func @dot_dot(%arg0: tensor, return %0 : tensor } // CHECK: func @dot_dot(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = tensor.generate +// CHECK: %[[INIT:.*]] = linalg.init_tensor [] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.dot // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) // ----- @@ -958,10 +965,40 @@ func @dot_general(%arg0: tensor, return %0 : tensor } // CHECK: func @dot_general(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = tensor.generate +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.batch_matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) + +// ----- + +func @batch_matmul_large + (%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, + precision_config = ["DEFAULT", "DEFAULT"]} + : (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32> + return %0 : tensor<2x16x32xf32> +} +// CHECK: func @batch_matmul_large( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>) +// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] +// CHECK: %[[DOT:.*]] = linalg.batch_matmul +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x16x32xf32>, tensor<2x32x32xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2x16x32xf32>) // ----- @@ -1001,13 +1038,14 @@ func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_add -// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): // CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -1025,13 +1063,14 @@ func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_minimum -// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): // CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 @@ -1050,13 +1089,14 @@ func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_maximum -// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): // CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 @@ -1075,13 +1115,14 @@ func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<4xi32> { // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_dim0 -// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [4] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): // CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 @@ -1101,13 +1142,13 @@ func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> { // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_init_const -// CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32 -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [1] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %{{.*}}) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<1xf32>) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): // CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 @@ -1126,13 +1167,14 @@ func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>, // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)> // CHECK-LABEL: @reduce_multi_dimensions -// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [4] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): // CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -1150,15 +1192,16 @@ func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor (d0, d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor -// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]] +// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]]) // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%{{.*}}tensor) -// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor) +// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor) // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): // CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32