Add support for lowering 2D depthwise mhlo.conv to Linalg on tensors.
The change upstreams the pattern from IREE repo to MHLO repo. PiperOrigin-RevId: 362300550
This commit is contained in:
parent
94f9740c67
commit
630cabefb0
|
@ -151,6 +151,11 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
|
|||
return inversePermutation(map);
|
||||
}
|
||||
|
||||
/// Returns true if the given `attr` is a splat of the given `value`.
|
||||
bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
|
||||
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
|
||||
}
|
||||
|
||||
/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
|
||||
/// follows a canonical form:
|
||||
///
|
||||
|
@ -1455,8 +1460,7 @@ struct NormalConvOpOnTensorsConversion
|
|||
|
||||
// Check if padding is zero.
|
||||
DenseIntElementsAttr padding = op.paddingAttr();
|
||||
if (padding &&
|
||||
(!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
|
||||
if (padding && !isSplatValue(*op.padding(), 0)) {
|
||||
return rewriter.notifyMatchFailure(op, "expected no padding");
|
||||
}
|
||||
|
||||
|
@ -1512,6 +1516,145 @@ struct NormalConvOpOnTensorsConversion
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts mhlo.convolution operation to
|
||||
/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
|
||||
/// depthwise_conv_2d_input_nhwc_filter_hwc op.
|
||||
struct DepthwiseConvOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::ConvOp> {
|
||||
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
if (op.batch_group_count() != 1) return failure();
|
||||
|
||||
if (op.padding() && !isSplatValue(*op.padding(), 0)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-zero padding unsupported yet");
|
||||
}
|
||||
|
||||
if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1)) ||
|
||||
(op.rhs_dilation() && !isSplatValue(*op.rhs_dilation(), 1))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-one dialation unsupported yet");
|
||||
}
|
||||
|
||||
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
|
||||
op.dimension_numbers()) {
|
||||
// Make sure that this is 2-D convolution.
|
||||
const auto spatial_rank =
|
||||
llvm::size(dimension_numbers.input_spatial_dimensions());
|
||||
if (spatial_rank != 2) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support 2-D cases for now");
|
||||
}
|
||||
|
||||
// Make sure that this is depthwise convolution.
|
||||
int64_t input_feature_dim =
|
||||
dimension_numbers.input_feature_dimension().getInt();
|
||||
int64_t input_feature_count =
|
||||
op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
|
||||
if (op.feature_group_count() != input_feature_count) {
|
||||
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
|
||||
}
|
||||
|
||||
// Make sure that this convolution has a canonical form.
|
||||
if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
|
||||
return rewriter.notifyMatchFailure(op, "does not have canonical form");
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntElementsAttr window_strides;
|
||||
if (op.window_strides()) {
|
||||
window_strides = op.window_strides().getValue();
|
||||
} else {
|
||||
window_strides = rewriter.getI64VectorAttr({1, 1});
|
||||
}
|
||||
|
||||
mhlo::ConvOp::Adaptor adaptor(args);
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.lhs();
|
||||
Value filter = adaptor.rhs();
|
||||
auto result_type = op.getResult().getType().cast<RankedTensorType>();
|
||||
if (!result_type.hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"expected output has static shapes");
|
||||
}
|
||||
|
||||
auto filter_dims =
|
||||
llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
|
||||
|
||||
auto get_indices_vector = [](int start, int end) {
|
||||
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
|
||||
};
|
||||
|
||||
if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
|
||||
// For cases where channel multiplier != 1
|
||||
auto output_dims = result_type.getShape();
|
||||
auto channel_multiplier = filter_dims[3];
|
||||
SmallVector<int64_t> reshaped_output_dims;
|
||||
reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
|
||||
reshaped_output_dims.push_back(channel_multiplier);
|
||||
reshaped_output_dims[3] /= channel_multiplier;
|
||||
|
||||
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, reshaped_output_dims, result_type.getElementType());
|
||||
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
|
||||
auto reshaped_output_type = RankedTensorType::get(
|
||||
reshaped_output_dims, result_type.getElementType());
|
||||
auto conv = rewriter.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
|
||||
op.getLoc(), reshaped_output_type, ValueRange{input, filter},
|
||||
ValueRange{zero_tensor}, window_strides);
|
||||
|
||||
// Create a Linalg reshape op that converts the output from 5 dimensions
|
||||
// into 4 dimensions (by collapsing the last two dimensions). This is
|
||||
// needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
|
||||
// 5 dimensions for the output.
|
||||
SmallVector<linalg::ReassociationIndices, 4> collapsed_dim_list = {
|
||||
get_indices_vector(0, 1), get_indices_vector(1, 2),
|
||||
get_indices_vector(2, 3), get_indices_vector(3, 5)};
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
op, result_type, conv.getResult(0), collapsed_dim_list);
|
||||
} else {
|
||||
// For cases where channel multiplier == 1
|
||||
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, result_type.getShape(), result_type.getElementType());
|
||||
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
|
||||
// Create a Linalg reshape op that converts the filter from 4 dimensions
|
||||
// into 3 dimensions (by droping the unit dimension). This is needed
|
||||
// because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
|
||||
// dimensions for the filter.
|
||||
|
||||
filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
|
||||
filter_dims.pop_back();
|
||||
|
||||
RankedTensorType filter_shape =
|
||||
RankedTensorType::get(filter_dims, op.getType().getElementType());
|
||||
|
||||
SmallVector<linalg::ReassociationIndices, 4> collapsed_dim_list = {
|
||||
get_indices_vector(0, 1), get_indices_vector(1, 2),
|
||||
get_indices_vector(2, 4)};
|
||||
|
||||
Value reshaped_filter = rewriter.create<linalg::TensorReshapeOp>(
|
||||
loc, filter_shape, filter, collapsed_dim_list);
|
||||
|
||||
rewriter.replaceOpWithNewOp<linalg::DepthwiseConvInputNHWCFilterHWCOp>(
|
||||
op, result_type, ValueRange{input, reshaped_filter},
|
||||
ValueRange{zero_tensor}, window_strides);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
|
@ -1701,6 +1844,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
|
||||
DotGeneralOpOnTensorsConversion,
|
||||
NormalConvOpOnTensorsConversion,
|
||||
DepthwiseConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
// clang-format on
|
||||
|
|
|
@ -1658,3 +1658,82 @@ func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2
|
|||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>,
|
||||
%arg1: tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 2 : i64,
|
||||
padding = dense<0> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
|
||||
// CHECK: func @depthwise_conv
|
||||
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
|
||||
// CHECK: %[[CST:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[CST]]) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
|
||||
// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[IN]], %[[FILTER]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
|
||||
// CHECK: %{{.+}} = linalg.tensor_reshape %[[OUT]]
|
||||
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
|
||||
// CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>,
|
||||
%arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 96 : i64,
|
||||
padding = dense<0> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
|
||||
return %0 : tensor<1x56x56x96xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK: func @depthwise_conv_multiplier_1
|
||||
// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
|
||||
// CHECK: %[[CST:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[CST]]) : tensor<1x56x56x96xf32>, f32 -> tensor<1x56x56x96xf32>
|
||||
// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]] [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : tensor<3x3x1x96xf32> into tensor<3x3x96xf32>
|
||||
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
|
||||
// CHECK-SAME: {strides = dense<2> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
|
||||
|
|
Loading…
Reference in New Issue