diff --git a/BUILD b/BUILD index 6fad5e9..e401b35 100644 --- a/BUILD +++ b/BUILD @@ -699,6 +699,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 5244bf3..187d5a1 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -1130,6 +1131,8 @@ SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, return dyn_shape; } +template class DotOpOnTensorsConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1139,44 +1142,38 @@ class DotOpOnTensorsConversion : public OpConversionPattern { if (!VerifyHloOpBufferOrTensorSemantics(op)) { return failure(); } - Location loc = op.getLoc(); + mhlo::DotOp::Adaptor adaptor(args); - Type result_type = op.getResult().getType(); - auto shaped_type = result_type.cast(); - DotOperationType op_type = GetDotOperationType(op); - auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); + + auto lhs_el_type = + adaptor.lhs().getType().cast().getElementType(); + auto rhs_el_type = + adaptor.lhs().getType().cast().getElementType(); + if (lhs_el_type != rhs_el_type || !lhs_el_type.isa() || + lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) { + return failure(); + } + + auto output_type = op.getType().cast(); + auto output_el_type = output_type.getElementType(); + if (!output_el_type.isa() || + output_el_type.getIntOrFloatBitWidth() != output_bit_width) { + return failure(); + } + + if (GetDotOperationType(op) != op_type) return failure(); + + Location loc = op.getLoc(); + auto zero_attr = rewriter.getZeroAttr(output_el_type); Value zero = rewriter.create(loc, zero_attr); SmallVector dyn_shape = GetDotOpInitTensorDynSizes( rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type); - auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); + auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape); Value zero_tensor = rewriter.create(loc, init_tensor, zero).getResult(0); - linalg::LinalgOp linalg_op; - switch (op_type) { - case DotOperationType::kMatrixMatrix: { - linalg_op = rewriter.create( - loc, TypeRange{result_type}, - ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor}); - break; - } - case DotOperationType::kMatrixVector: { - linalg_op = rewriter.create( - loc, TypeRange{result_type}, - ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor}); - break; - } - case DotOperationType::kVectorDot: { - linalg_op = rewriter.create( - loc, TypeRange{result_type}, - ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor}); - break; - } - case DotOperationType::kUnsupported: - default: { - return op.emitError("unsupported dot operation type"); - } - } - rewriter.replaceOp(op, linalg_op->getResults()); + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()}, + ValueRange{zero_tensor}); return success(); } }; @@ -1193,6 +1190,8 @@ SmallVector GetDotGeneralOpInitTensorDynSizes( return dyn_shape; } +template class DotGeneralOpOnTensorsConversion : public OpConversionPattern { public: @@ -1203,6 +1202,7 @@ class DotGeneralOpOnTensorsConversion if (!VerifyHloOpBufferOrTensorSemantics(op)) { return failure(); } + mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers(); auto lhs_bathcing_dims = Extract1DVector(dim_numbers.lhs_batching_dimensions()); @@ -1228,22 +1228,38 @@ class DotGeneralOpOnTensorsConversion return rewriter.notifyMatchFailure( op, "expected rhs contracting dimensions exactly {1}"); } - Location loc = op.getLoc(); + mhlo::DotGeneralOp::Adaptor adaptor(args); - Type result_type = op.getResult().getType(); - auto shaped_type = result_type.cast(); + auto lhs_el_type = + adaptor.lhs().getType().cast().getElementType(); + auto rhs_el_type = + adaptor.lhs().getType().cast().getElementType(); + if (lhs_el_type != rhs_el_type || !lhs_el_type.isa() || + lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) { + return failure(); + } + + auto output_type = op.getType().cast(); + auto output_el_type = output_type.getElementType(); + if (!output_el_type.isa() || + output_el_type.getIntOrFloatBitWidth() != output_bit_width) { + return failure(); + } + + Location loc = op.getLoc(); SmallVector dyn_shape = GetDotGeneralOpInitTensorDynSizes( - rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type); - auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); + rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type); + auto zero_attr = rewriter.getZeroAttr(output_el_type); Value zero = rewriter.create(loc, zero_attr); - auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape); + auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape); Value zero_tensor = rewriter.create(loc, init_tensor, zero).getResult(0); - auto linalg_op = rewriter.create( - loc, /*resultTensorTypes=*/TypeRange{result_type}, + Operation* linalg_op = rewriter.create( + loc, /*resultTensorTypes=*/TypeRange{op.getType()}, /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, /*outputBuffers=*/ValueRange{zero_tensor}); - rewriter.replaceOp(op, linalg_op.getResults()); + + rewriter.replaceOp(op, linalg_op->getResults()); return success(); } }; @@ -1560,8 +1576,51 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReshapeOpConverter, ReverseConverter, SliceConverter, - TransposeConverter, DotOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context); + TransposeConverter, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, + ReduceOnTensorsConversion>(context); patterns->insert, ReduceRegionXLAOpConversion, ReduceRegionXLAOpConversion, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 7846f2d..a5c65f1 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -960,7 +960,8 @@ func @dot_matmul(%arg0: tensor<2x3xf32>, tensor<3x?xf32>) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> } -// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) +// CHECK-LABEL: func @dot_matmul( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] @@ -969,6 +970,58 @@ func @dot_matmul(%arg0: tensor<2x3xf32>, // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>) +func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>, + %arg1: tensor<3x?xi8>) -> tensor<2x?xi32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>, + tensor<3x?xi8>) -> tensor<2x?xi32> + return %0 : tensor<2x?xi32> +} +// CHECK-LABEL: func @dot_matmul_i8_i8_i32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>) +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] +// CHECK: linalg.matmul_i8_i8_i32 +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) + +// ----- + +func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>, + %arg1: tensor<3x?xi16>) -> tensor<2x?xi32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi16>, + tensor<3x?xi16>) -> tensor<2x?xi32> + return %0 : tensor<2x?xi32> +} +// CHECK-LABEL: func @dot_matmul_i16_i16_i32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>) +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] +// CHECK: linalg.matmul_i16_i16_i32 +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) + +// ----- + +func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>, + %arg1: tensor<3x?xi32>) -> tensor<2x?xi32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi32>, + tensor<3x?xi32>) -> tensor<2x?xi32> + return %0 : tensor<2x?xi32> +} +// CHECK-LABEL: func @dot_matmul_i32_i32_i32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>) +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] +// CHECK: linalg.matmul_i32_i32_i32 +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) + // ----- func @dot_matvec(%arg0: tensor, @@ -977,7 +1030,8 @@ func @dot_matvec(%arg0: tensor, tensor<3xf32>) -> tensor return %0 : tensor } -// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) +// CHECK-LABEL: func @dot_matvec( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]] @@ -993,7 +1047,8 @@ func @dot_dot(%arg0: tensor, %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } -// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK-LABEL: func @dot_dot( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK: %[[INIT:.*]] = linalg.init_tensor [] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] // CHECK: linalg.dot @@ -1002,7 +1057,7 @@ func @dot_dot(%arg0: tensor, // ----- -func @dot_general(%arg0: tensor, +func @dot_general_batch_matmul(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { @@ -1015,7 +1070,8 @@ func @dot_general(%arg0: tensor, } : (tensor, tensor) -> tensor return %0 : tensor } -// CHECK: func @dot_general(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK-LABEL: func @dot_general_batch_matmul( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] // CHECK: %[[C1:.*]] = constant 1 : index @@ -1030,7 +1086,65 @@ func @dot_general(%arg0: tensor, // ----- -func @batch_matmul_large +func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }, + precision_config = ["DEFAULT", "DEFAULT"] + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] +// CHECK: linalg.batch_matmul_i8_i8_i32 +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) + +// ----- + +func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }, + precision_config = ["DEFAULT", "DEFAULT"] + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] +// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] +// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] +// CHECK: linalg.batch_matmul_i16_i16_i32 +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) + +// ----- + +func @dot_general_batch_matmul_large (%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x32x32xf32>) -> tensor<2x16x32xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { @@ -1042,7 +1156,7 @@ func @batch_matmul_large : (tensor<2x16x32xf32>, tensor<2x32x32xf32>) -> tensor<2x16x32xf32> return %0 : tensor<2x16x32xf32> } -// CHECK: func @batch_matmul_large( +// CHECK-LABEL: func @dot_general_batch_matmul_large( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<2x16x32xf32>, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<2x32x32xf32>) // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 16, 32]