Add support for lowering 2D depthwise mhlo.conv to Linalg on tensors.

The change upstreams the pattern from IREE repo to MHLO repo.

PiperOrigin-RevId: 362300550
This commit is contained in:
Hanhan Wang 2021-03-11 08:40:31 -08:00 committed by TensorFlow MLIR Team
parent 94f9740c67
commit 630cabefb0
2 changed files with 225 additions and 2 deletions

View File

@ -151,6 +151,11 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
return inversePermutation(map);
}
/// Returns true if the given `attr` is a splat of the given `value`.
bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
}
/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
/// follows a canonical form:
///
@ -1455,8 +1460,7 @@ struct NormalConvOpOnTensorsConversion
// Check if padding is zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (padding &&
(!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
if (padding && !isSplatValue(*op.padding(), 0)) {
return rewriter.notifyMatchFailure(op, "expected no padding");
}
@ -1512,6 +1516,145 @@ struct NormalConvOpOnTensorsConversion
}
};
/// Converts mhlo.convolution operation to
/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
/// depthwise_conv_2d_input_nhwc_filter_hwc op.
struct DepthwiseConvOpOnTensorsConversion
: public OpConversionPattern<mhlo::ConvOp> {
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const override {
if (op.batch_group_count() != 1) return failure();
if (op.padding() && !isSplatValue(*op.padding(), 0)) {
return rewriter.notifyMatchFailure(op,
"non-zero padding unsupported yet");
}
if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1)) ||
(op.rhs_dilation() && !isSplatValue(*op.rhs_dilation(), 1))) {
return rewriter.notifyMatchFailure(op,
"non-one dialation unsupported yet");
}
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
op.dimension_numbers()) {
// Make sure that this is 2-D convolution.
const auto spatial_rank =
llvm::size(dimension_numbers.input_spatial_dimensions());
if (spatial_rank != 2) {
return rewriter.notifyMatchFailure(op,
"only support 2-D cases for now");
}
// Make sure that this is depthwise convolution.
int64_t input_feature_dim =
dimension_numbers.input_feature_dimension().getInt();
int64_t input_feature_count =
op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
if (op.feature_group_count() != input_feature_count) {
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
}
// Make sure that this convolution has a canonical form.
if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
return rewriter.notifyMatchFailure(op, "does not have canonical form");
}
}
DenseIntElementsAttr window_strides;
if (op.window_strides()) {
window_strides = op.window_strides().getValue();
} else {
window_strides = rewriter.getI64VectorAttr({1, 1});
}
mhlo::ConvOp::Adaptor adaptor(args);
Location loc = op.getLoc();
Value input = adaptor.lhs();
Value filter = adaptor.rhs();
auto result_type = op.getResult().getType().cast<RankedTensorType>();
if (!result_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected output has static shapes");
}
auto filter_dims =
llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
auto get_indices_vector = [](int start, int end) {
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
};
if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
// For cases where channel multiplier != 1
auto output_dims = result_type.getShape();
auto channel_multiplier = filter_dims[3];
SmallVector<int64_t> reshaped_output_dims;
reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
reshaped_output_dims.push_back(channel_multiplier);
reshaped_output_dims[3] /= channel_multiplier;
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, reshaped_output_dims, result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
auto reshaped_output_type = RankedTensorType::get(
reshaped_output_dims, result_type.getElementType());
auto conv = rewriter.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
op.getLoc(), reshaped_output_type, ValueRange{input, filter},
ValueRange{zero_tensor}, window_strides);
// Create a Linalg reshape op that converts the output from 5 dimensions
// into 4 dimensions (by collapsing the last two dimensions). This is
// needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
// 5 dimensions for the output.
SmallVector<linalg::ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 3), get_indices_vector(3, 5)};
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
op, result_type, conv.getResult(0), collapsed_dim_list);
} else {
// For cases where channel multiplier == 1
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, result_type.getShape(), result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
// Create a Linalg reshape op that converts the filter from 4 dimensions
// into 3 dimensions (by droping the unit dimension). This is needed
// because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
// dimensions for the filter.
filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
filter_dims.pop_back();
RankedTensorType filter_shape =
RankedTensorType::get(filter_dims, op.getType().getElementType());
SmallVector<linalg::ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 4)};
Value reshaped_filter = rewriter.create<linalg::TensorReshapeOp>(
loc, filter_shape, filter, collapsed_dim_list);
rewriter.replaceOpWithNewOp<linalg::DepthwiseConvInputNHWCFilterHWCOp>(
op, result_type, ValueRange{input, reshaped_filter},
ValueRange{zero_tensor}, window_strides);
}
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
@ -1701,6 +1844,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralOpOnTensorsConversion,
NormalConvOpOnTensorsConversion,
DepthwiseConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
PadOpOnTensorsConversion>(context);
// clang-format on

View File

@ -1658,3 +1658,82 @@ func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>
// -----
func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>,
%arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 2 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
// CHECK: func @depthwise_conv
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
// CHECK: %[[CST:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[CST]]) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[IN]], %[[FILTER]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
// CHECK: %{{.+}} = linalg.tensor_reshape %[[OUT]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32>
// -----
func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>,
%arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 96 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
return %0 : tensor<1x56x56x96xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK: func @depthwise_conv_multiplier_1
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
// CHECK: %[[CST:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[CST]]) : tensor<1x56x56x96xf32>, f32 -> tensor<1x56x56x96xf32>
// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]] [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : tensor<3x3x1x96xf32> into tensor<3x3x96xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
// CHECK-SAME: {strides = dense<2> : tensor<2xi64>}
// CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>