Add support for lowering mhlo.conv to Linalg on tensors.

This pattern only works for normal convolutions. It does not work for depthwise
convolutions. The Linalg conv ops are defined with static rank, so it only
supports 1d/2d/3d cases, which are the most typical cases.

This also refactors out the same check in lmhlo.conv lowering.

PiperOrigin-RevId: 359503527
This commit is contained in:
Hanhan Wang 2021-02-25 05:58:12 -08:00 committed by TensorFlow MLIR Team
parent c5f5d13930
commit 90f0d7f935
2 changed files with 310 additions and 52 deletions

View File

@ -141,6 +141,72 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
return inversePermutation(map);
}
/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
/// follows a canonical form:
///
/// * Input dimensions have order: (batch_count, spatial_dims,
/// input_channel_count).
/// * Filter dimensions have order: (spatial_dims, input_channel_count,
/// output_channel_count).
/// * Output dimensions have order: (batch_count, spatial_dims,
/// output_channel_count).
template <typename DimensionNumbersTy>
static bool HasCanonicalDimensionNumbers(
const DimensionNumbersTy& dimension_numbers) {
const int input_spatial_rank =
llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimension_numbers.input_feature_dimension().getInt() !=
(input_spatial_rank + 1)) {
return false;
}
const int kernel_spatial_rank =
llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernel_spatial_rank ||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernel_spatial_rank + 1)) {
return false;
}
const int output_spatial_rank =
llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimension_numbers.output_feature_dimension().getInt() !=
(output_spatial_rank + 1)) {
return false;
}
if (input_spatial_rank != output_spatial_rank ||
input_spatial_rank != kernel_spatial_rank) {
return false;
}
auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin();
auto kernel_spatial_dim =
dimension_numbers.kernel_spatial_dimensions().begin();
auto output_spatial_dim =
dimension_numbers.output_spatial_dimensions().begin();
// Check spatial dims are ordered correctly.
for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1;
if ((*input_spatial_dim++).getZExtValue() != dim ||
(*output_spatial_dim++).getZExtValue() != dim ||
(*kernel_spatial_dim++).getZExtValue() != i) {
return false;
}
}
return true;
}
template <typename OpTy, bool isLHLO = true>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
@ -264,61 +330,10 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
public:
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite(
lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
op.dimension_numbers()) {
const int input_spatial_rank =
llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimension_numbers.input_feature_dimension().getInt() !=
(input_spatial_rank + 1))
return failure();
const int kernel_spatial_rank =
llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernel_spatial_rank ||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernel_spatial_rank + 1))
return failure();
const int output_spatial_rank =
llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimension_numbers.output_feature_dimension().getInt() !=
(output_spatial_rank + 1))
return failure();
if (input_spatial_rank != output_spatial_rank ||
input_spatial_rank != kernel_spatial_rank)
return failure();
auto input_spatial_dim =
dimension_numbers.input_spatial_dimensions().begin();
auto kernel_spatial_dim =
dimension_numbers.kernel_spatial_dimensions().begin();
auto output_spatial_dim =
dimension_numbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly.
for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1;
if ((*input_spatial_dim++).getZExtValue() != dim ||
(*output_spatial_dim++).getZExtValue() != dim ||
(*kernel_spatial_dim++).getZExtValue() != i)
return failure();
}
}
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
// TODO: LHS dilation for deconvolution not supported yet.
// TODO(jurahul): Window reversal is not supported yet.
@ -1432,6 +1447,80 @@ struct PadOpOnTensorsConversion : public OpConversionPattern<mhlo::PadOp> {
}
};
/// Converts mhlo.conv operation to linalg named op. This only covers normal
/// convolution cases. The op must have canonical dimension numbers. Depthwise
/// convolution and pointwise convolution are not handled in the conversion.
struct NormalConvOpOnTensorsConversion
: public OpConversionPattern<mhlo::ConvOp> {
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const override {
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
if (op.feature_group_count() != 1u) return failure();
mhlo::ConvOp::Adaptor adaptor(args);
Location loc = op.getLoc();
Value input = adaptor.lhs();
Value filter = adaptor.rhs();
auto result_type = op.getResult().getType().cast<ShapedType>();
int64_t rank = result_type.getRank();
// Check if padding is zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (padding &&
(!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
return rewriter.notifyMatchFailure(op, "expected no padding");
}
// The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes;
for (int64_t i = 0, e = rank - 1; i < e; ++i) {
if (!result_type.isDynamicDim(i)) continue;
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i));
}
if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
}
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
linalg::LinalgOp res;
Attribute strides = op.window_stridesAttr();
// TODO(ataei): Only support dilated kernel right now. We need to consider
// input dilation for deconvolution cases.
Attribute dilations = op.rhs_dilationAttr();
switch (rank) {
case 3: {
res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, dilations, strides);
break;
}
case 4: {
res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, dilations, strides);
break;
}
case 5: {
res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, dilations, strides);
break;
}
default:
return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
}
rewriter.replaceOp(op, res.getOperation()->getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
@ -1656,6 +1745,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
linalg::BatchMatmulI32I32I32Op>,
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
linalg::BatchMatmulOp>,
NormalConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
PadOpOnTensorsConversion>(context);
// clang-format on

View File

@ -1441,3 +1441,171 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
// CHECK: linalg.yield %[[PAD]] : f32
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
// -----
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 2 : i64,
input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
kernel_input_feature_dimension = 1 : i64,
kernel_output_feature_dimension = 2 : i64,
kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 2 : i64,
output_spatial_dimensions = dense<[1]> : tensor<1xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0], [0]]> : tensor<2x1xi64>,
rhs_dilation = dense<1> : tensor<1xi64>,
window_strides = dense<1> : tensor<1xi64>
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// -----
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
-> tensor<?x?x?x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// -----
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
-> tensor<?x?x?x?x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 4 : i64,
input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
kernel_input_feature_dimension = 3 : i64,
kernel_output_feature_dimension = 4 : i64,
kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 4 : i64,
output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
rhs_dilation = dense<1> : tensor<3xi64>,
window_strides = dense<1> : tensor<3xi64>
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C4:.+]] = constant 4 : index
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// -----
func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>)
-> tensor<1x2x4x3xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
return %0 : tensor<1x2x4x3xf32>
}
// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32>
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32>
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>