diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index bb995bb..3d770e7 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1655,6 +1655,120 @@ struct DepthwiseConvOpOnTensorsConversion } }; +struct ReduceWindowOpOnTensorsConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + /// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of + /// the pooling is determined based on the body of the reduce window + /// operation. This class enumerates the different variants. + enum class PoolingType { + kMin, + kMax, + kAdd, + }; + + static PoolingType getPoolingType(Region& region) { + assert(region.getBlocks().size() == 1 && + "expected the region has exactlly one block"); + Block& block = region.front(); + assert(block.getOperations().size() == 2 && + "expected the block has exactlly two operations"); + auto op = block.begin(); + if (isa(op)) return PoolingType::kMin; + if (isa(op)) return PoolingType::kMax; + if (isa(op)) return PoolingType::kAdd; + + llvm_unreachable("unknown pooling type"); + } + + LogicalResult matchAndRewrite( + mhlo::ReduceWindowOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const override { + auto loc = op.getLoc(); + auto result_type = op.getResult().getType().cast(); + if (result_type.getRank() != 4) { + return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op"); + } + + // Create a fake window dimension. + SmallVector shapes; + shapes.push_back(op.window_dimensions().getValue(1)); + shapes.push_back(op.window_dimensions().getValue(2)); + auto fake_window_dims = rewriter.create( + loc, shapes, result_type.getElementType()); + + if (op.window_strides() && + (op.window_strides().getValue().getValue(0) != 1 || + op.window_strides().getValue().getValue(3) != 1)) { + return rewriter.notifyMatchFailure( + op, "expected window_strides to be [1,x,y,1]"); + } + if (op.window_dimensions() && + (op.window_dimensions().getValue(0) != 1 || + op.window_dimensions().getValue(3) != 1)) { + return rewriter.notifyMatchFailure( + op, "expected window_dimensions to be [1,x,y,1]"); + } + + if (!args[0].getType().cast().getElementType().isF32()) { + return rewriter.notifyMatchFailure(op, "expected element type to be f32"); + } + + Attribute strides; + if (op.window_stridesAttr()) { + strides = rewriter.getI64VectorAttr( + {op.window_strides().getValue().getValue(1), + op.window_strides().getValue().getValue(2)}); + } else { + strides = rewriter.getI64VectorAttr({1, 1}); + } + Attribute dilations; + if (op.window_dilations()) { + dilations = rewriter.getI64VectorAttr( + {op.window_dilations().getValue().getValue(1), + op.window_dilations().getValue().getValue(2)}); + } else { + dilations = rewriter.getI64VectorAttr({1, 1}); + } + + Value init_tensor = rewriter.create( + loc, result_type.getShape(), result_type.getElementType()); + Value init_value = args[1]; + init_value = rewriter.create(loc, init_value); + Value filled_init_tensor = + rewriter.create(loc, init_tensor, init_value) + .getResult(0); + auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp { + return cast( + rewriter + .create>( + loc, ArrayRef{result_type}, + ValueRange{args[0], fake_window_dims.getResult()}, + filled_init_tensor, dilations, strides) + .getOperation()); + }; + linalg::LinalgOp pooling_op; + PoolingType pooling_type = getPoolingType(op.body()); + switch (pooling_type) { + case PoolingType::kMin: { + pooling_op = create_op(static_cast(nullptr)); + break; + } + case PoolingType::kMax: { + pooling_op = create_op(static_cast(nullptr)); + break; + } + case PoolingType::kAdd: { + pooling_op = create_op(static_cast(nullptr)); + break; + } + } + rewriter.replaceOp(op, pooling_op->getResult(0)); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1846,6 +1960,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, NormalConvOpOnTensorsConversion, DepthwiseConvOpOnTensorsConversion, ReduceOnTensorsConversion, + ReduceWindowOpOnTensorsConversion, PadOpOnTensorsConversion>(context); // clang-format on patterns->insert, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index e233622..714cd0a 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1737,3 +1737,105 @@ func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>, // CHECK-SAME: {strides = dense<2> : tensor<2xi64>} // CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> + +// ----- + +func @reduce_window_min_nhwc(%arg0: tensor<1x18x18x64xf32>, + %arg1: tensor) -> tensor<1x8x8x64xf32>{ + %0 = "mhlo.reduce_window"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3 : tensor): + %1 = mhlo.minimum %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor) -> tensor<1x8x8x64xf32> + return %0 : tensor<1x8x8x64xf32> +} +// CHECK-LABEL: func @reduce_window_min_nhwc +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> +// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> +// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_min +// CHECK-SAME: {dilations = dense<1> : vector<2xi64> +// CHECK-SAME: strides = dense<2> : vector<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> + +// ----- + +func @reduce_window_max_nhwc(%arg0: tensor<1x18x18x64xf32>, + %arg1: tensor) -> tensor<1x8x8x64xf32>{ + %0 = "mhlo.reduce_window"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3 : tensor): + %1 = mhlo.maximum %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor) -> tensor<1x8x8x64xf32> + return %0 : tensor<1x8x8x64xf32> +} +// CHECK-LABEL: func @reduce_window_max_nhwc +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> +// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> +// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max +// CHECK-SAME: {dilations = dense<1> : vector<2xi64> +// CHECK-SAME: strides = dense<2> : vector<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> + +// ----- + +func @reduce_window_sum_nhwc(%arg0: tensor<1x18x18x64xf32>, + %arg1: tensor) -> tensor<1x8x8x64xf32>{ + %0 = "mhlo.reduce_window"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3 : tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor) -> tensor<1x8x8x64xf32> + return %0 : tensor<1x8x8x64xf32> +} +// CHECK-LABEL: func @reduce_window_sum_nhwc +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> +// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> +// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum +// CHECK-SAME: {dilations = dense<1> : vector<2xi64> +// CHECK-SAME: strides = dense<2> : vector<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> + +// ----- + +func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1x8x8x64xf32> { + %0 = constant dense<0xFF800000> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2 : tensor): + %2 = mhlo.maximum %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor) -> tensor<1x8x8x64xf32> + return %1 : tensor<1x8x8x64xf32> +} + +// ----- +// CHECK-LABEL: func @reduce_window_max_nhwc +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor +// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32 +// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[CST]][] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> +// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max +// CHECK-SAME: {dilations = dense<1> : vector<2xi64> +// CHECK-SAME: strides = dense<2> : vector<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>