Lower MHLO Dot to type-polymorphic linalg named ops
The linalg named ops are now type polymorphic, so the type-monomorphic varieties are redundant (and will be deleted soon). PiperOrigin-RevId: 360509010
This commit is contained in:
parent
1facbe9eb5
commit
8687f3e4cf
|
@ -1148,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
||||||
return dyn_shape;
|
return dyn_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
template <DotOperationType op_type, typename LinalgOp>
|
||||||
int output_bit_width, DotOperationType op_type, typename LinalgOp>
|
|
||||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
||||||
|
@ -1159,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
if (GetDotOperationType(op) != op_type) return failure();
|
||||||
|
|
||||||
mhlo::DotOp::Adaptor adaptor(args);
|
mhlo::DotOp::Adaptor adaptor(args);
|
||||||
|
|
||||||
auto lhs_el_type =
|
Location loc = op.getLoc();
|
||||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
|
||||||
auto rhs_el_type =
|
|
||||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
|
||||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
|
||||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output_type = op.getType().cast<ShapedType>();
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
auto output_el_type = output_type.getElementType();
|
auto output_el_type = output_type.getElementType();
|
||||||
if (!output_el_type.isa<OutputElType>() ||
|
|
||||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (GetDotOperationType(op) != op_type) return failure();
|
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||||
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
||||||
|
@ -1207,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
||||||
return dyn_shape;
|
return dyn_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InputElType, int input_bit_width, typename OutputElType,
|
|
||||||
int output_bit_width, typename LinalgOp>
|
|
||||||
class DotGeneralOpOnTensorsConversion
|
class DotGeneralOpOnTensorsConversion
|
||||||
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1247,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
|
||||||
}
|
}
|
||||||
|
|
||||||
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||||
auto lhs_el_type =
|
|
||||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
|
||||||
auto rhs_el_type =
|
|
||||||
adaptor.lhs().getType().cast<ShapedType>().getElementType();
|
|
||||||
if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
|
|
||||||
lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output_type = op.getType().cast<ShapedType>();
|
|
||||||
auto output_el_type = output_type.getElementType();
|
|
||||||
if (!output_el_type.isa<OutputElType>() ||
|
|
||||||
output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
|
auto output_el_type = output_type.getElementType();
|
||||||
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
|
||||||
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
auto zero_attr = rewriter.getZeroAttr(output_el_type);
|
||||||
|
@ -1271,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
|
||||||
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
|
||||||
Value zero_tensor =
|
Value zero_tensor =
|
||||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||||
Operation* linalg_op = rewriter.create<LinalgOp>(
|
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||||
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
|
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
|
||||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||||
/*outputBuffers=*/ValueRange{zero_tensor});
|
/*outputBuffers=*/ValueRange{zero_tensor});
|
||||||
|
@ -1709,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
SliceConverter<mhlo::SliceOp, false>,
|
SliceConverter<mhlo::SliceOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>,
|
TransposeConverter<mhlo::TransposeOp, false>,
|
||||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
|
||||||
DotOperationType::kMatrixMatrix,
|
|
||||||
linalg::MatmulI8I8I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
|
||||||
DotOperationType::kMatrixVector,
|
|
||||||
linalg::MatvecI8I8I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
|
||||||
DotOperationType::kVectorDot,
|
|
||||||
linalg::DotI8I8I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
|
||||||
DotOperationType::kMatrixMatrix,
|
|
||||||
linalg::MatmulI16I16I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
|
||||||
DotOperationType::kMatrixVector,
|
|
||||||
linalg::MatvecI16I16I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
|
||||||
DotOperationType::kVectorDot,
|
|
||||||
linalg::DotI16I16I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
|
||||||
DotOperationType::kMatrixMatrix,
|
|
||||||
linalg::MatmulI32I32I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
|
||||||
DotOperationType::kMatrixVector,
|
|
||||||
linalg::MatvecI32I32I32Op>,
|
|
||||||
DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
|
||||||
DotOperationType::kVectorDot,
|
|
||||||
linalg::DotI32I32I32Op>,
|
|
||||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
|
||||||
DotOperationType::kMatrixMatrix,
|
|
||||||
linalg::MatmulOp>,
|
linalg::MatmulOp>,
|
||||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
|
||||||
DotOperationType::kMatrixVector,
|
|
||||||
linalg::MatvecOp>,
|
linalg::MatvecOp>,
|
||||||
DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
|
||||||
DotOperationType::kVectorDot, linalg::DotOp>,
|
DotGeneralOpOnTensorsConversion,
|
||||||
DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
|
|
||||||
linalg::BatchMatmulI8I8I32Op>,
|
|
||||||
DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
|
|
||||||
linalg::BatchMatmulI16I16I32Op>,
|
|
||||||
DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
|
|
||||||
linalg::BatchMatmulI32I32I32Op>,
|
|
||||||
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
|
||||||
linalg::BatchMatmulOp>,
|
|
||||||
NormalConvOpOnTensorsConversion,
|
NormalConvOpOnTensorsConversion,
|
||||||
ReduceOnTensorsConversion,
|
ReduceOnTensorsConversion,
|
||||||
PadOpOnTensorsConversion>(context);
|
PadOpOnTensorsConversion>(context);
|
||||||
|
|
|
@ -1004,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul_i8_i8_i32
|
// CHECK: linalg.matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||||
|
|
||||||
|
@ -1022,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul_i16_i16_i32
|
// CHECK: linalg.matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||||
|
|
||||||
|
@ -1040,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.matmul_i32_i32_i32
|
// CHECK: linalg.matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
|
||||||
|
|
||||||
|
@ -1131,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.batch_matmul_i8_i8_i32
|
// CHECK: linalg.batch_matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||||
|
|
||||||
|
@ -1160,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||||
// CHECK: linalg.batch_matmul_i16_i16_i32
|
// CHECK: linalg.batch_matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue