From 8f58f844e5c2405068eab6e886bcb1c73bced57f Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Mon, 11 Jan 2021 10:33:14 -0800 Subject: [PATCH] Upstream mhlo.dot lowering to Linalg to MHLO repo. We prototyped the lowering from mhlo.dot to linalg.matmul in IREE. Since Linalg now supports matmul in tensors world, we can move the lowering logic to tensors world, and upstream to legalize_to_linalg.cc. The patch lowers the mhlo.dot to the linalg.matmul/matvec/dot in tensors world. PiperOrigin-RevId: 351184911 --- .../mhlo/transforms/legalize_to_linalg.cc | 120 +++++++++++++++++- tests/hlo-legalize-to-linalg.mlir | 41 ++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 60102cb..f1cc090 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1014,6 +1014,123 @@ class SliceConverter : public OpConversionPattern { } }; +enum class DotOperationType { + kVectorDot = 0, + kMatrixVector = 1, + kMatrixMatrix = 2, + kUnsupported = 3 +}; + +DotOperationType GetDotOperationType(mhlo::DotOp dot_op) { + ArrayRef lhs_shape = + dot_op.lhs().getType().cast().getShape(); + ArrayRef rhs_shape = + dot_op.rhs().getType().cast().getShape(); + auto shape_matches = [](int64_t a, int64_t b) { + return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize || + a == b; + }; + if (lhs_shape.size() == 1 && rhs_shape.size() == 1 && + shape_matches(lhs_shape[0], rhs_shape[0])) { + return DotOperationType::kVectorDot; + } + if (lhs_shape.size() == 2 && rhs_shape.size() == 1 && + shape_matches(lhs_shape[1], rhs_shape[0])) { + return DotOperationType::kMatrixVector; + } + if (rhs_shape.size() == 2 && rhs_shape.size() == 2 && + shape_matches(lhs_shape[1], rhs_shape[0])) { + return DotOperationType::kMatrixMatrix; + } + return DotOperationType::kUnsupported; +} + +SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, + Value lhs, Value rhs, + ShapedType result_type, + DotOperationType type) { + SmallVector dyn_shape; + switch (type) { + case DotOperationType::kMatrixMatrix: { + if (result_type.isDynamicDim(0)) + dyn_shape.push_back(b.create(loc, lhs, 0)); + if (result_type.isDynamicDim(1)) + dyn_shape.push_back(b.create(loc, rhs, 1)); + break; + } + case DotOperationType::kMatrixVector: { + if (result_type.isDynamicDim(0)) + dyn_shape.push_back(b.create(loc, lhs, 0)); + break; + } + case DotOperationType::kVectorDot: + case DotOperationType::kUnsupported: + default: { + break; + } + } + return dyn_shape; +} + +class DotOpOnTensorsConversion : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::DotOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!VerifyHloOpBufferOrTensorSemantics(op)) { + return failure(); + } + Location loc = op.getLoc(); + mhlo::DotOp::Adaptor adaptor(args); + Type result_type = op.getResult().getType(); + auto shaped_type = result_type.cast(); + DotOperationType op_type = GetDotOperationType(op); + SmallVector dyn_shape = GetDotOpInitTensorDynSizes( + rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type); + auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); + Value zero = rewriter.create(loc, zero_attr); + auto init_tensor = rewriter.create( + loc, result_type, dyn_shape); + { + OpBuilder::InsertionGuard guard(rewriter); + SmallVector arg_types(shaped_type.getRank(), + rewriter.getIndexType()); + Region& region = init_tensor.body(); + Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); + rewriter.setInsertionPointToEnd(block); + rewriter.create(loc, zero); + } + linalg::LinalgOp linalg_op; + switch (op_type) { + case DotOperationType::kMatrixMatrix: { + linalg_op = rewriter.create( + loc, TypeRange{result_type}, + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + break; + } + case DotOperationType::kMatrixVector: { + linalg_op = rewriter.create( + loc, TypeRange{result_type}, + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + break; + } + case DotOperationType::kVectorDot: { + linalg_op = rewriter.create( + loc, TypeRange{result_type}, + ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor}); + break; + } + case DotOperationType::kUnsupported: + default: { + return op.emitError("unsupported dot operation type"); + } + } + rewriter.replaceOp(op, linalg_op->getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1181,7 +1298,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, - TransposeConverter>(context); + TransposeConverter, + DotOpOnTensorsConversion>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 31b89d2..b5d762e 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -830,3 +830,44 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // CHECK-SAME: ins([[CST]] : tensor) outs([[INIT]] : tensor) // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +func @dot_matmul(%arg0: tensor<2x3xf32>, + %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>, + tensor<3x?xf32>) -> tensor<2x?xf32> + return %0 : tensor<2x?xf32> +} +// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) +// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>) + +// ----- + +func @dot_matvec(%arg0: tensor, + %arg1: tensor<3xf32>) -> tensor { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, + tensor<3xf32>) -> tensor + return %0 : tensor +} +// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.matvec +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3xf32>) +// CHECK-SAME: outs(%[[INIT]] : tensor) + +// ----- + +func @dot_dot(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.dot +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT]] : tensor)