Remove `results_escape_functions` from HloLegalizeToLhlo

PiperOrigin-RevId: 340464958
This commit is contained in:
Sean Silva 2020-11-03 09:49:13 -08:00 committed by TensorFlow MLIR Team
parent 46dac6955b
commit d3ea3abdec
7 changed files with 199 additions and 233 deletions

View File

@ -48,12 +48,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the
/// function. Otherwise, the signature is rewritten with extra arguments for the
/// buffers that are to be used for results.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_functions = false);
/// buffers if necessary.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();

View File

@ -504,12 +504,7 @@ struct HloLegalizeToLhlo
public:
HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
this->results_escape_function = o.results_escape_function.getValue();
}
explicit HloLegalizeToLhlo(bool results_escape_function) {
this->results_escape_function.setValue(results_escape_function);
}
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
void runOnOperation() override {
OwningRewritePatternList patterns;
@ -542,13 +537,6 @@ struct HloLegalizeToLhlo
isMemRefType);
});
auto kind = results_escape_function
? BufferizeTypeConverter::KeepAsFunctionResult
: BufferizeTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind);
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp>(
@ -559,13 +547,6 @@ struct HloLegalizeToLhlo
std::move(patterns))))
signalPassFailure();
}
private:
Option<bool> results_escape_function{
*this, "results-escape-function",
llvm::cl::desc(
"Allocate the results of functions within the functions body"),
llvm::cl::init(false)};
};
} // namespace
@ -625,9 +606,8 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_function) {
return std::make_unique<HloLegalizeToLhlo>(results_escape_function);
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
return std::make_unique<HloLegalizeToLhlo>();
}
} // namespace mhlo

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @trivial_broadcast_wrapper() : () -> ()

View File

@ -1,5 +1,5 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting \
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -copy-removal -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @main() -> () {
call @reshape_with_static_shape_size_matrix_to_1D() : () -> ()

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -1,13 +1,12 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
// BOTH-LABEL: func @attrs
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -17,16 +16,13 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "lmhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK-NOT: "lmhlo.copy"
// CHECK-NEXT: return %[[ARG0]]
// -----
// BOTH-LABEL: func @func_op_long
// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
@ -35,91 +31,87 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// -----
// BOTH-LABEL: func @fusion
// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: return
return
}
// -----
// BOTH-LABEL: func @copy
// CHECK-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @exp
// CHECK-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @log
// CHECK-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @select
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_pred = tensor_load %pred : memref<2x2xi1>
@ -127,34 +119,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @compare
// CHECK-LABEL: func @compare
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
// CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
// -----
// BOTH-LABEL: func @broadcast
// CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
@ -163,56 +155,56 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
func @external_func() -> tensor<3xi64>
// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// BOTH-LABEL: func @dyn_broadcast
// CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%c1 = constant 1 : i64
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// BOTH: %[[C2:.*]] = constant 2 : index
// BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[C1__:.*]] = constant 1 : index
// BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// BOTH: %[[C0___:.*]] = constant 0 : index
// BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[C2_:.*]] = constant 2 : index
// BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// BOTH: %[[C1___:.*]] = constant 1 : index
// BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map>
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map>
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
@ -221,7 +213,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// -----
// BOTH-LABEL: func @complex
// CHECK-LABEL: func @complex
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) {
@ -229,14 +221,14 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return
}
// -----
// BOTH-LABEL: func @complex_dyn
// CHECK-LABEL: func @complex_dyn
func @complex_dyn(%real: memref<?xf32>,
%imag: memref<?xf32>,
%result: memref<?xcomplex<f32>>) {
@ -244,50 +236,50 @@ func @complex_dyn(%real: memref<?xf32>,
%tensor_imag = tensor_load %imag : memref<?xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<?xcomplex<f32>>
return
}
// -----
// BOTH-LABEL: func @real
// CHECK-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @real_dyn
// CHECK-LABEL: func @real_dyn
func @real_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<?xf32>
return
}
// -----
// BOTH-LABEL: func @imag
// CHECK-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @gather
// CHECK-LABEL: func @gather
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
%tensor_operand = tensor_load %operand : memref<13x7xf32>
%tensor_idxs = tensor_load %idxs : memref<5xi32>
@ -302,176 +294,176 @@ func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5
, name = "gather.71"
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
// BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<5x7xf32>
return
}
// -----
// BOTH-LABEL: func @imag_dyn
// CHECK-LABEL: func @imag_dyn
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<?xf32>
return
}
// -----
// BOTH-LABEL: func @iota
// CHECK-LABEL: func @iota
func @iota(%result: memref<10xi32>) {
%tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
// CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
// -----
// BOTH-LABEL: func @abs
// CHECK-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @ceil
// CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @convert
// CHECK-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
// CHECK-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @cos
// CHECK-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @floor
// CHECK-LABEL: func @floor
func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.floor"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @neg
// CHECK-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @not
// CHECK-LABEL: func @not
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
%tensor_operand = tensor_load %operand : memref<2x2xi32>
%tensor_result = "mhlo.not"(%tensor_operand)
: (tensor<2x2xi32>) -> tensor<2x2xi32>
// BOTH: "lmhlo.not"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.not"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// BOTH-LABEL: func @rsqrt
// CHECK-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @sign
// CHECK-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @sqrt
// CHECK-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @tanh
// CHECK-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @remainder
// CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -479,61 +471,60 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// -----
// Dynamic shape binary element-wise operation.
// BOTH-LABEL: func @add_dyn
// CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// Dynamic shape unary element-wise operation.
// BOTH-LABEL: func @tanh_dyn
// CHECK-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// BOTH-LABEL: func @dot
// CHECK-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// dot_dimension_numbers = {
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
@ -542,22 +533,21 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
// CHECK: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32>
}
// -----
// BOTH-LABEL: func @conv
// CHECK-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]>
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]>
%out = "mhlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
@ -581,18 +571,18 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
// -----
// BOTH-LABEL: func @reduce
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// BOTH: %[[OUT:.*]] = alloc() : memref<1xf32>
// BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// BOTH: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// BOTH-SAME: %[[ARG3:.*]]: memref<f32>):
// BOTH: %[[TMP:.*]] = alloc() : memref<f32>
// BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// BOTH: "lmhlo.terminator"() : () -> ()
// BOTH: }) {dimensions = dense<1> : tensor<1xi64>}
// BOTH-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// CHECK: "lmhlo.terminator"() : () -> ()
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>}
// CHECK-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
@ -604,25 +594,25 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// -----
// BOTH-LABEL: func @transpose
// CHECK-LABEL: func @transpose
func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
// BOTH-NOT: tensor_store
// CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
// CHECK-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @custom_call
// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
// CHECK-LABEL: func @custom_call
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
// BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
{backend_config = "", call_target_name = "foo", has_side_effect = false}
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
@ -632,10 +622,10 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
// ----
// BOTH-LABEL: func @isfinite
// CHECK-LABEL: func @isfinite
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
// BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
// CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
%result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1>
tensor_store %result_tensor, %result: memref<2x2xi1>
return
@ -644,19 +634,19 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
// -----
// Test that assuming ops propagate memref types.
// BOTH-LABEL: func @shape_assuming_memref
// CHECK-LABEL: func @shape_assuming_memref
func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
%1 = shape.const_witness true
// BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
// CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
%2 = shape.assuming %1 -> (tensor<?xf16>) {
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
// BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
shape.assuming_yield %7 : tensor<?xf16>
}
return %2 : tensor<?xf16>