mlir-hlo/tests/hlo-legalize-to-lhlo-only-d...

35 lines
1.7 KiB
MLIR
Raw Normal View History

PR #49970: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49970 1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim 2, add a flag `convert-to-lmhlo-only` to seperate following two case: - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo counterparts, do not apply any optimization (e.g. elide any buffer copy). Buffer optimization is not easy in dynamic shape world especially when involving control flow, thus we leave this to another dedicated pass. - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect directly and Lowers others to their lmhlo counterparts. Copybara import of the project: -- 562bd65a368f6194405c4ae6900e3b4388a5ec03 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim 1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim 2, add a flag `convert-to-lmhlo-only` to seperate following two case: - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo counterparts, do not apply any optimization (e.g. elide any buffer copy). Buffer optimization is not easy in dynamic shape world especially when involving control flow, thus we leave this to another dedicated pass. - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect directly and Lowers others to their lmhlo counterparts. PiperOrigin-RevId: 377603395
2021-06-05 06:35:08 +08:00
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=convert-to-lmhlo-only=true \
// RUN: -canonicalize -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
// CHECK-LABEL: func @dynamic_reshape
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
func @dynamic_reshape(%lhs: tensor<?x?xf32>, %rhs: tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
// CHECK: "lmhlo.dynamic_reshape"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
// CHECK: return %[[OUTPUT]]
%result = "mhlo.dynamic_reshape"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
func @dynamic_broadcast_in_dim(%operand: tensor<?x?xf32>, %shape: tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
// CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
// CHECK: return %[[OUTPUT]]
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}