Lower MHLO Dot to type-polymorphic linalg named ops

The linalg named ops are now type polymorphic, so the type-monomorphic
varieties are redundant (and will be deleted soon).

PiperOrigin-RevId: 360509010
This commit is contained in:
Geoffrey Martin-Noble 2021-03-02 13:59:50 -08:00 committed by TensorFlow MLIR Team
parent 1facbe9eb5
commit 8687f3e4cf
2 changed files with 15 additions and 83 deletions

View File

@ -1148,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
return dyn_shape; return dyn_shape;
} }
template <typename InputElType, int input_bit_width, typename OutputElType, template <DotOperationType op_type, typename LinalgOp>
int output_bit_width, DotOperationType op_type, typename LinalgOp>
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> { class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
public: public:
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern; using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
@ -1159,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) { if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure(); return failure();
} }
if (GetDotOperationType(op) != op_type) return failure();
mhlo::DotOp::Adaptor adaptor(args); mhlo::DotOp::Adaptor adaptor(args);
auto lhs_el_type = Location loc = op.getLoc();
adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
auto output_type = op.getType().cast<ShapedType>(); auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType(); auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
Location loc = op.getLoc();
auto zero_attr = rewriter.getZeroAttr(output_el_type); auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr); Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes( SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
@ -1207,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
return dyn_shape; return dyn_shape;
} }
template <typename InputElType, int input_bit_width, typename OutputElType,
int output_bit_width, typename LinalgOp>
class DotGeneralOpOnTensorsConversion class DotGeneralOpOnTensorsConversion
: public OpConversionPattern<mhlo::DotGeneralOp> { : public OpConversionPattern<mhlo::DotGeneralOp> {
public: public:
@ -1247,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
} }
mhlo::DotGeneralOp::Adaptor adaptor(args); mhlo::DotGeneralOp::Adaptor adaptor(args);
auto lhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
auto rhs_el_type =
adaptor.lhs().getType().cast<ShapedType>().getElementType();
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
return failure();
}
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
if (!output_el_type.isa<OutputElType>() ||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
return failure();
}
Location loc = op.getLoc(); Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes( SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type); rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type); auto zero_attr = rewriter.getZeroAttr(output_el_type);
@ -1271,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape); auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor = Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0); rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
Operation* linalg_op = rewriter.create<LinalgOp>( Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{op.getType()}, loc, /*resultTensorTypes=*/TypeRange{op.getType()},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{zero_tensor}); /*outputBuffers=*/ValueRange{zero_tensor});
@ -1709,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
SliceConverter<mhlo::SliceOp, false>, SliceConverter<mhlo::SliceOp, false>,
TransposeConverter<mhlo::TransposeOp, false>, TransposeConverter<mhlo::TransposeOp, false>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32, DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
DotOperationType::kMatrixMatrix,
linalg::MatmulI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI8I8I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI16I16I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kMatrixVector,
linalg::MatvecI32I32I32Op>,
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
DotOperationType::kVectorDot,
linalg::DotI32I32I32Op>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
DotOperationType::kMatrixMatrix,
linalg::MatmulOp>, linalg::MatmulOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32, DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
DotOperationType::kMatrixVector,
linalg::MatvecOp>, linalg::MatvecOp>,
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32, DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotOperationType::kVectorDot, linalg::DotOp>, DotGeneralOpOnTensorsConversion,
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
linalg::BatchMatmulI8I8I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
linalg::BatchMatmulI16I16I32Op>,
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
linalg::BatchMatmulI32I32I32Op>,
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
linalg::BatchMatmulOp>,
NormalConvOpOnTensorsConversion, NormalConvOpOnTensorsConversion,
ReduceOnTensorsConversion, ReduceOnTensorsConversion,
PadOpOnTensorsConversion>(context); PadOpOnTensorsConversion>(context);

View File

@ -1004,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i8_i8_i32 // CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1022,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i16_i16_i32 // CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1040,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul_i32_i32_i32 // CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
@ -1131,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i8_i8_i32 // CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
@ -1160,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul_i16_i16_i32 // CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>) // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)