Lower integer matmuls to linalg

PiperOrigin-RevId: 359306495
This commit is contained in:
Geoffrey Martin-Noble 2021-02-24 09:43:54 -08:00 committed by TensorFlow MLIR Team
parent 475b4a06a5
commit 89f7f2bd65
3 changed files with 224 additions and 50 deletions

1
BUILD
View File

@ -699,6 +699,7 @@ cc_library(
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
], ],

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
@ -1130,6 +1131,8 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
return dyn_shape; return dyn_shape;
} }
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, DotOperationType op_type, typename LinalgOp>
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> { class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
public: public:
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern; using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
@ -1139,44 +1142,38 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) { if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure(); return failure();
} }
Location loc = op.getLoc();
mhlo::DotOp::Adaptor adaptor(args); mhlo::DotOp::Adaptor adaptor(args);
Type result_type = op.getResult().getType();
auto shaped_type = result_type.cast<ShapedType>(); auto lhs_el_type =
DotOperationType op_type = GetDotOperationType(op); adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
Location loc = op.getLoc();
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr); Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes( SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type); rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor = Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0); rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
linalg::LinalgOp linalg_op; rewriter.replaceOpWithNewOp<LinalgOp>(
switch (op_type) { op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()},
case DotOperationType::kMatrixMatrix: { ValueRange{zero_tensor});
linalg_op = rewriter.create<linalg::MatmulOp>(
loc, TypeRange{result_type},
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
break;
}
case DotOperationType::kMatrixVector: {
linalg_op = rewriter.create<linalg::MatvecOp>(
loc, TypeRange{result_type},
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
break;
}
case DotOperationType::kVectorDot: {
linalg_op = rewriter.create<linalg::DotOp>(
loc, TypeRange{result_type},
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
break;
}
case DotOperationType::kUnsupported:
default: {
return op.emitError("unsupported dot operation type");
}
}
rewriter.replaceOp(op, linalg_op->getResults());
return success(); return success();
} }
}; };
@ -1193,6 +1190,8 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
return dyn_shape; return dyn_shape;
} }
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, typename LinalgOp>
class DotGeneralOpOnTensorsConversion class DotGeneralOpOnTensorsConversion
: public OpConversionPattern<mhlo::DotGeneralOp> { : public OpConversionPattern<mhlo::DotGeneralOp> {
public: public:
@ -1203,6 +1202,7 @@ class DotGeneralOpOnTensorsConversion
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) { if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure(); return failure();
} }
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers(); mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
auto lhs_bathcing_dims = auto lhs_bathcing_dims =
Extract1DVector(dim_numbers.lhs_batching_dimensions()); Extract1DVector(dim_numbers.lhs_batching_dimensions());
@ -1228,22 +1228,38 @@ class DotGeneralOpOnTensorsConversion
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected rhs contracting dimensions exactly {1}"); op, "expected rhs contracting dimensions exactly {1}");
} }
Location loc = op.getLoc();
mhlo::DotGeneralOp::Adaptor adaptor(args); mhlo::DotGeneralOp::Adaptor adaptor(args);
Type result_type = op.getResult().getType(); auto lhs_el_type =
auto shaped_type = result_type.cast<ShapedType>(); adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
Location loc = op.getLoc();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes( SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type); rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr); Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor = Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0); rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>( Operation* linalg_op = rewriter.create<LinalgOp>(
loc, /*resultTensorTypes=*/TypeRange{result_type}, loc, /*resultTensorTypes=*/TypeRange{op.getType()},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{zero_tensor}); /*outputBuffers=*/ValueRange{zero_tensor});
rewriter.replaceOp(op, linalg_op.getResults());
rewriter.replaceOp(op, linalg_op->getResults());
return success(); return success();
} }
}; };
@ -1560,8 +1576,51 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReshapeOpConverter<mhlo::ReshapeOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
SliceConverter<mhlo::SliceOp, false>, SliceConverter<mhlo::SliceOp, false>,
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion, TransposeConverter<mhlo::TransposeOp, false>,
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context); DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI32I32I32Op>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
linalg::BatchMatmulI8I8I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
linalg::BatchMatmulI16I16I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
linalg::BatchMatmulI32I32I32Op>,
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
linalg::BatchMatmulOp>,
ReduceOnTensorsConversion>(context);
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>, patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
ReduceRegionXLAOpConversion<mhlo::MinOp>, ReduceRegionXLAOpConversion<mhlo::MinOp>,
ReduceRegionXLAOpConversion<mhlo::MaxOp>, ReduceRegionXLAOpConversion<mhlo::MaxOp>,

