Add MLIR generated kernel for Angle kernel.
This also requires a canonicalization pattern to remove a redundant dynamic reshape from rank 1 to rank 1. PiperOrigin-RevId: 355113135
This commit is contained in:
parent
9d682343a9
commit
96f8771ed7
|
@ -1310,6 +1310,40 @@ class DynamicReshapeOpNotActuallyDynamic
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Canonicalizes
|
||||||
|
// %0 = some_op(%tensor)
|
||||||
|
// %1 = "mhlo.dynamic_reshape"(%0, %shape)
|
||||||
|
// (tensor<?xT>, tensor<1xindex>) -> tensor<?xT>
|
||||||
|
// ... uses of %1.
|
||||||
|
//
|
||||||
|
// into
|
||||||
|
//
|
||||||
|
// ... uses of %0.
|
||||||
|
// This canonicalization is only correct if the input is correct!
|
||||||
|
// TODO(b/178779691): Use a more sophisticated canonicalization that preserves
|
||||||
|
// errors in input, and still allows us to get rid of redundant reshapes.
|
||||||
|
class RemoveRedundantRank1DynamicReshape
|
||||||
|
: public OpRewritePattern<DynamicReshapeOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(DynamicReshapeOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto type = op.result().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type || type.getRank() != 1 || type.hasStaticShape()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "requires rank 1 shape tensor with dynamic dimension");
|
||||||
|
}
|
||||||
|
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!operand_type || operand_type.getRank() != 1 ||
|
||||||
|
operand_type.hasStaticShape()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "requires rank 1 shape tensor with dynamic dimension");
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {op.operand()});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Canonicalizes
|
// Canonicalizes
|
||||||
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
// %0 = "mhlo.dynamic_reshape"(%tensor, %shape)
|
||||||
// %1 = same_operands_and_result_shape_op(%tensor)
|
// %1 = same_operands_and_result_shape_op(%tensor)
|
||||||
|
@ -1354,6 +1388,7 @@ void DynamicReshapeOp::getCanonicalizationPatterns(
|
||||||
DynamicReshapeOpSameShapeOpResult,
|
DynamicReshapeOpSameShapeOpResult,
|
||||||
RemoveRedundantDynamicBroadcast,
|
RemoveRedundantDynamicBroadcast,
|
||||||
RemoveRedundantDynamicReshape,
|
RemoveRedundantDynamicReshape,
|
||||||
|
RemoveRedundantRank1DynamicReshape,
|
||||||
ShapeOfDynamicReshape
|
ShapeOfDynamicReshape
|
||||||
>(context);
|
>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
|
@ -594,12 +594,26 @@ func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) ->
|
||||||
return %1 : tensor<2xindex>
|
return %1 : tensor<2xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_reshape_rank_1_to_rank_1
|
||||||
|
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||||
|
func @dynamic_reshape_rank_1_to_rank_1(%arg0: tensor<?xcomplex<f32>>,
|
||||||
|
%shape: tensor<?xindex>) -> tensor<?xf32> {
|
||||||
|
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.real"([[ARG0]]) : (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||||
|
// CHECK: return [[RES]]
|
||||||
|
%0 = "mhlo.real"(%arg0): (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
||||||
|
%1 = shape.shape_of %arg0 : tensor<?xcomplex<f32>> -> tensor<1xindex>
|
||||||
|
%2 = shape.num_elements %1 : tensor<1xindex> -> index
|
||||||
|
%3 = tensor.from_elements %2 : tensor<1xindex>
|
||||||
|
%4 = "mhlo.dynamic_reshape"(%0, %3)
|
||||||
|
: (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
return %4 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape
|
// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape
|
||||||
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
|
||||||
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
|
||||||
func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?xindex>) -> tensor<?xf16> {
|
func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?xindex>) -> tensor<?xf16> {
|
||||||
// CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.dynamic_reshape"([[ARG0]], %{{[a-zA-Z0-9]+}}) : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
// CHECK: return [[ARG0]]
|
||||||
// CHECK: return [[RES]]
|
|
||||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16>
|
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16>
|
||||||
%1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex>
|
%1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex>
|
||||||
%2 = shape.num_elements %1 : tensor<?xindex> -> index
|
%2 = shape.num_elements %1 : tensor<?xindex> -> index
|
||||||
|
|
Loading…
Reference in New Issue