Lower MHLO Dot to type-polymorphic linalg named ops
The linalg named ops are now type polymorphic, so the type-monomorphic varieties are redundant (and will be deleted soon). PiperOrigin-RevId: 360509010
This commit is contained in:
		
							parent
							
								
									1facbe9eb5
								
							
						
					
					
						commit
						8687f3e4cf
					
				| 
						 | 
					@ -1148,8 +1148,7 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
 | 
				
			||||||
  return dyn_shape;
 | 
					  return dyn_shape;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename InputElType, int input_bit_width, typename OutputElType,
 | 
					template <DotOperationType op_type, typename LinalgOp>
 | 
				
			||||||
          int output_bit_width, DotOperationType op_type, typename LinalgOp>
 | 
					 | 
				
			||||||
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
 | 
					class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
 | 
					  using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
 | 
				
			||||||
| 
						 | 
					@ -1159,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
 | 
				
			||||||
    if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
 | 
					    if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
 | 
				
			||||||
      return failure();
 | 
					      return failure();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    if (GetDotOperationType(op) != op_type) return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mhlo::DotOp::Adaptor adaptor(args);
 | 
					    mhlo::DotOp::Adaptor adaptor(args);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto lhs_el_type =
 | 
					    Location loc = op.getLoc();
 | 
				
			||||||
        adaptor.lhs().getType().cast<ShapedType>().getElementType();
 | 
					 | 
				
			||||||
    auto rhs_el_type =
 | 
					 | 
				
			||||||
        adaptor.lhs().getType().cast<ShapedType>().getElementType();
 | 
					 | 
				
			||||||
    if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
 | 
					 | 
				
			||||||
        lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
 | 
					 | 
				
			||||||
      return failure();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto output_type = op.getType().cast<ShapedType>();
 | 
					    auto output_type = op.getType().cast<ShapedType>();
 | 
				
			||||||
    auto output_el_type = output_type.getElementType();
 | 
					    auto output_el_type = output_type.getElementType();
 | 
				
			||||||
    if (!output_el_type.isa<OutputElType>() ||
 | 
					 | 
				
			||||||
        output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
 | 
					 | 
				
			||||||
      return failure();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (GetDotOperationType(op) != op_type) return failure();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Location loc = op.getLoc();
 | 
					 | 
				
			||||||
    auto zero_attr = rewriter.getZeroAttr(output_el_type);
 | 
					    auto zero_attr = rewriter.getZeroAttr(output_el_type);
 | 
				
			||||||
    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
 | 
					    Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
 | 
				
			||||||
    SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
 | 
					    SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
 | 
				
			||||||
| 
						 | 
					@ -1207,8 +1191,6 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
 | 
				
			||||||
  return dyn_shape;
 | 
					  return dyn_shape;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename InputElType, int input_bit_width, typename OutputElType,
 | 
					 | 
				
			||||||
          int output_bit_width, typename LinalgOp>
 | 
					 | 
				
			||||||
class DotGeneralOpOnTensorsConversion
 | 
					class DotGeneralOpOnTensorsConversion
 | 
				
			||||||
    : public OpConversionPattern<mhlo::DotGeneralOp> {
 | 
					    : public OpConversionPattern<mhlo::DotGeneralOp> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
| 
						 | 
					@ -1247,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mhlo::DotGeneralOp::Adaptor adaptor(args);
 | 
					    mhlo::DotGeneralOp::Adaptor adaptor(args);
 | 
				
			||||||
    auto lhs_el_type =
 | 
					 | 
				
			||||||
        adaptor.lhs().getType().cast<ShapedType>().getElementType();
 | 
					 | 
				
			||||||
    auto rhs_el_type =
 | 
					 | 
				
			||||||
        adaptor.lhs().getType().cast<ShapedType>().getElementType();
 | 
					 | 
				
			||||||
    if (lhs_el_type != rhs_el_type || !lhs_el_type.isa<InputElType>() ||
 | 
					 | 
				
			||||||
        lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) {
 | 
					 | 
				
			||||||
      return failure();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto output_type = op.getType().cast<ShapedType>();
 | 
					 | 
				
			||||||
    auto output_el_type = output_type.getElementType();
 | 
					 | 
				
			||||||
    if (!output_el_type.isa<OutputElType>() ||
 | 
					 | 
				
			||||||
        output_el_type.getIntOrFloatBitWidth() != output_bit_width) {
 | 
					 | 
				
			||||||
      return failure();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Location loc = op.getLoc();
 | 
					    Location loc = op.getLoc();
 | 
				
			||||||
 | 
					    auto output_type = op.getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					    auto output_el_type = output_type.getElementType();
 | 
				
			||||||
    SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
 | 
					    SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
 | 
				
			||||||
        rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
 | 
					        rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
 | 
				
			||||||
    auto zero_attr = rewriter.getZeroAttr(output_el_type);
 | 
					    auto zero_attr = rewriter.getZeroAttr(output_el_type);
 | 
				
			||||||
| 
						 | 
					@ -1271,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion
 | 
				
			||||||
    auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
 | 
					    auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
 | 
				
			||||||
    Value zero_tensor =
 | 
					    Value zero_tensor =
 | 
				
			||||||
        rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
 | 
					        rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
 | 
				
			||||||
    Operation* linalg_op = rewriter.create<LinalgOp>(
 | 
					    Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
 | 
				
			||||||
        loc, /*resultTensorTypes=*/TypeRange{op.getType()},
 | 
					        loc, /*resultTensorTypes=*/TypeRange{op.getType()},
 | 
				
			||||||
        /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
 | 
					        /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
 | 
				
			||||||
        /*outputBuffers=*/ValueRange{zero_tensor});
 | 
					        /*outputBuffers=*/ValueRange{zero_tensor});
 | 
				
			||||||
| 
						 | 
					@ -1709,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
      ReverseConverter<mhlo::ReverseOp, false>,
 | 
					      ReverseConverter<mhlo::ReverseOp, false>,
 | 
				
			||||||
      SliceConverter<mhlo::SliceOp, false>,
 | 
					      SliceConverter<mhlo::SliceOp, false>,
 | 
				
			||||||
      TransposeConverter<mhlo::TransposeOp, false>,
 | 
					      TransposeConverter<mhlo::TransposeOp, false>,
 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
 | 
					      DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
 | 
				
			||||||
                               DotOperationType::kMatrixMatrix,
 | 
					 | 
				
			||||||
                               linalg::MatmulI8I8I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kMatrixVector,
 | 
					 | 
				
			||||||
                               linalg::MatvecI8I8I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kVectorDot,
 | 
					 | 
				
			||||||
                               linalg::DotI8I8I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kMatrixMatrix,
 | 
					 | 
				
			||||||
                               linalg::MatmulI16I16I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kMatrixVector,
 | 
					 | 
				
			||||||
                               linalg::MatvecI16I16I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kVectorDot,
 | 
					 | 
				
			||||||
                               linalg::DotI16I16I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kMatrixMatrix,
 | 
					 | 
				
			||||||
                               linalg::MatmulI32I32I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kMatrixVector,
 | 
					 | 
				
			||||||
                               linalg::MatvecI32I32I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kVectorDot,
 | 
					 | 
				
			||||||
                               linalg::DotI32I32I32Op>,
 | 
					 | 
				
			||||||
      DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
 | 
					 | 
				
			||||||
                               DotOperationType::kMatrixMatrix,
 | 
					 | 
				
			||||||
                               linalg::MatmulOp>,
 | 
					                               linalg::MatmulOp>,
 | 
				
			||||||
      DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
 | 
					      DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
 | 
				
			||||||
                               DotOperationType::kMatrixVector,
 | 
					 | 
				
			||||||
                               linalg::MatvecOp>,
 | 
					                               linalg::MatvecOp>,
 | 
				
			||||||
      DotOpOnTensorsConversion<FloatType, 32, FloatType, 32,
 | 
					      DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
 | 
				
			||||||
                               DotOperationType::kVectorDot, linalg::DotOp>,
 | 
					      DotGeneralOpOnTensorsConversion,
 | 
				
			||||||
      DotGeneralOpOnTensorsConversion<IntegerType, 8, IntegerType, 32,
 | 
					 | 
				
			||||||
                                      linalg::BatchMatmulI8I8I32Op>,
 | 
					 | 
				
			||||||
      DotGeneralOpOnTensorsConversion<IntegerType, 16, IntegerType, 32,
 | 
					 | 
				
			||||||
                                      linalg::BatchMatmulI16I16I32Op>,
 | 
					 | 
				
			||||||
      DotGeneralOpOnTensorsConversion<IntegerType, 32, IntegerType, 32,
 | 
					 | 
				
			||||||
                                      linalg::BatchMatmulI32I32I32Op>,
 | 
					 | 
				
			||||||
      DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
 | 
					 | 
				
			||||||
                                      linalg::BatchMatmulOp>,
 | 
					 | 
				
			||||||
      NormalConvOpOnTensorsConversion,
 | 
					      NormalConvOpOnTensorsConversion,
 | 
				
			||||||
      ReduceOnTensorsConversion,
 | 
					      ReduceOnTensorsConversion,
 | 
				
			||||||
      PadOpOnTensorsConversion>(context);
 | 
					      PadOpOnTensorsConversion>(context);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1004,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
 | 
				
			||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
					// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
				
			||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.matmul_i8_i8_i32
 | 
					// CHECK: linalg.matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1022,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
 | 
				
			||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
					// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
				
			||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.matmul_i16_i16_i32
 | 
					// CHECK: linalg.matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1040,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
 | 
				
			||||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
					// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
 | 
				
			||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.matmul_i32_i32_i32
 | 
					// CHECK: linalg.matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1131,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
 | 
				
			||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
					// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
				
			||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.batch_matmul_i8_i8_i32
 | 
					// CHECK: linalg.batch_matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi8>, tensor<?x3x?xi8>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1160,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
 | 
				
			||||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
					// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
 | 
				
			||||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
					// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
 | 
				
			||||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
					// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
 | 
				
			||||||
// CHECK: linalg.batch_matmul_i16_i16_i32
 | 
					// CHECK: linalg.batch_matmul
 | 
				
			||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
 | 
					// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xi16>, tensor<?x3x?xi16>)
 | 
				
			||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
 | 
					// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xi32>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue