Add support for lowering NHWC pooling mhlo.reduce_window to Linalg on tensors.
The change upstreams the pattern from IREE repo to MHLO repo. PiperOrigin-RevId: 362312573
This commit is contained in:
parent
630cabefb0
commit
4f5e1c51dd
|
@ -1655,6 +1655,120 @@ struct DepthwiseConvOpOnTensorsConversion
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ReduceWindowOpOnTensorsConversion
|
||||||
|
: public OpConversionPattern<mhlo::ReduceWindowOp> {
|
||||||
|
using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
/// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
|
||||||
|
/// the pooling is determined based on the body of the reduce window
|
||||||
|
/// operation. This class enumerates the different variants.
|
||||||
|
enum class PoolingType {
|
||||||
|
kMin,
|
||||||
|
kMax,
|
||||||
|
kAdd,
|
||||||
|
};
|
||||||
|
|
||||||
|
static PoolingType getPoolingType(Region& region) {
|
||||||
|
assert(region.getBlocks().size() == 1 &&
|
||||||
|
"expected the region has exactlly one block");
|
||||||
|
Block& block = region.front();
|
||||||
|
assert(block.getOperations().size() == 2 &&
|
||||||
|
"expected the block has exactlly two operations");
|
||||||
|
auto op = block.begin();
|
||||||
|
if (isa<mhlo::MinOp>(op)) return PoolingType::kMin;
|
||||||
|
if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax;
|
||||||
|
if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd;
|
||||||
|
|
||||||
|
llvm_unreachable("unknown pooling type");
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::ReduceWindowOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto result_type = op.getResult().getType().cast<ShapedType>();
|
||||||
|
if (result_type.getRank() != 4) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a fake window dimension.
|
||||||
|
SmallVector<int64_t, 4> shapes;
|
||||||
|
shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
|
||||||
|
shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
|
||||||
|
auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, shapes, result_type.getElementType());
|
||||||
|
|
||||||
|
if (op.window_strides() &&
|
||||||
|
(op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
|
||||||
|
op.window_strides().getValue().getValue<int64_t>(3) != 1)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected window_strides to be [1,x,y,1]");
|
||||||
|
}
|
||||||
|
if (op.window_dimensions() &&
|
||||||
|
(op.window_dimensions().getValue<int64_t>(0) != 1 ||
|
||||||
|
op.window_dimensions().getValue<int64_t>(3) != 1)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected window_dimensions to be [1,x,y,1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!args[0].getType().cast<ShapedType>().getElementType().isF32()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "expected element type to be f32");
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute strides;
|
||||||
|
if (op.window_stridesAttr()) {
|
||||||
|
strides = rewriter.getI64VectorAttr(
|
||||||
|
{op.window_strides().getValue().getValue<int64_t>(1),
|
||||||
|
op.window_strides().getValue().getValue<int64_t>(2)});
|
||||||
|
} else {
|
||||||
|
strides = rewriter.getI64VectorAttr({1, 1});
|
||||||
|
}
|
||||||
|
Attribute dilations;
|
||||||
|
if (op.window_dilations()) {
|
||||||
|
dilations = rewriter.getI64VectorAttr(
|
||||||
|
{op.window_dilations().getValue().getValue<int64_t>(1),
|
||||||
|
op.window_dilations().getValue().getValue<int64_t>(2)});
|
||||||
|
} else {
|
||||||
|
dilations = rewriter.getI64VectorAttr({1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, result_type.getShape(), result_type.getElementType());
|
||||||
|
Value init_value = args[1];
|
||||||
|
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
||||||
|
Value filled_init_tensor =
|
||||||
|
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
||||||
|
.getResult(0);
|
||||||
|
auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp {
|
||||||
|
return cast<linalg::LinalgOp>(
|
||||||
|
rewriter
|
||||||
|
.create<std::remove_pointer_t<decltype(type_ptr)>>(
|
||||||
|
loc, ArrayRef<Type>{result_type},
|
||||||
|
ValueRange{args[0], fake_window_dims.getResult()},
|
||||||
|
filled_init_tensor, dilations, strides)
|
||||||
|
.getOperation());
|
||||||
|
};
|
||||||
|
linalg::LinalgOp pooling_op;
|
||||||
|
PoolingType pooling_type = getPoolingType(op.body());
|
||||||
|
switch (pooling_type) {
|
||||||
|
case PoolingType::kMin: {
|
||||||
|
pooling_op = create_op(static_cast<linalg::PoolingNHWCMinOp*>(nullptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case PoolingType::kMax: {
|
||||||
|
pooling_op = create_op(static_cast<linalg::PoolingNHWCMaxOp*>(nullptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case PoolingType::kAdd: {
|
||||||
|
pooling_op = create_op(static_cast<linalg::PoolingNHWCSumOp*>(nullptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, pooling_op->getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -1846,6 +1960,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
NormalConvOpOnTensorsConversion,
|
NormalConvOpOnTensorsConversion,
|
||||||
DepthwiseConvOpOnTensorsConversion,
|
DepthwiseConvOpOnTensorsConversion,
|
||||||
ReduceOnTensorsConversion,
|
ReduceOnTensorsConversion,
|
||||||
|
ReduceWindowOpOnTensorsConversion,
|
||||||
PadOpOnTensorsConversion>(context);
|
PadOpOnTensorsConversion>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||||
|
|
|
@ -1737,3 +1737,105 @@ func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>,
|
||||||
// CHECK-SAME: {strides = dense<2> : tensor<2xi64>}
|
// CHECK-SAME: {strides = dense<2> : tensor<2xi64>}
|
||||||
// CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
|
// CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
|
||||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
|
// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window_min_nhwc(%arg0: tensor<1x18x18x64xf32>,
|
||||||
|
%arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
|
||||||
|
%0 = "mhlo.reduce_window"(%arg0, %arg1) ( {
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
|
||||||
|
%1 = mhlo.minimum %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
|
||||||
|
return %0 : tensor<1x8x8x64xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @reduce_window_min_nhwc
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||||
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
|
||||||
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_min
|
||||||
|
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window_max_nhwc(%arg0: tensor<1x18x18x64xf32>,
|
||||||
|
%arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
|
||||||
|
%0 = "mhlo.reduce_window"(%arg0, %arg1) ( {
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
|
||||||
|
%1 = mhlo.maximum %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
|
||||||
|
return %0 : tensor<1x8x8x64xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @reduce_window_max_nhwc
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||||
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
|
||||||
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max
|
||||||
|
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window_sum_nhwc(%arg0: tensor<1x18x18x64xf32>,
|
||||||
|
%arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
|
||||||
|
%0 = "mhlo.reduce_window"(%arg0, %arg1) ( {
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
|
||||||
|
return %0 : tensor<1x8x8x64xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @reduce_window_sum_nhwc
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||||
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
|
||||||
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum
|
||||||
|
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1x8x8x64xf32> {
|
||||||
|
%0 = constant dense<0xFF800000> : tensor<f32>
|
||||||
|
%1 = "mhlo.reduce_window"(%arg0, %0) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2 : tensor<f32>):
|
||||||
|
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
|
||||||
|
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||||
|
}) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||||
|
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
|
||||||
|
return %1 : tensor<1x8x8x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @reduce_window_max_nhwc
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor<f32>
|
||||||
|
// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
|
||||||
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32
|
||||||
|
// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[CST]][] : tensor<f32>
|
||||||
|
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
|
||||||
|
// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max
|
||||||
|
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
|
||||||
|
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||||
|
|
Loading…
Reference in New Issue