Lower integer matmuls to linalg
PiperOrigin-RevId: 359306495
This commit is contained in:
parent
475b4a06a5
commit
89f7f2bd65
1
BUILD
1
BUILD
|
@ -699,6 +699,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:TensorDialect",
|
"@llvm-project//mlir:TensorDialect",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
|
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -1130,6 +1131,8 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
||||||
return dyn_shape;
|
return dyn_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||||
|
int output_bit_width, DotOperationType op_type, typename LinalgOp>
|
||||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
||||||
|
@ -1139,44 +1142,38 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
Location loc = op.getLoc();
|
|
||||||
mhlo::DotOp::Adaptor adaptor(args);
|
mhlo::DotOp::Adaptor adaptor(args);
|
||||||
Type result_type = op.getResult().getType();
|
|
||||||
auto shaped_type = result_type.cast<ShapedType>();
|
auto lhs_el_type =
|
||||||
DotOperationType op_type = GetDotOperationType(op);
|
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
auto rhs_el_type =
|
||||||
|
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||||
|
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||||
|
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
|
auto output_el_type = output_type.getElementType();
|
||||||
|
if (!output_el_type.isa<OutputElType>() ||
|
||||||
|
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (GetDotOperationType(op) != op_type) return failure();
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||||
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
|
||||||
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||||
Value zero_tensor =
|
Value zero_tensor =
|
||||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||||
linalg::LinalgOp linalg_op;
|
rewriter.replaceOpWithNewOp<LinalgOp>(
|
||||||
switch (op_type) {
|
op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||||
case DotOperationType::kMatrixMatrix: {
|
ValueRange{zero_tensor});
|
||||||
linalg_op = rewriter.create<linalg::MatmulOp>(
|
|
||||||
loc, TypeRange{result_type},
|
|
||||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case DotOperationType::kMatrixVector: {
|
|
||||||
linalg_op = rewriter.create<linalg::MatvecOp>(
|
|
||||||
loc, TypeRange{result_type},
|
|
||||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case DotOperationType::kVectorDot: {
|
|
||||||
linalg_op = rewriter.create<linalg::DotOp>(
|
|
||||||
loc, TypeRange{result_type},
|
|
||||||
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case DotOperationType::kUnsupported:
|
|
||||||
default: {
|
|
||||||
return op.emitError("unsupported dot operation type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(op, linalg_op->getResults());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1193,6 +1190,8 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
||||||
return dyn_shape;
|
return dyn_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||||
|
int output_bit_width, typename LinalgOp>
|
||||||
class DotGeneralOpOnTensorsConversion
|
class DotGeneralOpOnTensorsConversion
|
||||||
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1203,6 +1202,7 @@ class DotGeneralOpOnTensorsConversion
|
||||||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
|
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
|
||||||
auto lhs_bathcing_dims =
|
auto lhs_bathcing_dims =
|
||||||
Extract1DVector(dim_numbers.lhs_batching_dimensions());
|
Extract1DVector(dim_numbers.lhs_batching_dimensions());
|
||||||
|
@ -1228,22 +1228,38 @@ class DotGeneralOpOnTensorsConversion
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected rhs contracting dimensions exactly {1}");
|
op, "expected rhs contracting dimensions exactly {1}");
|
||||||
}
|
}
|
||||||
Location loc = op.getLoc();
|
|
||||||
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||||
Type result_type = op.getResult().getType();
|
auto lhs_el_type =
|
||||||
auto shaped_type = result_type.cast<ShapedType>();
|
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||||
|
auto rhs_el_type =
|
||||||
|
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||||
|
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||||
|
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
|
auto output_el_type = output_type.getElementType();
|
||||||
|
if (!output_el_type.isa<OutputElType>() ||
|
||||||
|
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
|
||||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||||
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||||
Value zero_tensor =
|
Value zero_tensor =
|
||||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||||
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
Operation* linalg_op = rewriter.create<LinalgOp>(
|
||||||
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
|
||||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||||
/*outputBuffers=*/ValueRange{zero_tensor});
|
/*outputBuffers=*/ValueRange{zero_tensor});
|
||||||
rewriter.replaceOp(op, linalg_op.getResults());
|
|
||||||
|
rewriter.replaceOp(op, linalg_op->getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1560,8 +1576,51 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
SliceConverter<mhlo::SliceOp, false>,
|
SliceConverter<mhlo::SliceOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
TransposeConverter<mhlo::TransposeOp, false>,
|
||||||
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||||
|
DotOperationType::kMatrixMatrix,
|
||||||
|
linalg::MatmulI8I8I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||||
|
DotOperationType::kMatrixVector,
|
||||||
|
linalg::MatvecI8I8I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||||
|
DotOperationType::kVectorDot,
|
||||||
|
linalg::DotI8I8I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||||
|
DotOperationType::kMatrixMatrix,
|
||||||
|
linalg::MatmulI16I16I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||||
|
DotOperationType::kMatrixVector,
|
||||||
|
linalg::MatvecI16I16I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||||
|
DotOperationType::kVectorDot,
|
||||||
|
linalg::DotI16I16I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||||
|
DotOperationType::kMatrixMatrix,
|
||||||
|
linalg::MatmulI32I32I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||||
|
DotOperationType::kMatrixVector,
|
||||||
|
linalg::MatvecI32I32I32Op>,
|
||||||
|
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||||
|
DotOperationType::kVectorDot,
|
||||||
|
linalg::DotI32I32I32Op>,
|
||||||
|
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||||
|
DotOperationType::kMatrixMatrix,
|
||||||
|
linalg::MatmulOp>,
|
||||||
|
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||||
|
DotOperationType::kMatrixVector,
|
||||||
|
linalg::MatvecOp>,
|
||||||
|
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||||
|
DotOperationType::kVectorDot, linalg::DotOp>,
|
||||||
|
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||||
|
linalg::BatchMatmulI8I8I32Op>,
|
||||||
|
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||||
|
linalg::BatchMatmulI16I16I32Op>,
|
||||||
|
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||||
|
linalg::BatchMatmulI32I32I32Op>,
|
||||||
|
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||||
|
linalg::BatchMatmulOp>,
|
||||||
|
ReduceOnTensorsConversion>(context);
|
||||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
||||||
|
|
|
@ -960,7 +960,8 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||||
tensor<3x?xf32>) -> tensor<2x?xf32>
|
tensor<3x?xf32>) -> tensor<2x?xf32>
|
||||||
return %0 : tensor<2x?xf32>
|
return %0 : tensor<2x?xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
// CHECK-LABEL: func @dot_matmul(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
|
@ -969,6 +970,58 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>)
|
||||||
|
|
||||||
|
func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
||||||
|
%arg1: tensor<3x?xi8>) -> tensor<2x?xi32> {
|
||||||
|
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>,
|
||||||
|
tensor<3x?xi8>) -> tensor<2x?xi32>
|
||||||
|
return %0 : tensor<2x?xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
|
// CHECK: linalg.matmul_i8_i8_i32
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
||||||
|
%arg1: tensor<3x?xi16>) -> tensor<2x?xi32> {
|
||||||
|
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi16>,
|
||||||
|
tensor<3x?xi16>) -> tensor<2x?xi32>
|
||||||
|
return %0 : tensor<2x?xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
|
// CHECK: linalg.matmul_i16_i16_i32
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
||||||
|
%arg1: tensor<3x?xi32>) -> tensor<2x?xi32> {
|
||||||
|
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi32>,
|
||||||
|
tensor<3x?xi32>) -> tensor<2x?xi32>
|
||||||
|
return %0 : tensor<2x?xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
|
// CHECK: linalg.matmul_i32_i32_i32
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dot_matvec(%arg0: tensor<?x3xf32>,
|
func @dot_matvec(%arg0: tensor<?x3xf32>,
|
||||||
|
@ -977,7 +1030,8 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
|
||||||
tensor<3xf32>) -> tensor<?xf32>
|
tensor<3xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
// CHECK-LABEL: func @dot_matvec(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
||||||
|
@ -993,7 +1047,8 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
||||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
|
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
|
||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
// CHECK-LABEL: func @dot_dot(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor []
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor []
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.dot
|
// CHECK: linalg.dot
|
||||||
|
@ -1002,7 +1057,7 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dot_general(%arg0: tensor<?x?x3xf32>,
|
func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
|
||||||
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
|
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
|
||||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||||
dot_dimension_numbers = {
|
dot_dimension_numbers = {
|
||||||
|
@ -1015,7 +1070,8 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||||
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
|
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
|
||||||
return %0 : tensor<?x?x?xf32>
|
return %0 : tensor<?x?x?xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
// CHECK-LABEL: func @dot_general_batch_matmul(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
@ -1030,7 +1086,65 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @batch_matmul_large
|
func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
||||||
|
%arg1: tensor<?x3x?xi8>) -> tensor<?x?x?xi32> {
|
||||||
|
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||||
|
dot_dimension_numbers = {
|
||||||
|
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||||
|
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
precision_config = ["DEFAULT", "DEFAULT"]
|
||||||
|
} : (tensor<?x?x3xi8>, tensor<?x3x?xi8>) -> tensor<?x?x?xi32>
|
||||||
|
return %0 : tensor<?x?x?xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
|
// CHECK: linalg.batch_matmul_i8_i8_i32
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
||||||
|
%arg1: tensor<?x3x?xi16>) -> tensor<?x?x?xi32> {
|
||||||
|
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||||
|
dot_dimension_numbers = {
|
||||||
|
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||||
|
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
precision_config = ["DEFAULT", "DEFAULT"]
|
||||||
|
} : (tensor<?x?x3xi16>, tensor<?x3x?xi16>) -> tensor<?x?x?xi32>
|
||||||
|
return %0 : tensor<?x?x?xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
|
// CHECK: linalg.batch_matmul_i16_i16_i32
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dot_general_batch_matmul_large
|
||||||
(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> {
|
(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> {
|
||||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
%0 = "mhlo.dot_general"(%arg0, %arg1) {
|
||||||
dot_dimension_numbers = {
|
dot_dimension_numbers = {
|
||||||
|
@ -1042,7 +1156,7 @@ func @batch_matmul_large
|
||||||
: (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32>
|
: (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32>
|
||||||
return %0 : tensor<2x16x32xf32>
|
return %0 : tensor<2x16x32xf32>
|
||||||
}
|
}
|
||||||
// CHECK: func @batch_matmul_large(
|
// CHECK-LABEL: func @dot_general_batch_matmul_large(
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>,
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>,
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>)
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>)
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]
|
||||||
|
|
Loading…
Reference in New Issue