Use linalg.fill on tensors instead of tensor.generate in MHLO -> Linalg conversion.
linalg.fill on tensors is a structured op that allows use tile + fuse to reduce the fill overhead. PiperOrigin-RevId: 355490400
This commit is contained in:
		
							parent
							
								
									10cd797d6d
								
							
						
					
					
						commit
						44d0464d16
					
				| 
						 | 
					@ -84,22 +84,12 @@ bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
 | 
				
			||||||
                : llvm::all_of(op->getResults(), verify_type);
 | 
					                : llvm::all_of(op->getResults(), verify_type);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(pifon): Migrate to InitTensorOp when available.
 | 
					 | 
				
			||||||
template <bool isLHLO>
 | 
					 | 
				
			||||||
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
 | 
					Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
 | 
				
			||||||
                    SmallVectorImpl<Value>& dyn_sizes) {
 | 
					                    ArrayRef<Value> dyn_sizes) {
 | 
				
			||||||
  if (isLHLO) return nullptr;
 | 
					 | 
				
			||||||
  return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
 | 
					  return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
 | 
				
			||||||
                                        type.getElementType());
 | 
					                                        type.getElementType());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <bool isLHLO>
 | 
					 | 
				
			||||||
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type) {
 | 
					 | 
				
			||||||
  SmallVector<Value, 0> empty;
 | 
					 | 
				
			||||||
  return GetInitTensor<isLHLO>(b, loc, type, empty);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// TODO(pifon): This logic is used everywhere, the code should be shared.
 | 
					 | 
				
			||||||
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
 | 
					SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
 | 
				
			||||||
                                          Value tensor) {
 | 
					                                          Value tensor) {
 | 
				
			||||||
  auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
 | 
					  auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
| 
						 | 
					@ -200,7 +190,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
      ShapedType result_type = result.getType().template cast<ShapedType>();
 | 
					      ShapedType result_type = result.getType().template cast<ShapedType>();
 | 
				
			||||||
      auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
 | 
					      auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
 | 
				
			||||||
      output_buffers.push_back(
 | 
					      output_buffers.push_back(
 | 
				
			||||||
          GetInitTensor<isLHLO>(rewriter, loc, result_type, dyn_sizes));
 | 
					          GetInitTensor(rewriter, loc, result_type, dyn_sizes));
 | 
				
			||||||
      op_result_types.push_back(result.getType());
 | 
					      op_result_types.push_back(result.getType());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    body_result_types = llvm::to_vector<4>(llvm::map_range(
 | 
					    body_result_types = llvm::to_vector<4>(llvm::map_range(
 | 
				
			||||||
| 
						 | 
					@ -397,9 +387,9 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
        /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
 | 
					        /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
 | 
				
			||||||
        /*inputs=*/args.front(),
 | 
					        /*inputs=*/args.front(),
 | 
				
			||||||
        /*outputBuffers=*/
 | 
					        /*outputBuffers=*/
 | 
				
			||||||
        isLHLO ? ValueRange{args.back()}
 | 
					        isLHLO
 | 
				
			||||||
               : ValueRange{GetInitTensor<isLHLO>(rewriter, loc, result_type,
 | 
					            ? ValueRange{args.back()}
 | 
				
			||||||
                                                  dyn_sizes)},
 | 
					            : ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
 | 
				
			||||||
        indexing_maps, GetNParallelLoopsAttrs(nloops),
 | 
					        indexing_maps, GetNParallelLoopsAttrs(nloops),
 | 
				
			||||||
        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
 | 
					        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
 | 
				
			||||||
          nested_builder.create<linalg::YieldOp>(loc, *args.begin());
 | 
					          nested_builder.create<linalg::YieldOp>(loc, *args.begin());
 | 
				
			||||||
| 
						 | 
					@ -859,6 +849,10 @@ class IotaConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
    unsigned nloops = result_shaped_type.getRank();
 | 
					    unsigned nloops = result_shaped_type.getRank();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Location loc = iota_op.getLoc();
 | 
					    Location loc = iota_op.getLoc();
 | 
				
			||||||
 | 
					    auto dyn_sizes = isLHLO
 | 
				
			||||||
 | 
					                         ? SmallVector<Value, 2>()
 | 
				
			||||||
 | 
					                         : ExtractDynamicSizes(rewriter, loc,
 | 
				
			||||||
 | 
					                                               GetResultValue<isLHLO>(iota_op));
 | 
				
			||||||
    auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
 | 
					    auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
 | 
				
			||||||
        loc,
 | 
					        loc,
 | 
				
			||||||
        /*resultTensorTypes=*/
 | 
					        /*resultTensorTypes=*/
 | 
				
			||||||
| 
						 | 
					@ -866,8 +860,8 @@ class IotaConverter : public OpConversionPattern<OpTy> {
 | 
				
			||||||
        /*inputs=*/ValueRange{},
 | 
					        /*inputs=*/ValueRange{},
 | 
				
			||||||
        /*outputBuffers=*/
 | 
					        /*outputBuffers=*/
 | 
				
			||||||
        isLHLO ? ValueRange{args}
 | 
					        isLHLO ? ValueRange{args}
 | 
				
			||||||
               : ValueRange{GetInitTensor<isLHLO>(rewriter, loc,
 | 
					               : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
 | 
				
			||||||
                                                  result_shaped_type)},
 | 
					                                          dyn_sizes)},
 | 
				
			||||||
        llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
 | 
					        llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
 | 
				
			||||||
        GetNParallelLoopsAttrs(nloops),
 | 
					        GetNParallelLoopsAttrs(nloops),
 | 
				
			||||||
        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
 | 
					        [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
 | 
				
			||||||
| 
						 | 
					@ -1107,21 +1101,20 @@ DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
 | 
				
			||||||
  return DotOperationType::kUnsupported;
 | 
					  return DotOperationType::kUnsupported;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SmallVector<Value, 8> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
 | 
					SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
 | 
				
			||||||
                                                 Value lhs, Value rhs,
 | 
					                                                 Value lhs, Value rhs,
 | 
				
			||||||
                                                 ShapedType result_type,
 | 
					 | 
				
			||||||
                                                 DotOperationType type) {
 | 
					                                                 DotOperationType type) {
 | 
				
			||||||
  SmallVector<Value, 8> dyn_shape;
 | 
					  SmallVector<Value, 2> dyn_shape;
 | 
				
			||||||
  switch (type) {
 | 
					  switch (type) {
 | 
				
			||||||
    case DotOperationType::kMatrixMatrix: {
 | 
					    case DotOperationType::kMatrixMatrix: {
 | 
				
			||||||
      if (result_type.isDynamicDim(0))
 | 
					      if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
 | 
				
			||||||
        dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
					        dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
				
			||||||
      if (result_type.isDynamicDim(1))
 | 
					      if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
 | 
				
			||||||
        dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
 | 
					        dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    case DotOperationType::kMatrixVector: {
 | 
					    case DotOperationType::kMatrixVector: {
 | 
				
			||||||
      if (result_type.isDynamicDim(0))
 | 
					      if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
 | 
				
			||||||
        dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
					        dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -1148,39 +1141,31 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
 | 
				
			||||||
    Type result_type = op.getResult().getType();
 | 
					    Type result_type = op.getResult().getType();
 | 
				
			||||||
    auto shaped_type = result_type.cast<ShapedType>();
 | 
					    auto shaped_type = result_type.cast<ShapedType>();
 | 
				
			||||||
    DotOperationType op_type = GetDotOperationType(op);
 | 
					    DotOperationType op_type = GetDotOperationType(op);
 | 
				
			||||||
    SmallVector<Value, 8> dyn_shape = GetDotOpInitTensorDynSizes(
 | 
					 | 
				
			||||||
        rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type);
 | 
					 | 
				
			||||||
    auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
 | 
					    auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
 | 
				
			||||||
    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
 | 
					    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
 | 
				
			||||||
    auto init_tensor =
 | 
					    SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
 | 
				
			||||||
        rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
 | 
					        rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
 | 
				
			||||||
    {
 | 
					    auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
 | 
				
			||||||
      OpBuilder::InsertionGuard guard(rewriter);
 | 
					    Value zero_tensor =
 | 
				
			||||||
      SmallVector<Type, 4> arg_types(shaped_type.getRank(),
 | 
					        rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
 | 
				
			||||||
                                     rewriter.getIndexType());
 | 
					 | 
				
			||||||
      Region& region = init_tensor.body();
 | 
					 | 
				
			||||||
      Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
 | 
					 | 
				
			||||||
      rewriter.setInsertionPointToEnd(block);
 | 
					 | 
				
			||||||
      rewriter.create<tensor::YieldOp>(loc, zero);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    linalg::LinalgOp linalg_op;
 | 
					    linalg::LinalgOp linalg_op;
 | 
				
			||||||
    switch (op_type) {
 | 
					    switch (op_type) {
 | 
				
			||||||
      case DotOperationType::kMatrixMatrix: {
 | 
					      case DotOperationType::kMatrixMatrix: {
 | 
				
			||||||
        linalg_op = rewriter.create<linalg::MatmulOp>(
 | 
					        linalg_op = rewriter.create<linalg::MatmulOp>(
 | 
				
			||||||
            loc, TypeRange{result_type},
 | 
					            loc, TypeRange{result_type},
 | 
				
			||||||
            ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
 | 
					            ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      case DotOperationType::kMatrixVector: {
 | 
					      case DotOperationType::kMatrixVector: {
 | 
				
			||||||
        linalg_op = rewriter.create<linalg::MatvecOp>(
 | 
					        linalg_op = rewriter.create<linalg::MatvecOp>(
 | 
				
			||||||
            loc, TypeRange{result_type},
 | 
					            loc, TypeRange{result_type},
 | 
				
			||||||
            ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
 | 
					            ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      case DotOperationType::kVectorDot: {
 | 
					      case DotOperationType::kVectorDot: {
 | 
				
			||||||
        linalg_op = rewriter.create<linalg::DotOp>(
 | 
					        linalg_op = rewriter.create<linalg::DotOp>(
 | 
				
			||||||
            loc, TypeRange{result_type},
 | 
					            loc, TypeRange{result_type},
 | 
				
			||||||
            ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
 | 
					            ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      case DotOperationType::kUnsupported:
 | 
					      case DotOperationType::kUnsupported:
 | 
				
			||||||
| 
						 | 
					@ -1248,21 +1233,13 @@ class DotGeneralOpOnTensorsConversion
 | 
				
			||||||
        rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
 | 
					        rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
 | 
				
			||||||
    auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
 | 
					    auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
 | 
				
			||||||
    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
 | 
					    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
 | 
				
			||||||
    auto init_tensor =
 | 
					    auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
 | 
				
			||||||
        rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
 | 
					    Value zero_tensor =
 | 
				
			||||||
    {
 | 
					        rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
 | 
				
			||||||
      OpBuilder::InsertionGuard guard(rewriter);
 | 
					 | 
				
			||||||
      SmallVector<Type, 4> arg_types(shaped_type.getRank(),
 | 
					 | 
				
			||||||
                                     rewriter.getIndexType());
 | 
					 | 
				
			||||||
      Region& region = init_tensor.body();
 | 
					 | 
				
			||||||
      Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
 | 
					 | 
				
			||||||
      rewriter.setInsertionPointToEnd(block);
 | 
					 | 
				
			||||||
      rewriter.create<tensor::YieldOp>(loc, zero);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
 | 
					    auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
 | 
				
			||||||
        loc, /*resultTensorTypes=*/TypeRange{result_type},
 | 
					        loc, /*resultTensorTypes=*/TypeRange{result_type},
 | 
				
			||||||
        /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
 | 
					        /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
 | 
				
			||||||
        /*outputBuffers=*/ValueRange{init_tensor});
 | 
					        /*outputBuffers=*/ValueRange{zero_tensor});
 | 
				
			||||||
    rewriter.replaceOp(op, linalg_op.getResults());
 | 
					    rewriter.replaceOp(op, linalg_op.getResults());
 | 
				
			||||||
    return success();
 | 
					    return success();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -1375,21 +1352,14 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
 | 
				
			||||||
    SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
 | 
					    SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
 | 
				
			||||||
        rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
 | 
					        rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
 | 
				
			||||||
        reduction_dims);
 | 
					        reduction_dims);
 | 
				
			||||||
    auto init_tensor =
 | 
					    auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
 | 
				
			||||||
        rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
 | 
					    Value filled_tensor =
 | 
				
			||||||
    {
 | 
					        rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
 | 
				
			||||||
      OpBuilder::InsertionGuard guard(rewriter);
 | 
					            .getResult(0);
 | 
				
			||||||
      SmallVector<Type, 4> arg_types(shaped_type.getRank(),
 | 
					 | 
				
			||||||
                                     rewriter.getIndexType());
 | 
					 | 
				
			||||||
      Region& region = init_tensor.body();
 | 
					 | 
				
			||||||
      Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
 | 
					 | 
				
			||||||
      rewriter.setInsertionPointToEnd(block);
 | 
					 | 
				
			||||||
      rewriter.create<tensor::YieldOp>(loc, init_value);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto linalg_op = rewriter.create<linalg::GenericOp>(
 | 
					    auto linalg_op = rewriter.create<linalg::GenericOp>(
 | 
				
			||||||
        loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
 | 
					        loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
 | 
				
			||||||
        /*outputBuffers=*/ValueRange{init_tensor}, indexing_maps,
 | 
					        /*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
 | 
				
			||||||
        GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
 | 
					        GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Convert the signature of the body. The reduce op region apply function
 | 
					    // Convert the signature of the body. The reduce op region apply function
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -910,10 +910,13 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
 | 
				
			||||||
  return %0 : tensor<2x?xf32>
 | 
					  return %0 : tensor<2x?xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
 | 
					// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.generate
 | 
					// CHECK: %[[C1:.*]] = constant 1 : index
 | 
				
			||||||
 | 
					// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
				
			||||||
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
				
			||||||
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.matmul
 | 
					// CHECK: linalg.matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -924,10 +927,13 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
 | 
				
			||||||
  return %0 : tensor<?xf32>
 | 
					  return %0 : tensor<?xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
 | 
					// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.generate
 | 
					// CHECK: %[[C0:.*]] = constant 0 : index
 | 
				
			||||||
 | 
					// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
 | 
				
			||||||
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
 | 
				
			||||||
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.matvec
 | 
					// CHECK: linalg.matvec
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?xf32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -937,10 +943,11 @@ func @dot_dot(%arg0: tensor<?xf32>,
 | 
				
			||||||
  return %0 : tensor<f32>
 | 
					  return %0 : tensor<f32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
 | 
					// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.generate
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor []
 | 
				
			||||||
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.dot
 | 
					// CHECK: linalg.dot
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<f32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -958,10 +965,40 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
 | 
				
			||||||
  return %0 : tensor<?x?x?xf32>
 | 
					  return %0 : tensor<?x?x?xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
 | 
					// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.generate
 | 
					// CHECK: %[[C0:.*]] = constant 0 : index
 | 
				
			||||||
 | 
					// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
 | 
				
			||||||
 | 
					// CHECK: %[[C1:.*]] = constant 1 : index
 | 
				
			||||||
 | 
					// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
 | 
				
			||||||
 | 
					// CHECK: %[[C2:.*]] = constant 2 : index
 | 
				
			||||||
 | 
					// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
				
			||||||
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
				
			||||||
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.batch_matmul
 | 
					// CHECK: linalg.batch_matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @batch_matmul_large
 | 
				
			||||||
 | 
					  (%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> {
 | 
				
			||||||
 | 
					  %0 = "mhlo.dot_general"(%arg0, %arg1) {
 | 
				
			||||||
 | 
					    dot_dimension_numbers = {
 | 
				
			||||||
 | 
					      lhs_batching_dimensions = dense<0> : tensor<1xi64>,
 | 
				
			||||||
 | 
					      lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
 | 
				
			||||||
 | 
					      rhs_batching_dimensions = dense<0> : tensor<1xi64>,
 | 
				
			||||||
 | 
					      rhs_contracting_dimensions = dense<1> : tensor<1xi64>},
 | 
				
			||||||
 | 
					    precision_config = ["DEFAULT", "DEFAULT"]}
 | 
				
			||||||
 | 
					    : (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<2x16x32xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK: func @batch_matmul_large(
 | 
				
			||||||
 | 
					// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>,
 | 
				
			||||||
 | 
					// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>)
 | 
				
			||||||
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]
 | 
				
			||||||
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
 | 
					// CHECK: %[[DOT:.*]] = linalg.batch_matmul
 | 
				
			||||||
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x16x32xf32>, tensor<2x32x32xf32>)
 | 
				
			||||||
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<2x16x32xf32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1001,13 +1038,14 @@ func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
				
			||||||
// CHECK-LABEL: @reduce_add
 | 
					// CHECK-LABEL: @reduce_add
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
					// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5]
 | 
				
			||||||
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
				
			||||||
| 
						 | 
					@ -1025,13 +1063,14 @@ func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
				
			||||||
// CHECK-LABEL: @reduce_minimum
 | 
					// CHECK-LABEL: @reduce_minimum
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
					// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5]
 | 
				
			||||||
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
				
			||||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
| 
						 | 
					@ -1050,13 +1089,14 @@ func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
				
			||||||
// CHECK-LABEL: @reduce_maximum
 | 
					// CHECK-LABEL: @reduce_maximum
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
					// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5]
 | 
				
			||||||
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
				
			||||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
| 
						 | 
					@ -1075,13 +1115,14 @@ func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<4xi32> {
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
				
			||||||
// CHECK-LABEL: @reduce_dim0
 | 
					// CHECK-LABEL: @reduce_dim0
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
					// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [4]
 | 
				
			||||||
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
				
			||||||
// CHECK-NEXT:   %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
| 
						 | 
					@ -1101,13 +1142,13 @@ func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> {
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
				
			||||||
// CHECK-LABEL: @reduce_init_const
 | 
					// CHECK-LABEL: @reduce_init_const
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [1]
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %{{.*}})
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<1xf32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
 | 
				
			||||||
| 
						 | 
					@ -1126,13 +1167,14 @@ func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>,
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
 | 
				
			||||||
// CHECK-LABEL: @reduce_multi_dimensions
 | 
					// CHECK-LABEL: @reduce_multi_dimensions
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
					// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [4]
 | 
				
			||||||
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
				
			||||||
| 
						 | 
					@ -1150,15 +1192,16 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
 | 
				
			||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
					// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 | 
				
			||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
					// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 | 
				
			||||||
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
 | 
					// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
					// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
 | 
				
			||||||
// CHECK: %[[C0:.*]] = constant 0 : index
 | 
					// CHECK-DAG: %[[C0:.*]] = constant 0 : index
 | 
				
			||||||
// CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
 | 
					// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
 | 
				
			||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
 | 
					// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
 | 
				
			||||||
 | 
					// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
 | 
				
			||||||
// CHECK: linalg.generic
 | 
					// CHECK: linalg.generic
 | 
				
			||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
					// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
 | 
				
			||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
					// CHECK-SAME: iterator_types = ["parallel", "reduction"]
 | 
				
			||||||
// CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>)
 | 
					// CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<?xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<?xi32>)
 | 
				
			||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
					// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
 | 
				
			||||||
// CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
					// CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
 | 
				
			||||||
// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
					// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue