diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 60102cb..f1cc090 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1014,6 +1014,123 @@ class SliceConverter : public OpConversionPattern { } }; +enum class DotOperationType { + kVectorDot = 0, + kMatrixVector = 1, + kMatrixMatrix = 2, + kUnsupported = 3 +}; + +DotOperationType GetDotOperationType(mhlo::DotOp dot_op) { + ArrayRef lhs_shape = + dot_op.lhs().getType().cast().getShape(); + ArrayRef rhs_shape = + dot_op.rhs().getType().cast().getShape(); + auto shape_matches = [](int64_t a, int64_t b) { + return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize || + a == b; + }; + if (lhs_shape.size() == 1 && rhs_shape.size() == 1 && + shape_matches(lhs_shape[0], rhs_shape[0])) { + return DotOperationType::kVectorDot; + } + if (lhs_shape.size() == 2 && rhs_shape.size() == 1 && + shape_matches(lhs_shape[1], rhs_shape[0])) { + return DotOperationType::kMatrixVector; + } + if (rhs_shape.size() == 2 && rhs_shape.size() == 2 && + shape_matches(lhs_shape[1], rhs_shape[0])) { + return DotOperationType::kMatrixMatrix; + } + return DotOperationType::kUnsupported; +} + +SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, + Value lhs, Value rhs, + ShapedType result_type, + DotOperationType type) { + SmallVector dyn_shape; + switch (type) { + case DotOperationType::kMatrixMatrix: { + if (result_type.isDynamicDim(0)) + dyn_shape.push_back(b.create(loc, lhs, 0)); + if (result_type.isDynamicDim(1)) + dyn_shape.push_back(b.create(loc, rhs, 1)); + break; + } + case DotOperationType::kMatrixVector: { + if (result_type.isDynamicDim(0)) + dyn_shape.push_back(b.create(loc, lhs, 0)); + break; + } + case DotOperationType::kVectorDot: + case DotOperationType::kUnsupported: + default: { + break; + } + } + return dyn_shape; +} + +class DotOpOnTensorsConversion : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::DotOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!VerifyHloOpBufferOrTensorSemantics(op)) { + return failure(); + } + Location loc = op.getLoc(); + mhlo::DotOp::Adaptor adaptor(args); + Type result_type = op.getResult().getType(); + auto shaped_type = result_type.cast(); + DotOperationType op_type = GetDotOperationType(op); + SmallVector dyn_shape = GetDotOpInitTensorDynSizes( + rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type); + auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); + Value zero = rewriter.create(loc, zero_attr); + auto init_tensor = rewriter.create( + loc, result_type, dyn_shape); + { + OpBuilder::InsertionGuard guard(rewriter); + SmallVector arg_types(shaped_type.getRank(), + rewriter.getIndexType()); + Region& region = init_tensor.body(); + Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); + rewriter.setInsertionPointToEnd(block); + rewriter.create(loc, zero); + } + linalg::LinalgOp linalg_op; + switch (op_type) { + case DotOperationType::kMatrixMatrix: { + linalg_op = rewriter.create( + loc, TypeRange{result_type}, + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + break; + } + case DotOperationType::kMatrixVector: { + linalg_op = rewriter.create( + loc, TypeRange{result_type}, + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + break; + } + case DotOperationType::kVectorDot: { + linalg_op = rewriter.create( + loc, TypeRange{result_type}, + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + break; + } + case DotOperationType::kUnsupported: + default: { + return op.emitError("unsupported dot operation type"); + } + } + rewriter.replaceOp(op, linalg_op->getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1181,7 +1298,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, - TransposeConverter>(context); + TransposeConverter, + DotOpOnTensorsConversion>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 31b89d2..b5d762e 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -830,3 +830,44 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // CHECK-SAME: ins([[CST]] : tensor) outs([[INIT]] : tensor) // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +func @dot_matmul(%arg0: tensor<2x3xf32>, + %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>, + tensor<3x?xf32>) -> tensor<2x?xf32> + return %0 : tensor<2x?xf32> +} +// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) +// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>) + +// ----- + +func @dot_matvec(%arg0: tensor, + %arg1: tensor<3xf32>) -> tensor { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, + tensor<3xf32>) -> tensor + return %0 : tensor +} +// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.matvec +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3xf32>) +// CHECK-SAME: outs(%[[INIT]] : tensor) + +// ----- + +func @dot_dot(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.dot +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT]] : tensor)