From 90f0d7f935406889cfa57ca1679a92b170ca5f72 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Thu, 25 Feb 2021 05:58:12 -0800 Subject: [PATCH] Add support for lowering mhlo.conv to Linalg on tensors. This pattern only works for normal convolutions. It does not work for depthwise convolutions. The Linalg conv ops are defined with static rank, so it only supports 1d/2d/3d cases, which are the most typical cases. This also refactors out the same check in lmhlo.conv lowering. PiperOrigin-RevId: 359503527 --- .../mhlo/transforms/legalize_to_linalg.cc | 194 +++++++++++++----- tests/hlo-legalize-to-linalg.mlir | 168 +++++++++++++++ 2 files changed, 310 insertions(+), 52 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index be6eaa9..2a523e6 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -141,6 +141,72 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank, return inversePermutation(map); } +/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op +/// follows a canonical form: +/// +/// * Input dimensions have order: (batch_count, spatial_dims, +/// input_channel_count). +/// * Filter dimensions have order: (spatial_dims, input_channel_count, +/// output_channel_count). +/// * Output dimensions have order: (batch_count, spatial_dims, +/// output_channel_count). +template +static bool HasCanonicalDimensionNumbers( + const DimensionNumbersTy& dimension_numbers) { + const int input_spatial_rank = + llvm::size(dimension_numbers.input_spatial_dimensions()); + // The dimensions for input should follow the order of + // batch_count, spatial_dims..., input_feature_count. + if (dimension_numbers.input_batch_dimension().getInt() != 0 || + dimension_numbers.input_feature_dimension().getInt() != + (input_spatial_rank + 1)) { + return false; + } + + const int kernel_spatial_rank = + llvm::size(dimension_numbers.kernel_spatial_dimensions()); + // The dimensions for filter should follow the order of + // spatial_dims..., input_feature_count, num_output_feature_count. + if (dimension_numbers.kernel_input_feature_dimension().getInt() != + kernel_spatial_rank || + dimension_numbers.kernel_output_feature_dimension().getInt() != + (kernel_spatial_rank + 1)) { + return false; + } + + const int output_spatial_rank = + llvm::size(dimension_numbers.output_spatial_dimensions()); + // The dimensions for output should follow the order of + // batch_count, spatial_dims.., output_feature_count. + if (dimension_numbers.output_batch_dimension().getInt() != 0 || + dimension_numbers.output_feature_dimension().getInt() != + (output_spatial_rank + 1)) { + return false; + } + + if (input_spatial_rank != output_spatial_rank || + input_spatial_rank != kernel_spatial_rank) { + return false; + } + + auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin(); + auto kernel_spatial_dim = + dimension_numbers.kernel_spatial_dimensions().begin(); + auto output_spatial_dim = + dimension_numbers.output_spatial_dimensions().begin(); + // Check spatial dims are ordered correctly. + for (int i = 0; i < input_spatial_rank; ++i) { + const int dim = i + 1; + if ((*input_spatial_dim++).getZExtValue() != dim || + (*output_spatial_dim++).getZExtValue() != dim || + (*kernel_spatial_dim++).getZExtValue() != i) { + return false; + } + } + + return true; +} + template class PointwiseToLinalgConverter : public OpConversionPattern { public: @@ -264,61 +330,10 @@ struct ConvToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - // This code has been adapted from IREE's - // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - // Check validity of dimension information. - if (const mhlo::ConvDimensionNumbers& dimension_numbers = - op.dimension_numbers()) { - const int input_spatial_rank = - llvm::size(dimension_numbers.input_spatial_dimensions()); - // The dimensions for input should follow the order of - // batch_count, spatial_dims..., input_feature_count. - if (dimension_numbers.input_batch_dimension().getInt() != 0 || - dimension_numbers.input_feature_dimension().getInt() != - (input_spatial_rank + 1)) - return failure(); - - const int kernel_spatial_rank = - llvm::size(dimension_numbers.kernel_spatial_dimensions()); - // The dimensions for filter should follow the order of - // spatial_dims..., input_feature_count, num_output_feature_count. - if (dimension_numbers.kernel_input_feature_dimension().getInt() != - kernel_spatial_rank || - dimension_numbers.kernel_output_feature_dimension().getInt() != - (kernel_spatial_rank + 1)) - return failure(); - - const int output_spatial_rank = - llvm::size(dimension_numbers.output_spatial_dimensions()); - // The dimensions for output should follow the order of - // batch_count, spatial_dims.., output_feature_count. - if (dimension_numbers.output_batch_dimension().getInt() != 0 || - dimension_numbers.output_feature_dimension().getInt() != - (output_spatial_rank + 1)) - return failure(); - - if (input_spatial_rank != output_spatial_rank || - input_spatial_rank != kernel_spatial_rank) - return failure(); - - auto input_spatial_dim = - dimension_numbers.input_spatial_dimensions().begin(); - auto kernel_spatial_dim = - dimension_numbers.kernel_spatial_dimensions().begin(); - auto output_spatial_dim = - dimension_numbers.output_spatial_dimensions().begin(); - // Check if spatial dims are ordered correctly. - for (int i = 0; i < input_spatial_rank; ++i) { - const int dim = i + 1; - if ((*input_spatial_dim++).getZExtValue() != dim || - (*output_spatial_dim++).getZExtValue() != dim || - (*kernel_spatial_dim++).getZExtValue() != i) - return failure(); - } - } + if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure(); // TODO: LHS dilation for deconvolution not supported yet. // TODO(jurahul): Window reversal is not supported yet. @@ -1432,6 +1447,80 @@ struct PadOpOnTensorsConversion : public OpConversionPattern { } }; +/// Converts mhlo.conv operation to linalg named op. This only covers normal +/// convolution cases. The op must have canonical dimension numbers. Depthwise +/// convolution and pointwise convolution are not handled in the conversion. +struct NormalConvOpOnTensorsConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConvOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const override { + if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure(); + if (op.feature_group_count() != 1u) return failure(); + + mhlo::ConvOp::Adaptor adaptor(args); + Location loc = op.getLoc(); + Value input = adaptor.lhs(); + Value filter = adaptor.rhs(); + auto result_type = op.getResult().getType().cast(); + int64_t rank = result_type.getRank(); + + // Check if padding is zero. + DenseIntElementsAttr padding = op.paddingAttr(); + if (padding && + (!padding.isSplat() || padding.getSplatValue() != 0)) { + return rewriter.notifyMatchFailure(op, "expected no padding"); + } + + // The output shape is N spatial_dims F. + SmallVector dyn_sizes; + for (int64_t i = 0, e = rank - 1; i < e; ++i) { + if (!result_type.isDynamicDim(i)) continue; + dyn_sizes.push_back(rewriter.create(loc, input, i)); + } + if (result_type.isDynamicDim(rank - 1)) { + dyn_sizes.push_back(rewriter.create(loc, filter, rank - 1)); + } + Value init_tensor = rewriter.create( + loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); + auto zero_attr = rewriter.getZeroAttr(result_type.getElementType()); + Value zero = rewriter.create(loc, zero_attr); + Value zero_tensor = + rewriter.create(loc, init_tensor, zero).getResult(0); + linalg::LinalgOp res; + Attribute strides = op.window_stridesAttr(); + // TODO(ataei): Only support dilated kernel right now. We need to consider + // input dilation for deconvolution cases. + Attribute dilations = op.rhs_dilationAttr(); + switch (rank) { + case 3: { + res = rewriter.create( + loc, result_type, ValueRange{input, filter}, + ValueRange{zero_tensor}, dilations, strides); + break; + } + case 4: { + res = rewriter.create( + loc, result_type, ValueRange{input, filter}, + ValueRange{zero_tensor}, dilations, strides); + break; + } + case 5: { + res = rewriter.create( + loc, result_type, ValueRange{input, filter}, + ValueRange{zero_tensor}, dilations, strides); + break; + } + default: + return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op"); + } + rewriter.replaceOp(op, res.getOperation()->getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1656,6 +1745,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, linalg::BatchMatmulI32I32I32Op>, DotGeneralOpOnTensorsConversion, + NormalConvOpOnTensorsConversion, ReduceOnTensorsConversion, PadOpOnTensorsConversion>(context); // clang-format on diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 5d320c9..fdaf9f8 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1441,3 +1441,171 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf3 // CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]] // CHECK: linalg.yield %[[PAD]] : f32 // CHECK: } : tensor<12x4xf32> to tensor<18x12xf32> + +// ----- + +func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 2 : i64, + input_spatial_dimensions = dense<[1]> : tensor<1xi64>, + kernel_input_feature_dimension = 1 : i64, + kernel_output_feature_dimension = 2 : i64, + kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 2 : i64, + output_spatial_dimensions = dense<[1]> : tensor<1xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0], [0]]> : tensor<2x1xi64>, + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) +// CHECK: linalg.conv_1d_input_nwc_filter_wcf +// CHECK-SAME: {dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + +// ----- + +func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[C3:.+]] = constant 3 : index +// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]] +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: {dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + +// ----- + +func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 4 : i64, + input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, + kernel_input_feature_dimension = 3 : i64, + kernel_output_feature_dimension = 4 : i64, + kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 4 : i64, + output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>, + rhs_dilation = dense<1> : tensor<3xi64>, + window_strides = dense<1> : tensor<3xi64> + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[C3:.+]] = constant 3 : index +// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor +// CHECK: %[[C4:.+]] = constant 4 : index +// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]] +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) +// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf +// CHECK-SAME: {dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + +// ----- + +func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>) + -> tensor<1x2x4x3xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<0> : tensor<2x2xi64>, + rhs_dilation = dense<[2, 1]> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32> + return %0 : tensor<1x2x4x3xf32> +} +// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32> +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32> +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>