From d3ea3abdec3e604245620a491d6812ec124bc438 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 3 Nov 2020 09:49:13 -0800 Subject: [PATCH] Remove `results_escape_functions` from HloLegalizeToLhlo PiperOrigin-RevId: 340464958 --- .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 8 +- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 26 +- tests/end2end/broadcast.mlir | 2 +- tests/end2end/reduce.mlir | 2 +- tests/end2end/reshape.mlir | 2 +- tests/hlo-legalize-to-lhlo-unranked.mlir | 2 +- tests/hlo-legalize-to-lhlo.mlir | 390 +++++++++--------- 7 files changed, 199 insertions(+), 233 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index b1933f6..3345884 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -48,12 +48,8 @@ std::unique_ptr> createLegalizeToStdPass(); std::unique_ptr createChloLegalizeToHloPass(); /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary -/// buffers if necessary. If `results_escape_functions` is set to true, -/// allocated buffers for function results will be returned and escape the -/// function. Otherwise, the signature is rewritten with extra arguments for the -/// buffers that are to be used for results. -std::unique_ptr> createLegalizeToLhloPass( - bool results_escape_functions = false); +/// buffers if necessary. +std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 1cf3915..ae897a6 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -504,12 +504,7 @@ struct HloLegalizeToLhlo public: HloLegalizeToLhlo() = default; - HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { - this->results_escape_function = o.results_escape_function.getValue(); - } - explicit HloLegalizeToLhlo(bool results_escape_function) { - this->results_escape_function.setValue(results_escape_function); - } + HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {} void runOnOperation() override { OwningRewritePatternList patterns; @@ -542,13 +537,6 @@ struct HloLegalizeToLhlo isMemRefType); }); - auto kind = results_escape_function - ? BufferizeTypeConverter::KeepAsFunctionResult - : BufferizeTypeConverter::AppendToArgumentsList; - converter.setResultConversionKind( - kind); - converter.setResultConversionKind(kind); - populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateWithBufferizeOpConversionPatterns( @@ -559,13 +547,6 @@ struct HloLegalizeToLhlo std::move(patterns)))) signalPassFailure(); } - - private: - Option results_escape_function{ - *this, "results-escape-function", - llvm::cl::desc( - "Allocate the results of functions within the functions body"), - llvm::cl::init(false)}; }; } // namespace @@ -625,9 +606,8 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, // clang-format on } -std::unique_ptr> createLegalizeToLhloPass( - bool results_escape_function) { - return std::make_unique(results_escape_function); +std::unique_ptr> createLegalizeToLhloPass() { + return std::make_unique(); } } // namespace mhlo diff --git a/tests/end2end/broadcast.mlir b/tests/end2end/broadcast.mlir index 9bd0c09..dd2c311 100644 --- a/tests/end2end/broadcast.mlir +++ b/tests/end2end/broadcast.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s func @main() -> () { call @trivial_broadcast_wrapper() : () -> () diff --git a/tests/end2end/reduce.mlir b/tests/end2end/reduce.mlir index 8eb6553..b018b73 100644 --- a/tests/end2end/reduce.mlir +++ b/tests/end2end/reduce.mlir @@ -1,5 +1,5 @@ // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ -// RUN: -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting \ +// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \ // RUN: -buffer-deallocation -copy-removal -canonicalize -cse \ // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ // RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \ diff --git a/tests/end2end/reshape.mlir b/tests/end2end/reshape.mlir index e8dc509..8311546 100644 --- a/tests/end2end/reshape.mlir +++ b/tests/end2end/reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s func @main() -> () { call @reshape_with_static_shape_size_matrix_to_1D() : () -> () diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir index 6400943..7e3d13e 100644 --- a/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s // CHECK-LABEL: func @func_op_unranked_arg_result func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 608973f..7f880bd 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -1,13 +1,12 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s -// BOTH-LABEL: func @attrs +// CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.exponential"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -17,16 +16,13 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %arg0 : tensor<4xf32> } -// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) -// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () -// PRE-NEXT: return -// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] -// ESC-NOT: "lmhlo.copy" -// ESC-NEXT: return %[[ARG0]] +// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] +// CHECK-NOT: "lmhlo.copy" +// CHECK-NEXT: return %[[ARG0]] // ----- -// BOTH-LABEL: func @func_op_long +// CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> %2 = mhlo.add %arg0, %1 : tensor<4xf32> @@ -35,91 +31,87 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> %5 = mhlo.multiply %2, %4 : tensor<4xf32> return %5 : tensor<4xf32> } -// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) -// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> -// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) -// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) -// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> -// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) -// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> -//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) -// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> -// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) -// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> -// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> -// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () -// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> -// PRE-NEXT: return -// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32> +// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> +// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> +//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> +// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> +// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32> // ----- -// BOTH-LABEL: func @fusion +// CHECK-LABEL: func @fusion func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { - // BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) - // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> + // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) + // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) - // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) - // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> - // BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) + // CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> + // CHECK-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> - // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> - // BOTH-NEXT: return + // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> + // CHECK-NEXT: return return } // ----- -// BOTH-LABEL: func @copy +// CHECK-LABEL: func @copy func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @exp +// CHECK-LABEL: func @exp func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.exponential"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @log +// CHECK-LABEL: func @log func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.log"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.log"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @select +// CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_pred = tensor_load %pred : memref<2x2xi1> @@ -127,34 +119,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @compare +// CHECK-LABEL: func @compare func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - // BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + // CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} tensor_store %tensor_result, %result : memref<2x2xi1> return } // ----- -// BOTH-LABEL: func @broadcast +// CHECK-LABEL: func @broadcast func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_operand = tensor_load %operand : memref<5xf32> %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> - // BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<10x5xf32> return } @@ -163,56 +155,56 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { func @external_func() -> tensor<3xi64> -// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> -// BOTH-LABEL: func @dyn_broadcast +// CHECK-LABEL: func @dyn_broadcast func @dyn_broadcast(%operand: memref) { - // BOTH-SAME: (%[[OPERAND:.*]]: memref) + // CHECK-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref %c1 = constant 1 : i64 %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64> %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor - // BOTH: %[[SHAPE:.*]] = tensor_from_elements - // BOTH: %[[C0:.*]] = constant 0 : index - // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> - // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index - // BOTH: %[[C1:.*]] = constant 1 : index - // BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> - // BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index - // BOTH: %[[C2:.*]] = constant 2 : index - // BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> - // BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index - // BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) + // CHECK: %[[SHAPE:.*]] = tensor_from_elements + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> + // CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> + // CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> + // CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) - // BOTH: %[[C0_:.*]] = constant 0 : index - // BOTH: %[[C1_:.*]] = constant 1 : index + // CHECK: %[[C0_:.*]] = constant 0 : index + // CHECK: %[[C1_:.*]] = constant 1 : index - // BOTH: %[[C1__:.*]] = constant 1 : index - // BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> - // BOTH: %[[C0___:.*]] = constant 0 : index - // BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref - // BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index - // BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] - // BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index + // CHECK: %[[C1__:.*]] = constant 1 : index + // CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> + // CHECK: %[[C0___:.*]] = constant 0 : index + // CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref + // CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index + // CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] + // CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index - // BOTH: %[[C2_:.*]] = constant 2 : index - // BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> - // BOTH: %[[C1___:.*]] = constant 1 : index - // BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref - // BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index - // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] - // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index + // CHECK: %[[C2_:.*]] = constant 2 : index + // CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> + // CHECK: %[[C1___:.*]] = constant 1 : index + // CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref + // CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index + // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] + // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index - // BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast - // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) - // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] - // BOTH-SAME: : memref -> memref + // CHECK: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast + // CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) + // CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] + // CHECK-SAME: : memref -> memref - // BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { - // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> - // BOTH-SAME: } : (memref, memref) -> () + // CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { + // CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME: } : (memref, memref) -> () // Do not store the value back to avoid the tensor-store being rewritten to // a copy into the pre-allocated argument. @@ -221,7 +213,7 @@ func @dyn_broadcast(%operand: memref) { // ----- -// BOTH-LABEL: func @complex +// CHECK-LABEL: func @complex func @complex(%real: memref<2x2xf32>, %imag: memref<2x2xf32>, %result: memref<2x2xcomplex>) { @@ -229,14 +221,14 @@ func @complex(%real: memref<2x2xf32>, %tensor_imag = tensor_load %imag : memref<2x2xf32> %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> - // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xcomplex> return } // ----- -// BOTH-LABEL: func @complex_dyn +// CHECK-LABEL: func @complex_dyn func @complex_dyn(%real: memref, %imag: memref, %result: memref>) { @@ -244,50 +236,50 @@ func @complex_dyn(%real: memref, %tensor_imag = tensor_load %imag : memref %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) : (tensor, tensor) -> tensor> - // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref> return } // ----- -// BOTH-LABEL: func @real +// CHECK-LABEL: func @real func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> %tensor_result = "mhlo.real"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @real_dyn +// CHECK-LABEL: func @real_dyn func @real_dyn(%operand: memref>, %result: memref) { %tensor_operand = tensor_load %operand : memref> %tensor_result = "mhlo.real"(%tensor_operand) : (tensor>) -> tensor - // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.real"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref return } // ----- -// BOTH-LABEL: func @imag +// CHECK-LABEL: func @imag func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> %tensor_result = "mhlo.imag"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @gather +// CHECK-LABEL: func @gather func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) { %tensor_operand = tensor_load %operand : memref<13x7xf32> %tensor_idxs = tensor_load %idxs : memref<5xi32> @@ -302,176 +294,176 @@ func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5 , name = "gather.71" , slice_sizes = dense<[1, 7]> : tensor<2xi64> } : (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32> - // BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<5x7xf32> return } // ----- -// BOTH-LABEL: func @imag_dyn +// CHECK-LABEL: func @imag_dyn func @imag_dyn(%operand: memref>, %result: memref) { %tensor_operand = tensor_load %operand : memref> %tensor_result = "mhlo.imag"(%tensor_operand) : (tensor>) -> tensor - // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref return } // ----- -// BOTH-LABEL: func @iota +// CHECK-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> - // BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + // CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> return } // ----- -// BOTH-LABEL: func @abs +// CHECK-LABEL: func @abs func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @ceil +// CHECK-LABEL: func @ceil func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @convert +// CHECK-LABEL: func @convert func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.convert"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) - // BOTH-NOT: tensor_store + // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) + // CHECK-NOT: tensor_store tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @cos +// CHECK-LABEL: func @cos func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.cosine"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @floor +// CHECK-LABEL: func @floor func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.floor"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @neg +// CHECK-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.negate"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @not +// CHECK-LABEL: func @not func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) { %tensor_operand = tensor_load %operand : memref<2x2xi32> %tensor_result = "mhlo.not"(%tensor_operand) : (tensor<2x2xi32>) -> tensor<2x2xi32> - // BOTH: "lmhlo.not"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.not"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xi32> return } // ----- -// BOTH-LABEL: func @rsqrt +// CHECK-LABEL: func @rsqrt func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.rsqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @sign +// CHECK-LABEL: func @sign func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @sqrt +// CHECK-LABEL: func @sqrt func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.sqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @tanh +// CHECK-LABEL: func @tanh func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @remainder +// CHECK-LABEL: func @remainder func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + // CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -479,61 +471,60 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x // ----- // Dynamic shape binary element-wise operation. -// BOTH-LABEL: func @add_dyn +// CHECK-LABEL: func @add_dyn func @add_dyn(%lhs: tensor, %rhs: tensor) { %result = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - // BOTH: %[[C0:.*]] = constant 0 : index - // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref - // BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 - // BOTH: %[[C1:.*]] = constant 1 : index - // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref - // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // BOTH: %[[C0_:.*]] = constant 0 : index - // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> - // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // BOTH: %[[C1_:.*]] = constant 1 : index - // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> - // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index - // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref + // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref + // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> + // CHECK: %[[C0_:.*]] = constant 0 : index + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> + // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // CHECK: %[[C1_:.*]] = constant 1 : index + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> + // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () return } // ----- // Dynamic shape unary element-wise operation. -// BOTH-LABEL: func @tanh_dyn +// CHECK-LABEL: func @tanh_dyn func @tanh_dyn(%arg0: tensor) { %result = "mhlo.tanh"(%arg0) : (tensor) -> tensor - // BOTH: %[[C0:.*]] = constant 0 : index - // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref - // BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 - // BOTH: %[[C1:.*]] = constant 1 : index - // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref - // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // BOTH: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // BOTH: %[[C0_:.*]] = constant 0 : index - // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> - // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // BOTH: %[[C1_:.*]] = constant 1 : index - // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> - // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index - // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref + // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref + // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> + // CHECK: %[[C0_:.*]] = constant 0 : index + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> + // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // CHECK: %[[C1_:.*]] = constant 1 : index + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> + // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return } // ----- -// BOTH-LABEL: func @dot +// CHECK-LABEL: func @dot func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { -// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) -// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] -// BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { +// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] +// CHECK-NEXT: %[[ALLOC:.*]] = alloc +// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { // dot_dimension_numbers = { // lhs_batching_dimensions = dense<> : tensor<0xi64>, // lhs_contracting_dimensions = dense<1> : tensor<1xi64>, @@ -542,22 +533,21 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) -// ESC: return %[[ALLOC]] +// CHECK: return %[[ALLOC]] return %dot : tensor<1024x1024xf32> } // ----- -// BOTH-LABEL: func @conv +// CHECK-LABEL: func @conv func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { %c0 = constant 0 : index - // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> - // BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) - // BOTH-SAME: padding = dense<[ - // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> - // BOTH-SAME: rhs_dilation = dense<[1, 2]> - // BOTH-SAME: window_strides = dense<[2, 1]> + // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> + // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // CHECK-SAME: padding = dense<[ + // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> + // CHECK-SAME: rhs_dilation = dense<[1, 2]> + // CHECK-SAME: window_strides = dense<[2, 1]> %out = "mhlo.convolution"(%filter, %input) { batch_group_count = 1 : i64, dimension_numbers = { @@ -581,18 +571,18 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor // ----- -// BOTH-LABEL: func @reduce +// CHECK-LABEL: func @reduce func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { - // BOTH: %[[OUT:.*]] = alloc() : memref<1xf32> - // BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { - // BOTH: ^bb0(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, - // BOTH-SAME: %[[ARG3:.*]]: memref): - // BOTH: %[[TMP:.*]] = alloc() : memref - // BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) - // BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) - // BOTH: "lmhlo.terminator"() : () -> () - // BOTH: }) {dimensions = dense<1> : tensor<1xi64>} - // BOTH-SAME: : (memref<1x8xf32>, memref, memref<1xf32>) -> () + // CHECK: %[[OUT:.*]] = alloc() : memref<1xf32> + // CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { + // CHECK: ^bb0(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, + // CHECK-SAME: %[[ARG3:.*]]: memref): + // CHECK: %[[TMP:.*]] = alloc() : memref + // CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) + // CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) + // CHECK: "lmhlo.terminator"() : () -> () + // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} + // CHECK-SAME: : (memref<1x8xf32>, memref, memref<1xf32>) -> () %0 = "mhlo.reduce"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors %1 = mhlo.add %arg2, %arg3 : tensor @@ -604,25 +594,25 @@ func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { // ----- -// BOTH-LABEL: func @transpose +// CHECK-LABEL: func @transpose func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.transpose"(%tensor_operand) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} - // BOTH-NOT: tensor_store + // CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} + // CHECK-NOT: tensor_store tensor_store %tensor_result, %result : memref<2x2xf32> return } // ----- -// BOTH-LABEL: func @custom_call -// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>) +// CHECK-LABEL: func @custom_call +// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>) func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) { %arg0_tensor = tensor_load %arg0 : memref<2x2xf32> %arg1_tensor = tensor_load %arg1 : memref<2x3xf32> - // BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false} + // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false} %result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor) {backend_config = "", call_target_name = "foo", has_side_effect = false} : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16> @@ -632,10 +622,10 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre // ---- -// BOTH-LABEL: func @isfinite +// CHECK-LABEL: func @isfinite func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) { %arg0_tensor = tensor_load %arg0 : memref<2x2xf32> - // BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}}) + // CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}}) %result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1> tensor_store %result_tensor, %result: memref<2x2xi1> return @@ -644,19 +634,19 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) { // ----- // Test that assuming ops propagate memref types. -// BOTH-LABEL: func @shape_assuming_memref +// CHECK-LABEL: func @shape_assuming_memref func @shape_assuming_memref(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = shape.const_witness true - // BOTH: shape.assuming %{{.*}} -> (memref) + // CHECK: shape.assuming %{{.*}} -> (memref) %2 = shape.assuming %1 -> (tensor) { %3 = shape.shape_of %arg0 : tensor -> tensor %4 = tensor_cast %3 : tensor to tensor<1xindex> %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor - // BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref, memref, memref) -> () + // CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref, memref, memref) -> () %7 = mhlo.maximum %5, %6 : tensor - // BOTH: shape.assuming_yield %{{.*}} : memref + // CHECK: shape.assuming_yield %{{.*}} : memref shape.assuming_yield %7 : tensor } return %2 : tensor