From 300a7c11ce9614f53826870ad6ac415b837ba5e9 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 12 Jan 2021 22:07:29 -0800 Subject: [PATCH] Upstream mhlo.dot_general lowering to Linalg to MHLO repo PiperOrigin-RevId: 351514250 --- .../mhlo/transforms/legalize_to_linalg.cc | 86 ++++++++++++++++++- tests/hlo-legalize-to-linalg.mlir | 21 +++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index f1cc090..9a31f8c 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -99,6 +99,14 @@ SmallVector ExtractDynamicSizes(OpBuilder& b, Location loc, return dyn_sizes; } +SmallVector Extract1DVector(DenseIntElementsAttr elements) { + SmallVector ret; + for (const APInt& element : elements) { + ret.push_back(element.getLimitedValue()); + } + return ret; +} + template class PointwiseToLinalgConverter : public OpConversionPattern { public: @@ -1131,6 +1139,81 @@ class DotOpOnTensorsConversion : public OpConversionPattern { } }; +SmallVector GetDotGeneralOpInitTensorDynSizes( + OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) { + SmallVector dyn_shape; + if (result_type.isDynamicDim(0)) + dyn_shape.push_back(b.create(loc, lhs, 0)); + if (result_type.isDynamicDim(1)) + dyn_shape.push_back(b.create(loc, lhs, 1)); + if (result_type.isDynamicDim(2)) + dyn_shape.push_back(b.create(loc, rhs, 2)); + return dyn_shape; +} + +class DotGeneralOpOnTensorsConversion + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!VerifyHloOpBufferOrTensorSemantics(op)) { + return failure(); + } + mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers(); + auto lhs_bathcing_dims = + Extract1DVector(dim_numbers.lhs_batching_dimensions()); + auto rhs_bathcing_dims = + Extract1DVector(dim_numbers.rhs_batching_dimensions()); + auto lhs_contracting_dims = + Extract1DVector(dim_numbers.lhs_contracting_dimensions()); + auto rhs_contracting_dims = + Extract1DVector(dim_numbers.rhs_contracting_dimensions()); + if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) { + return rewriter.notifyMatchFailure( + op, "expected lhs batching dimensions exactly {0}"); + } + if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) { + return rewriter.notifyMatchFailure( + op, "expected rhs batching dimensions exactly {0}"); + } + if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) { + return rewriter.notifyMatchFailure( + op, "expected lhs contracting dimensions exactly {2}"); + } + if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) { + return rewriter.notifyMatchFailure( + op, "expected rhs contracting dimensions exactly {1}"); + } + Location loc = op.getLoc(); + mhlo::DotGeneralOp::Adaptor adaptor(args); + Type result_type = op.getResult().getType(); + auto shaped_type = result_type.cast(); + SmallVector dyn_shape = GetDotGeneralOpInitTensorDynSizes( + rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type); + auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType()); + Value zero = rewriter.create(loc, zero_attr); + auto init_tensor = rewriter.create( + loc, result_type, dyn_shape); + { + OpBuilder::InsertionGuard guard(rewriter); + SmallVector arg_types(shaped_type.getRank(), + rewriter.getIndexType()); + Region& region = init_tensor.body(); + Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); + rewriter.setInsertionPointToEnd(block); + rewriter.create(loc, zero); + } + auto linalg_op = rewriter.create( + loc, /*resultTensorTypes=*/TypeRange{result_type}, + /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, + /*outputBuffers=*/ValueRange{init_tensor}); + rewriter.replaceOp(op, linalg_op.getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1299,7 +1382,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReshapeOpConverter, ReverseConverter, TransposeConverter, - DotOpOnTensorsConversion>(context); + DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>( + context); } std::unique_ptr> createLegalizeHloToLinalgPass() { diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index b5d762e..0058307 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -871,3 +871,24 @@ func @dot_dot(%arg0: tensor, // CHECK: linalg.dot // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) + +// ----- + +func @dot_general(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 ="mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }, + precision_config = ["DEFAULT", "DEFAULT"] + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: func @dot_general(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements +// CHECK: linalg.batch_matmul +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT]] : tensor)