From 96f8771ed718174a8d38386abc4521dcc7f124ff Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 2 Feb 2021 00:45:39 -0800 Subject: [PATCH] Add MLIR generated kernel for Angle kernel. This also requires a canonicalization pattern to remove a redundant dynamic reshape from rank 1 to rank 1. PiperOrigin-RevId: 355113135 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 35 ++++++++++++++++++++++++++++++++++ tests/canonicalize.mlir | 18 +++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 47b0765..a6ca692 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1310,6 +1310,40 @@ class DynamicReshapeOpNotActuallyDynamic } }; +// Canonicalizes +// %0 = some_op(%tensor) +// %1 = "mhlo.dynamic_reshape"(%0, %shape) +// (tensor, tensor<1xindex>) -> tensor +// ... uses of %1. +// +// into +// +// ... uses of %0. +// This canonicalization is only correct if the input is correct! +// TODO(b/178779691): Use a more sophisticated canonicalization that preserves +// errors in input, and still allows us to get rid of redundant reshapes. +class RemoveRedundantRank1DynamicReshape + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + auto type = op.result().getType().dyn_cast(); + if (!type || type.getRank() != 1 || type.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "requires rank 1 shape tensor with dynamic dimension"); + } + auto operand_type = op.operand().getType().dyn_cast(); + if (!operand_type || operand_type.getRank() != 1 || + operand_type.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "requires rank 1 shape tensor with dynamic dimension"); + } + rewriter.replaceOp(op, {op.operand()}); + return success(); + } +}; + // Canonicalizes // %0 = "mhlo.dynamic_reshape"(%tensor, %shape) // %1 = same_operands_and_result_shape_op(%tensor) @@ -1354,6 +1388,7 @@ void DynamicReshapeOp::getCanonicalizationPatterns( DynamicReshapeOpSameShapeOpResult, RemoveRedundantDynamicBroadcast, RemoveRedundantDynamicReshape, + RemoveRedundantRank1DynamicReshape, ShapeOfDynamicReshape >(context); // clang-format on diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 3b6cf16..5d48701 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -594,12 +594,26 @@ func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> return %1 : tensor<2xindex> } +// CHECK-LABEL: func @dynamic_reshape_rank_1_to_rank_1 +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +func @dynamic_reshape_rank_1_to_rank_1(%arg0: tensor>, + %shape: tensor) -> tensor { + // CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.real"([[ARG0]]) : (tensor>) -> tensor + // CHECK: return [[RES]] + %0 = "mhlo.real"(%arg0): (tensor>) -> tensor + %1 = shape.shape_of %arg0 : tensor> -> tensor<1xindex> + %2 = shape.num_elements %1 : tensor<1xindex> -> index + %3 = tensor.from_elements %2 : tensor<1xindex> + %4 = "mhlo.dynamic_reshape"(%0, %3) + : (tensor, tensor<1xindex>) -> tensor + return %4 : tensor +} + // CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor, %shape: tensor) -> tensor { - // CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.dynamic_reshape"([[ARG0]], %{{[a-zA-Z0-9]+}}) : (tensor, tensor<1xindex>) -> tensor - // CHECK: return [[RES]] + // CHECK: return [[ARG0]] %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor<*xf16> %1 = shape.shape_of %0 : tensor<*xf16> -> tensor %2 = shape.num_elements %1 : tensor -> index