Lower MHLO Dot to type-polymorphic linalg named ops
The linalg named ops are now type polymorphic, so the type-monomorphic varieties are redundant (and will be deleted soon). PiperOrigin-RevId: 360509010
This commit is contained in:
parent
1facbe9eb5
commit
8687f3e4cf
|
@ -1148,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
|||
return dyn_shape;
|
||||
}
|
||||
|
||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||
int output_bit_width, DotOperationType op_type, typename LinalgOp>
|
||||
template <DotOperationType op_type, typename LinalgOp>
|
||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
||||
|
@ -1159,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
|||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||
return failure();
|
||||
}
|
||||
if (GetDotOperationType(op) != op_type) return failure();
|
||||
|
||||
mhlo::DotOp::Adaptor adaptor(args);
|
||||
|
||||
auto lhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
auto rhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
if (!output_el_type.isa<OutputElType>() ||
|
||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (GetDotOperationType(op) != op_type) return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||
|
@ -1207,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
|||
return dyn_shape;
|
||||
}
|
||||
|
||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
||||
int output_bit_width, typename LinalgOp>
|
||||
class DotGeneralOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||
public:
|
||||
|
@ -1247,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
|
|||
}
|
||||
|
||||
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||
auto lhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
auto rhs_el_type =
|
||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
if (!output_el_type.isa<OutputElType>() ||
|
||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_el_type = output_type.getElementType();
|
||||
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
|
||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||
|
@ -1271,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
|
|||
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
Operation* linalg_op = rewriter.create<LinalgOp>(
|
||||
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
|
||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||
/*outputBuffers=*/ValueRange{zero_tensor});
|
||||
|
@ -1709,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
SliceConverter<mhlo::SliceOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI8I8I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI16I16I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
linalg::MatvecI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
DotOperationType::kVectorDot,
|
||||
linalg::DotI32I32I32Op>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kMatrixMatrix,
|
||||
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
|
||||
linalg::MatmulOp>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kMatrixVector,
|
||||
DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
|
||||
linalg::MatvecOp>,
|
||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
DotOperationType::kVectorDot, linalg::DotOp>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
||||
linalg::BatchMatmulI8I8I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
||||
linalg::BatchMatmulI16I16I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
||||
linalg::BatchMatmulI32I32I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
linalg::BatchMatmulOp>,
|
||||
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
|
||||
DotGeneralOpOnTensorsConversion,
|
||||
NormalConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
|
|
|
@ -1004,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
|||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i8_i8_i32
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
|
@ -1022,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
|||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i16_i16_i32
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
|
@ -1040,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
|||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul_i32_i32_i32
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||
|
||||
|
@ -1131,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
|||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul_i8_i8_i32
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||
|
||||
|
@ -1160,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
|||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul_i16_i16_i32
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||
|
||||
|
|
Loading…
Reference in New Issue