View File

@ -960,7 +960,8 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
tensor<3x?xf32>) -> tensor<2x?xf32> tensor<3x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32> return %0 : tensor<2x?xf32>
} }
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) // CHECK-LABEL: func @dot_matmul(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
@ -969,6 +970,58 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)
func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
%arg1: tensor<3x?xi8>) -> tensor<2x?xi32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>,
tensor<3x?xi8>) -> tensor<2x?xi32>
return %0 : tensor<2x?xi32>
}
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i8_i8_i32
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
// -----
func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
%arg1: tensor<3x?xi16>) -> tensor<2x?xi32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi16>,
tensor<3x?xi16>) -> tensor<2x?xi32>
return %0 : tensor<2x?xi32>
}
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i16_i16_i32
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
// -----
func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
%arg1: tensor<3x?xi32>) -> tensor<2x?xi32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi32>,
tensor<3x?xi32>) -> tensor<2x?xi32>
return %0 : tensor<2x?xi32>
}
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i32_i32_i32
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
// ----- // -----
func @dot_matvec(%arg0: tensor<?x3xf32>, func @dot_matvec(%arg0: tensor<?x3xf32>,
@ -977,7 +1030,8 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
tensor<3xf32>) -> tensor<?xf32> tensor<3xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>) // CHECK-LABEL: func @dot_matvec(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
@ -993,7 +1047,8 @@ func @dot_dot(%arg0: tensor<?xf32>,
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32> %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>) // CHECK-LABEL: func @dot_dot(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
// CHECK: %[[INIT:.*]] = linalg.init_tensor [] // CHECK: %[[INIT:.*]] = linalg.init_tensor []
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.dot // CHECK: linalg.dot
@ -1002,7 +1057,7 @@ func @dot_dot(%arg0: tensor<?xf32>,
// ----- // -----
func @dot_general(%arg0: tensor<?x?x3xf32>, func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> { %arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) { %0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = { dot_dimension_numbers = {
@ -1015,7 +1070,8 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32> } : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32> return %0 : tensor<?x?x?xf32>
} }
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>) // CHECK-LABEL: func @dot_general_batch_matmul(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
@ -1030,7 +1086,65 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
// ----- // -----
func @batch_matmul_large func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
%arg1: tensor<?x3x?xi8>) -> tensor<?x?x?xi32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
},
precision_config = ["DEFAULT", "DEFAULT"]
} : (tensor<?x?x3xi8>, tensor<?x3x?xi8>) -> tensor<?x?x?xi32>
return %0 : tensor<?x?x?xi32>
}
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i8_i8_i32
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
// -----
func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
%arg1: tensor<?x3x?xi16>) -> tensor<?x?x?xi32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
},
precision_config = ["DEFAULT", "DEFAULT"]
} : (tensor<?x?x3xi16>, tensor<?x3x?xi16>) -> tensor<?x?x?xi32>
return %0 : tensor<?x?x?xi32>
}
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i16_i16_i32
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
// -----
func @dot_general_batch_matmul_large
(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> { (%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) { %0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = { dot_dimension_numbers = {
@ -1042,7 +1156,7 @@ func @batch_matmul_large
: (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32> : (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32>
return %0 : tensor<2x16x32xf32> return %0 : tensor<2x16x32xf32>
} }
// CHECK: func @batch_matmul_large( // CHECK-LABEL: func @dot_general_batch_matmul_large(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>) // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>)
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]