Add support for lowering mhlo.pad to linalg.pad_tensor
The change upstreams the pattern from IREE repo to MHLO repo. PiperOrigin-RevId: 359481543
This commit is contained in:
parent
459362b206
commit
45a1249fe2
|
@ -1397,6 +1397,41 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Converts mhlo.pad operation to linalg.pad_tensor op.
|
||||||
|
struct PadOpOnTensorsConversion : public OpConversionPattern<mhlo::PadOp> {
|
||||||
|
using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::PadOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
mhlo::PadOp::Adaptor adaptor(args);
|
||||||
|
if (llvm::any_of(
|
||||||
|
op.interior_padding().getValues<APInt>(),
|
||||||
|
[](const APInt& int_val) { return int_val.getZExtValue() != 0; })) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "expected no interior padding");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value padding_val =
|
||||||
|
rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value());
|
||||||
|
|
||||||
|
const auto& edge_padding_low = op.edge_padding_low();
|
||||||
|
const auto& edge_padding_high = op.edge_padding_high();
|
||||||
|
SmallVector<OpFoldResult, 4> low, high;
|
||||||
|
for (auto it : llvm::zip(edge_padding_low, edge_padding_high)) {
|
||||||
|
low.push_back(rewriter.createOrFold<ConstantIndexOp>(
|
||||||
|
loc, std::get<0>(it).getZExtValue()));
|
||||||
|
high.push_back(rewriter.createOrFold<ConstantIndexOp>(
|
||||||
|
loc, std::get<1>(it).getZExtValue()));
|
||||||
|
}
|
||||||
|
Type result_type = op.getResult().getType();
|
||||||
|
auto pad_tensor_op = linalg::PadTensorOp::createPadScalarOp(
|
||||||
|
result_type, adaptor.operand(), padding_val, low, high, loc, rewriter);
|
||||||
|
rewriter.replaceOp(op, pad_tensor_op.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -1529,6 +1564,7 @@ namespace mhlo {
|
||||||
|
|
||||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
BroadcastConverter<mhlo::BroadcastOp, false>,
|
BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||||
|
@ -1620,7 +1656,9 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
linalg::BatchMatmulI32I32I32Op>,
|
linalg::BatchMatmulI32I32I32Op>,
|
||||||
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||||
linalg::BatchMatmulOp>,
|
linalg::BatchMatmulOp>,
|
||||||
ReduceOnTensorsConversion>(context);
|
ReduceOnTensorsConversion,
|
||||||
|
PadOpOnTensorsConversion>(context);
|
||||||
|
// clang-format on
|
||||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
||||||
|
|
|
@ -1396,3 +1396,48 @@ func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @slice_stride_part
|
// CHECK-LABEL: func @slice_stride_part
|
||||||
// CHECK: subtensor %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>
|
// CHECK: subtensor %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
|
||||||
|
%0 = constant dense<0.0> : tensor<f32>
|
||||||
|
%1 = "mhlo.pad"(%arg0, %0) {
|
||||||
|
edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
edge_padding_low = dense<[4, 5]> : tensor<2xi64>,
|
||||||
|
interior_padding = dense<0> : tensor<2xi64>
|
||||||
|
} : (tensor<12x4xf32>, tensor<f32>) -> tensor<18x12xf32>
|
||||||
|
return %1 : tensor<18x12xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @pad_cst
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-DAG: %[[CST:.+]] = constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[CST]][] : tensor<f32>
|
||||||
|
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||||
|
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||||
|
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||||
|
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||||
|
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
|
||||||
|
// CHECK: linalg.yield %[[PAD]] : f32
|
||||||
|
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf32> {
|
||||||
|
%0 = "mhlo.pad"(%arg0, %arg1) {
|
||||||
|
edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
edge_padding_low = dense<[4, 5]> : tensor<2xi64>,
|
||||||
|
interior_padding = dense<0> : tensor<2xi64>
|
||||||
|
} : (tensor<12x4xf32>, tensor<f32>) -> tensor<18x12xf32>
|
||||||
|
return %0 : tensor<18x12xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @pad_tensor
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||||
|
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||||
|
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||||
|
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||||
|
// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
|
||||||
|
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
|
||||||
|
// CHECK: linalg.yield %[[PAD]] : f32
|
||||||
|
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
|
||||||
|
|
Loading…
Reference in New Issue