Use linalg.fill on tensors instead of tensor.generate in MHLO -> Linalg conversion.
linalg.fill on tensors is a structured op that allows use tile + fuse to reduce the fill overhead. PiperOrigin-RevId: 355490400
This commit is contained in:
parent
10cd797d6d
commit
44d0464d16
|
@ -84,22 +84,12 @@ bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
|
||||||
: llvm::all_of(op->getResults(), verify_type);
|
: llvm::all_of(op->getResults(), verify_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(pifon): Migrate to InitTensorOp when available.
|
|
||||||
template <bool isLHLO>
|
|
||||||
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
|
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
|
||||||
SmallVectorImpl<Value>& dyn_sizes) {
|
ArrayRef<Value> dyn_sizes) {
|
||||||
if (isLHLO) return nullptr;
|
|
||||||
return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
|
return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
|
||||||
type.getElementType());
|
type.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool isLHLO>
|
|
||||||
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type) {
|
|
||||||
SmallVector<Value, 0> empty;
|
|
||||||
return GetInitTensor<isLHLO>(b, loc, type, empty);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(pifon): This logic is used everywhere, the code should be shared.
|
|
||||||
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
|
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -200,7 +190,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
ShapedType result_type = result.getType().template cast<ShapedType>();
|
ShapedType result_type = result.getType().template cast<ShapedType>();
|
||||||
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
||||||
output_buffers.push_back(
|
output_buffers.push_back(
|
||||||
GetInitTensor<isLHLO>(rewriter, loc, result_type, dyn_sizes));
|
GetInitTensor(rewriter, loc, result_type, dyn_sizes));
|
||||||
op_result_types.push_back(result.getType());
|
op_result_types.push_back(result.getType());
|
||||||
}
|
}
|
||||||
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
||||||
|
@ -397,9 +387,9 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||||
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
|
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
|
||||||
/*inputs=*/args.front(),
|
/*inputs=*/args.front(),
|
||||||
/*outputBuffers=*/
|
/*outputBuffers=*/
|
||||||
isLHLO ? ValueRange{args.back()}
|
isLHLO
|
||||||
: ValueRange{GetInitTensor<isLHLO>(rewriter, loc, result_type,
|
? ValueRange{args.back()}
|
||||||
dyn_sizes)},
|
: ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
|
||||||
indexing_maps, GetNParallelLoopsAttrs(nloops),
|
indexing_maps, GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||||
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
||||||
|
@ -859,6 +849,10 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
unsigned nloops = result_shaped_type.getRank();
|
unsigned nloops = result_shaped_type.getRank();
|
||||||
|
|
||||||
Location loc = iota_op.getLoc();
|
Location loc = iota_op.getLoc();
|
||||||
|
auto dyn_sizes = isLHLO
|
||||||
|
? SmallVector<Value, 2>()
|
||||||
|
: ExtractDynamicSizes(rewriter, loc,
|
||||||
|
GetResultValue<isLHLO>(iota_op));
|
||||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
||||||
loc,
|
loc,
|
||||||
/*resultTensorTypes=*/
|
/*resultTensorTypes=*/
|
||||||
|
@ -866,8 +860,8 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
/*inputs=*/ValueRange{},
|
/*inputs=*/ValueRange{},
|
||||||
/*outputBuffers=*/
|
/*outputBuffers=*/
|
||||||
isLHLO ? ValueRange{args}
|
isLHLO ? ValueRange{args}
|
||||||
: ValueRange{GetInitTensor<isLHLO>(rewriter, loc,
|
: ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
|
||||||
result_shaped_type)},
|
dyn_sizes)},
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
|
||||||
|
@ -1107,21 +1101,20 @@ DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
|
||||||
return DotOperationType::kUnsupported;
|
return DotOperationType::kUnsupported;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 8> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
||||||
Value lhs, Value rhs,
|
Value lhs, Value rhs,
|
||||||
ShapedType result_type,
|
|
||||||
DotOperationType type) {
|
DotOperationType type) {
|
||||||
SmallVector<Value, 8> dyn_shape;
|
SmallVector<Value, 2> dyn_shape;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case DotOperationType::kMatrixMatrix: {
|
case DotOperationType::kMatrixMatrix: {
|
||||||
if (result_type.isDynamicDim(0))
|
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||||
if (result_type.isDynamicDim(1))
|
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
|
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DotOperationType::kMatrixVector: {
|
case DotOperationType::kMatrixVector: {
|
||||||
if (result_type.isDynamicDim(0))
|
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
||||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1148,39 +1141,31 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||||
Type result_type = op.getResult().getType();
|
Type result_type = op.getResult().getType();
|
||||||
auto shaped_type = result_type.cast<ShapedType>();
|
auto shaped_type = result_type.cast<ShapedType>();
|
||||||
DotOperationType op_type = GetDotOperationType(op);
|
DotOperationType op_type = GetDotOperationType(op);
|
||||||
SmallVector<Value, 8> dyn_shape = GetDotOpInitTensorDynSizes(
|
|
||||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type);
|
|
||||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||||
auto init_tensor =
|
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||||
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
|
||||||
{
|
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
Value zero_tensor =
|
||||||
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||||
rewriter.getIndexType());
|
|
||||||
Region& region = init_tensor.body();
|
|
||||||
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
|
||||||
rewriter.setInsertionPointToEnd(block);
|
|
||||||
rewriter.create<tensor::YieldOp>(loc, zero);
|
|
||||||
}
|
|
||||||
linalg::LinalgOp linalg_op;
|
linalg::LinalgOp linalg_op;
|
||||||
switch (op_type) {
|
switch (op_type) {
|
||||||
case DotOperationType::kMatrixMatrix: {
|
case DotOperationType::kMatrixMatrix: {
|
||||||
linalg_op = rewriter.create<linalg::MatmulOp>(
|
linalg_op = rewriter.create<linalg::MatmulOp>(
|
||||||
loc, TypeRange{result_type},
|
loc, TypeRange{result_type},
|
||||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
|
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DotOperationType::kMatrixVector: {
|
case DotOperationType::kMatrixVector: {
|
||||||
linalg_op = rewriter.create<linalg::MatvecOp>(
|
linalg_op = rewriter.create<linalg::MatvecOp>(
|
||||||
loc, TypeRange{result_type},
|
loc, TypeRange{result_type},
|
||||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
|
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DotOperationType::kVectorDot: {
|
case DotOperationType::kVectorDot: {
|
||||||
linalg_op = rewriter.create<linalg::DotOp>(
|
linalg_op = rewriter.create<linalg::DotOp>(
|
||||||
loc, TypeRange{result_type},
|
loc, TypeRange{result_type},
|
||||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
|
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DotOperationType::kUnsupported:
|
case DotOperationType::kUnsupported:
|
||||||
|
@ -1248,21 +1233,13 @@ class DotGeneralOpOnTensorsConversion
|
||||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
||||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||||
auto init_tensor =
|
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
||||||
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
|
Value zero_tensor =
|
||||||
{
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
|
||||||
rewriter.getIndexType());
|
|
||||||
Region& region = init_tensor.body();
|
|
||||||
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
|
||||||
rewriter.setInsertionPointToEnd(block);
|
|
||||||
rewriter.create<tensor::YieldOp>(loc, zero);
|
|
||||||
}
|
|
||||||
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||||
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
||||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||||
/*outputBuffers=*/ValueRange{init_tensor});
|
/*outputBuffers=*/ValueRange{zero_tensor});
|
||||||
rewriter.replaceOp(op, linalg_op.getResults());
|
rewriter.replaceOp(op, linalg_op.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1375,21 +1352,14 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
|
||||||
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
|
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
|
||||||
rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
|
rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
|
||||||
reduction_dims);
|
reduction_dims);
|
||||||
auto init_tensor =
|
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
||||||
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
|
Value filled_tensor =
|
||||||
{
|
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
.getResult(0);
|
||||||
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
|
||||||
rewriter.getIndexType());
|
|
||||||
Region& region = init_tensor.body();
|
|
||||||
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
|
||||||
rewriter.setInsertionPointToEnd(block);
|
|
||||||
rewriter.create<tensor::YieldOp>(loc, init_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
|
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
|
||||||
/*outputBuffers=*/ValueRange{init_tensor}, indexing_maps,
|
/*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
|
||||||
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
|
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
|
||||||
|
|
||||||
// Convert the signature of the body. The reduce op region apply function
|
// Convert the signature of the body. The reduce op region apply function
|
||||||
|
|
|
@ -910,10 +910,13 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||||
return %0 : tensor<2x?xf32>
|
return %0 : tensor<2x?xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = tensor.generate
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul
|
// CHECK: linalg.matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -924,10 +927,13 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = tensor.generate
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matvec
|
// CHECK: linalg.matvec
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?xf32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -937,10 +943,11 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = tensor.generate
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor []
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.dot
|
// CHECK: linalg.dot
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<f32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -958,10 +965,40 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||||
return %0 : tensor<?x?x?xf32>
|
return %0 : tensor<?x?x?xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = tensor.generate
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.batch_matmul
|
// CHECK: linalg.batch_matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @batch_matmul_large
|
||||||
|
(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> {
|
||||||
|
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||||
|
dot_dimension_numbers = {
|
||||||
|
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||||
|
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
rhs_contracting_dimensions = dense<1> : tensor<1xi64>},
|
||||||
|
precision_config = ["DEFAULT", "DEFAULT"]}
|
||||||
|
: (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32>
|
||||||
|
return %0 : tensor<2x16x32xf32>
|
||||||
|
}
|
||||||
|
// CHECK: func @batch_matmul_large(
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>,
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>)
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
|
// CHECK: %[[DOT:.*]] = linalg.batch_matmul
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x16x32xf32>, tensor<2x32x32xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x16x32xf32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -1001,13 +1038,14 @@ func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-LABEL: @reduce_add
|
// CHECK-LABEL: @reduce_add
|
||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5]
|
||||||
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
@ -1025,13 +1063,14 @@ func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-LABEL: @reduce_minimum
|
// CHECK-LABEL: @reduce_minimum
|
||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5]
|
||||||
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
@ -1050,13 +1089,14 @@ func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-LABEL: @reduce_maximum
|
// CHECK-LABEL: @reduce_maximum
|
||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [5]
|
||||||
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
@ -1075,13 +1115,14 @@ func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<4xi32> {
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-LABEL: @reduce_dim0
|
// CHECK-LABEL: @reduce_dim0
|
||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [4]
|
||||||
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
|
@ -1101,13 +1142,13 @@ func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> {
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK-LABEL: @reduce_init_const
|
// CHECK-LABEL: @reduce_init_const
|
||||||
// CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [1]
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %{{.*}})
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<1xf32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
|
// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||||
|
@ -1126,13 +1167,14 @@ func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>,
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||||
// CHECK-LABEL: @reduce_multi_dimensions
|
// CHECK-LABEL: @reduce_multi_dimensions
|
||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [4]
|
||||||
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
@ -1150,15 +1192,16 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
|
||||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
|
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
|
||||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
|
||||||
|
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
// CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>)
|
// CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<?xi32>)
|
// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<?xi32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
Loading…
Reference in New Issue