Lower MHLO Dot to type-polymorphic linalg named ops

The linalg named ops are now type polymorphic, so the type-monomorphic
varieties are redundant (and will be deleted soon).

PiperOrigin-RevId: 360509010
This commit is contained in:
Geoffrey Martin-Noble 2021-03-02 13:59:50 -08:00 committed by TensorFlow MLIR Team
parent 1facbe9eb5
commit 8687f3e4cf
2 changed files with 15 additions and 83 deletions

View File

@ -1148,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
return dyn_shape;
}
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, DotOperationType op_type, typename LinalgOp>
template <DotOperationType op_type, typename LinalgOp>
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
public:
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
@ -1159,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
mhlo::DotOp::Adaptor adaptor(args);
auto lhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
Location loc = op.getLoc();
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
@ -1207,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
return dyn_shape;
}
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, typename LinalgOp>
class DotGeneralOpOnTensorsConversion
: public OpConversionPattern<mhlo::DotGeneralOp> {
public:
@ -1247,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
}
mhlo::DotGeneralOp::Adaptor adaptor(args);
auto lhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type);
@ -1271,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
Operation* linalg_op = rewriter.create<LinalgOp>(
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{zero_tensor});
@ -1709,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReverseConverter<mhlo::ReverseOp, false>,
SliceConverter<mhlo::SliceOp, false>,
TransposeConverter<mhlo::TransposeOp, false>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI32I32I32Op>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixMatrix,
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
linalg::MatmulOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixVector,
DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
linalg::MatvecOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
linalg::BatchMatmulI8I8I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
linalg::BatchMatmulI16I16I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
linalg::BatchMatmulI32I32I32Op>,
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
linalg::BatchMatmulOp>,
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralOpOnTensorsConversion,
NormalConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
PadOpOnTensorsConversion>(context);

View File

@ -1004,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i8_i8_i32
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1022,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i16_i16_i32
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1040,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i32_i32_i32
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1131,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i8_i8_i32
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
@ -1160,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i16_i16_i32
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)