diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index be6eaa9..2a523e6 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -141,6 +141,72 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank, return inversePermutation(map); } +/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op +/// follows a canonical form: +/// +/// * Input dimensions have order: (batch_count, spatial_dims, +/// input_channel_count). +/// * Filter dimensions have order: (spatial_dims, input_channel_count, +/// output_channel_count). +/// * Output dimensions have order: (batch_count, spatial_dims, +/// output_channel_count). +template +static bool HasCanonicalDimensionNumbers( + const DimensionNumbersTy& dimension_numbers) { + const int input_spatial_rank = + llvm::size(dimension_numbers.input_spatial_dimensions()); + // The dimensions for input should follow the order of + // batch_count, spatial_dims..., input_feature_count. + if (dimension_numbers.input_batch_dimension().getInt() != 0 || + dimension_numbers.input_feature_dimension().getInt() != + (input_spatial_rank + 1)) { + return false; + } + + const int kernel_spatial_rank = + llvm::size(dimension_numbers.kernel_spatial_dimensions()); + // The dimensions for filter should follow the order of + // spatial_dims..., input_feature_count, num_output_feature_count. + if (dimension_numbers.kernel_input_feature_dimension().getInt() != + kernel_spatial_rank || + dimension_numbers.kernel_output_feature_dimension().getInt() != + (kernel_spatial_rank + 1)) { + return false; + } + + const int output_spatial_rank = + llvm::size(dimension_numbers.output_spatial_dimensions()); + // The dimensions for output should follow the order of + // batch_count, spatial_dims.., output_feature_count. + if (dimension_numbers.output_batch_dimension().getInt() != 0 || + dimension_numbers.output_feature_dimension().getInt() != + (output_spatial_rank + 1)) { + return false; + } + + if (input_spatial_rank != output_spatial_rank || + input_spatial_rank != kernel_spatial_rank) { + return false; + } + + auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin(); + auto kernel_spatial_dim = + dimension_numbers.kernel_spatial_dimensions().begin(); + auto output_spatial_dim = + dimension_numbers.output_spatial_dimensions().begin(); + // Check spatial dims are ordered correctly. + for (int i = 0; i < input_spatial_rank; ++i) { + const int dim = i + 1; + if ((*input_spatial_dim++).getZExtValue() != dim || + (*output_spatial_dim++).getZExtValue() != dim || + (*kernel_spatial_dim++).getZExtValue() != i) { + return false; + } + } + + return true; +} + template class PointwiseToLinalgConverter : public OpConversionPattern { public: @@ -264,61 +330,10 @@ struct ConvToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - // This code has been adapted from IREE's - // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - // Check validity of dimension information. - if (const mhlo::ConvDimensionNumbers& dimension_numbers = - op.dimension_numbers()) { - const int input_spatial_rank = - llvm::size(dimension_numbers.input_spatial_dimensions()); - // The dimensions for input should follow the order of - // batch_count, spatial_dims..., input_feature_count. - if (dimension_numbers.input_batch_dimension().getInt() != 0 || - dimension_numbers.input_feature_dimension().getInt() != - (input_spatial_rank + 1)) - return failure(); - - const int kernel_spatial_rank = - llvm::size(dimension_numbers.kernel_spatial_dimensions()); - // The dimensions for filter should follow the order of - // spatial_dims..., input_feature_count, num_output_feature_count. - if (dimension_numbers.kernel_input_feature_dimension().getInt() != - kernel_spatial_rank || - dimension_numbers.kernel_output_feature_dimension().getInt() != - (kernel_spatial_rank + 1)) - return failure(); - - const int output_spatial_rank = - llvm::size(dimension_numbers.output_spatial_dimensions()); - // The dimensions for output should follow the order of - // batch_count, spatial_dims.., output_feature_count. - if (dimension_numbers.output_batch_dimension().getInt() != 0 || - dimension_numbers.output_feature_dimension().getInt() != - (output_spatial_rank + 1)) - return failure(); - - if (input_spatial_rank != output_spatial_rank || - input_spatial_rank != kernel_spatial_rank) - return failure(); - - auto input_spatial_dim = - dimension_numbers.input_spatial_dimensions().begin(); - auto kernel_spatial_dim = - dimension_numbers.kernel_spatial_dimensions().begin(); - auto output_spatial_dim = - dimension_numbers.output_spatial_dimensions().begin(); - // Check if spatial dims are ordered correctly. - for (int i = 0; i < input_spatial_rank; ++i) { - const int dim = i + 1; - if ((*input_spatial_dim++).getZExtValue() != dim || - (*output_spatial_dim++).getZExtValue() != dim || - (*kernel_spatial_dim++).getZExtValue() != i) - return failure(); - } - } + if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure(); // TODO: LHS dilation for deconvolution not supported yet. // TODO(jurahul): Window reversal is not supported yet. @@ -1432,6 +1447,80 @@ struct PadOpOnTensorsConversion : public OpConversionPattern { } }; +/// Converts mhlo.conv operation to linalg named op. This only covers normal +/// convolution cases. The op must have canonical dimension numbers. Depthwise +/// convolution and pointwise convolution are not handled in the conversion. +struct NormalConvOpOnTensorsConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConvOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const override { + if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure(); + if (op.feature_group_count() != 1u) return failure(); + + mhlo::ConvOp::Adaptor adaptor(args); + Location loc = op.getLoc(); + Value input = adaptor.lhs(); + Value filter = adaptor.rhs(); + auto result_type = op.getResult().getType().cast(); + int64_t rank = result_type.getRank(); + + // Check if padding is zero. + DenseIntElementsAttr padding = op.paddingAttr(); + if (padding && + (!padding.isSplat() || padding.getSplatValue() != 0)) { + return rewriter.notifyMatchFailure(op, "expected no padding"); + } + + // The output shape is N spatial_dims F. + SmallVector dyn_sizes; + for (int64_t i = 0, e = rank - 1; i < e; ++i) { + if (!result_type.isDynamicDim(i)) continue; + dyn_sizes.push_back(rewriter.create(loc, input, i)); + } + if (result_type.isDynamicDim(rank - 1)) { + dyn_sizes.push_back(rewriter.create(loc, filter, rank - 1)); + } + Value init_tensor = rewriter.create( + loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); + auto zero_attr = rewriter.getZeroAttr(result_type.getElementType()); + Value zero = rewriter.create(loc, zero_attr); + Value zero_tensor = + rewriter.create(loc, init_tensor, zero).getResult(0); + linalg::LinalgOp res; + Attribute strides = op.window_stridesAttr(); + // TODO(ataei): Only support dilated kernel right now. We need to consider + // input dilation for deconvolution cases. + Attribute dilations = op.rhs_dilationAttr(); + switch (rank) { + case 3: { + res = rewriter.create( + loc, result_type, ValueRange{input, filter}, + ValueRange{zero_tensor}, dilations, strides); + break; + } + case 4: { + res = rewriter.create( + loc, result_type, ValueRange{input, filter}, + ValueRange{zero_tensor}, dilations, strides); + break; + } + case 5: { + res = rewriter.create( + loc, result_type, ValueRange{input, filter}, + ValueRange{zero_tensor}, dilations, strides); + break; + } + default: + return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op"); + } + rewriter.replaceOp(op, res.getOperation()->getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1656,6 +1745,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, linalg::BatchMatmulI32I32I32Op>, DotGeneralOpOnTensorsConversion, + NormalConvOpOnTensorsConversion, ReduceOnTensorsConversion, PadOpOnTensorsConversion>(context); // clang-format on diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 5d320c9..fdaf9f8 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1441,3 +1441,171 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf3 // CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]] // CHECK: linalg.yield %[[PAD]] : f32 // CHECK: } : tensor<12x4xf32> to tensor<18x12xf32> + +// ----- + +func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 2 : i64, + input_spatial_dimensions = dense<[1]> : tensor<1xi64>, + kernel_input_feature_dimension = 1 : i64, + kernel_output_feature_dimension = 2 : i64, + kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 2 : i64, + output_spatial_dimensions = dense<[1]> : tensor<1xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0], [0]]> : tensor<2x1xi64>, + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) +// CHECK: linalg.conv_1d_input_nwc_filter_wcf +// CHECK-SAME: {dilations = dense<1> : tensor<1xi64> +// CHECK-SAME: strides = dense<1> : tensor<1xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + +// ----- + +func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[C3:.+]] = constant 3 : index +// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]] +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: {dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + +// ----- + +func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 4 : i64, + input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, + kernel_input_feature_dimension = 3 : i64, + kernel_output_feature_dimension = 4 : i64, + kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 4 : i64, + output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>, + rhs_dilation = dense<1> : tensor<3xi64>, + window_strides = dense<1> : tensor<3xi64> + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[C3:.+]] = constant 3 : index +// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor +// CHECK: %[[C4:.+]] = constant 4 : index +// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]] +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) +// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf +// CHECK-SAME: {dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + +// ----- + +func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>) + -> tensor<1x2x4x3xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<0> : tensor<2x2xi64>, + rhs_dilation = dense<[2, 1]> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32> + return %0 : tensor<1x2x4x3xf32> +} +// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32> +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32> +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>