Lower integer matmuls to linalg
PiperOrigin-RevId: 359306495
This commit is contained in:
parent
475b4a06a5
commit
89f7f2bd65
1
BUILD
1
BUILD
|
@ -699,6 +699,7 @@ cc_library(
|
|||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
|
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -1130,6 +1131,8 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
|||
return dyn_shape;
|
||||
}
|
||||
|
||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||
int output_bit_width, DotOperationType op_type, typename LinalgOp>
|
||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
||||
|
@ -1139,44 +1142,38 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
|||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||
return failure();
|
||||
}
|
||||
Location loc = op.getLoc();
|
||||
|
||||
mhlo::DotOp::Adaptor adaptor(args);
|
||||
Type result_type = op.getResult().getType();
|
||||
auto shaped_type = result_type.cast<ShapedType>();
|
||||
DotOperationType op_type = GetDotOperationType(op);
|
||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||
|
||||
auto lhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
auto rhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
if (!output_el_type.isa<OutputElType>() ||
|
||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (GetDotOperationType(op) != op_type) return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
|
||||
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
||||
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
linalg::LinalgOp linalg_op;
|
||||
switch (op_type) {
|
||||
case DotOperationType::kMatrixMatrix: {
|
||||
linalg_op = rewriter.create<linalg::MatmulOp>(
|
||||
loc, TypeRange{result_type},
|
||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kMatrixVector: {
|
||||
linalg_op = rewriter.create<linalg::MatvecOp>(
|
||||
loc, TypeRange{result_type},
|
||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kVectorDot: {
|
||||
linalg_op = rewriter.create<linalg::DotOp>(
|
||||
loc, TypeRange{result_type},
|
||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kUnsupported:
|
||||
default: {
|
||||
return op.emitError("unsupported dot operation type");
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, linalg_op->getResults());
|
||||
rewriter.replaceOpWithNewOp<LinalgOp>(
|
||||
op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||
ValueRange{zero_tensor});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1193,6 +1190,8 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
|||
return dyn_shape;
|
||||
}
|
||||
|
||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||
int output_bit_width, typename LinalgOp>
|
||||
class DotGeneralOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||
public:
|
||||
|
@ -1203,6 +1202,7 @@ class DotGeneralOpOnTensorsConversion
|
|||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
|
||||
auto lhs_bathcing_dims =
|
||||
Extract1DVector(dim_numbers.lhs_batching_dimensions());
|
||||
|
@ -1228,22 +1228,38 @@ class DotGeneralOpOnTensorsConversion
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "expected rhs contracting dimensions exactly {1}");
|
||||
}
|
||||
Location loc = op.getLoc();
|
||||
|
||||
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||
Type result_type = op.getResult().getType();
|
||||
auto shaped_type = result_type.cast<ShapedType>();
|
||||
auto lhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
auto rhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
if (!output_el_type.isa<OutputElType>() ||
|
||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
|
||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
||||
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
||||
Operation* linalg_op = rewriter.create<LinalgOp>(
|
||||
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
|
||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||
/*outputBuffers=*/ValueRange{zero_tensor});
|
||||
rewriter.replaceOp(op, linalg_op.getResults());
|
||||
|
||||
rewriter.replaceOp(op, linalg_op->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1560,8 +1576,51 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
SliceConverter<mhlo::SliceOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
||||
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
||||
TransposeConverter<mhlo::TransposeOp, false>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulOp>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecOp>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kVectorDot, linalg::DotOp>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
linalg::BatchMatmulI8I8I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
linalg::BatchMatmulI16I16I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
linalg::BatchMatmulI32I32I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
linalg::BatchMatmulOp>,
|
||||
ReduceOnTensorsConversion>(context);
|
||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
||||
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
||||
|
|
|
@ -960,7 +960,8 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
|||
tensor<3x?xf32>) -> tensor<2x?xf32>
|
||||
return %0 : tensor<2x?xf32>
|
||||
}
|
||||
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||
// CHECK-LABEL: func @dot_matmul(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
|
@ -969,6 +970,58 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
|||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)
|
||||
|
||||
func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
||||
%arg1: tensor<3x?xi8>) -> tensor<2x?xi32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>,
|
||||
tensor<3x?xi8>) -> tensor<2x?xi32>
|
||||
return %0 : tensor<2x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i8_i8_i32
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
||||
%arg1: tensor<3x?xi16>) -> tensor<2x?xi32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi16>,
|
||||
tensor<3x?xi16>) -> tensor<2x?xi32>
|
||||
return %0 : tensor<2x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i16_i16_i32
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
||||
%arg1: tensor<3x?xi32>) -> tensor<2x?xi32> {
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi32>,
|
||||
tensor<3x?xi32>) -> tensor<2x?xi32>
|
||||
return %0 : tensor<2x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i32_i32_i32
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_matvec(%arg0: tensor<?x3xf32>,
|
||||
|
@ -977,7 +1030,8 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
|
|||
tensor<3xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||
// CHECK-LABEL: func @dot_matvec(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
||||
|
@ -993,7 +1047,8 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
|||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
||||
// CHECK-LABEL: func @dot_dot(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor []
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.dot
|
||||
|
@ -1002,7 +1057,7 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||
func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
|
||||
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||
dot_dimension_numbers = {
|
||||
|
@ -1015,7 +1070,8 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
|
|||
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||
// CHECK-LABEL: func @dot_general_batch_matmul(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
|
@ -1030,7 +1086,65 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
func @batch_matmul_large
|
||||
func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
||||
%arg1: tensor<?x3x?xi8>) -> tensor<?x?x?xi32> {
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||
dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
},
|
||||
precision_config = ["DEFAULT", "DEFAULT"]
|
||||
} : (tensor<?x?x3xi8>, tensor<?x3x?xi8>) -> tensor<?x?x?xi32>
|
||||
return %0 : tensor<?x?x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul_i8_i8_i32
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
||||
%arg1: tensor<?x3x?xi16>) -> tensor<?x?x?xi32> {
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||
dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
},
|
||||
precision_config = ["DEFAULT", "DEFAULT"]
|
||||
} : (tensor<?x?x3xi16>, tensor<?x3x?xi16>) -> tensor<?x?x?xi32>
|
||||
return %0 : tensor<?x?x?xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul_i16_i16_i32
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_general_batch_matmul_large
|
||||
(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> {
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||
dot_dimension_numbers = {
|
||||
|
@ -1042,7 +1156,7 @@ func @batch_matmul_large
|
|||
: (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32>
|
||||
return %0 : tensor<2x16x32xf32>
|
||||
}
|
||||
// CHECK: func @batch_matmul_large(
|
||||
// CHECK-LABEL: func @dot_general_batch_matmul_large(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>,
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>)
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]
|
||||
|
|
Loading…
Reference in New Issue