From d4f2c767d33aff0f8c899237c5bc3f570de6f9d2 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 9 Nov 2020 04:23:54 -0800 Subject: [PATCH] [HLO] Fix HLO DynamicBroadcastInDimOp -> LHLO lowering. The conversion had a bug in computation of strides and sizes args for std.memref_reinterpret_cast. The previous version also relied on linalg::ReshapeOp to do broadcasting when the rank of the output was higher than the rank of the input. Now the broadcasting is entirely done via descriptor modification and linalg::ReshapeOp was replaced with CopyOp. PiperOrigin-RevId: 341379871 --- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 83 ++- tests/end2end/broadcast.mlir | 506 ++++++++++++++---- tests/hlo-legalize-to-lhlo.mlir | 93 ++-- 3 files changed, 492 insertions(+), 190 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index aca5977..6710d37 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -194,8 +194,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter Value transformed_operand = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); - rewriter.create( - loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); + rewriter.create(loc, transformed_operand, resultBuffer); rewriter.replaceOp(op, {resultBuffer}); @@ -211,48 +210,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); auto operand_shape = operand_type.getShape(); + auto operand_rank = operand_type.getRank(); - SmallVector sizes, strides; - sizes.reserve(operand_shape.size()); - strides.reserve(operand_shape.size()); + auto result_type = op.getType().cast(); + auto result_rank = result_type.getRank(); Value zero = b->create(loc, 0); Value one = b->create(loc, 1); - for (auto dim : llvm::enumerate(op.broadcast_dimensions())) { - Value broadcast_dim_value = - b->create(loc, dim.value().getSExtValue()); - Value result_dim_size = b->create( - loc, op.output_dimensions(), broadcast_dim_value); - Value operand_dim_size = - ShapedType::isDynamic(operand_shape[dim.index()]) - ? b->create(loc, operand, dim.index()).getResult() - : b->create(loc, operand_shape[dim.index()]) - .getResult(); - // TODO(pifon): Revisit if this cast is needed. Maybe we can use - // tensor for `output_dimensions` as well. + // Compute a reversed scan product. Compute the stride for the dimensions so + // far, working from minor to major dimensions. Additionally, save the + // operand shape Values to use in the next loop. + SmallVector operand_strides(operand_rank, one); + SmallVector operand_sizes(operand_rank, one); + Value stride_so_far = one; + for (int i = operand_rank - 1; i >= 0; --i) { + Value operand_dim_size = + ShapedType::isDynamic(operand_shape[i]) + ? b->create(loc, operand, i).getResult() + : b->create(loc, operand_shape[i]).getResult(); + operand_sizes[i] = operand_dim_size; + + operand_strides[i] = stride_so_far; + if (i > 0) { + stride_so_far = b->create(loc, stride_so_far, operand_dim_size); + } + } + + SmallVector sizes, strides; + sizes.reserve(result_rank); + strides.reserve(result_rank); + + DenseMap output_to_input_dim; + for (auto dim : llvm::enumerate(op.broadcast_dimensions())) { + output_to_input_dim[dim.value().getSExtValue()] = dim.index(); + } + for (int i = 0; i < result_rank; ++i) { + Value i_val = b->create(loc, i); + Value result_dim_size = + b->create(loc, op.output_dimensions(), i_val); if (!result_dim_size.getType().isIndex()) { result_dim_size = b->create(loc, result_dim_size, b->getIndexType()); } + sizes.push_back(result_dim_size); + + auto it = output_to_input_dim.find(i); + // If the rank of the output is greater than the rank of the input, i.e. + // there was no output dimension in the inverse broadcast_dimensions map + // we also set stride to 0 to emulate padding of the shape with 1s and the + // corresponding expansion. + if (it == output_to_input_dim.end()) { + strides.push_back(zero); + continue; + } // There can be two cases: - // 1) Operand dim == result dim => expansion is not needed => stride := 1. + // 1) Operand dim == result dim => expansion is not needed + // => stride flattened buffer stride // 2) Operand dim < result dim => expansion is needed => stride := 0. - Value is_expansion = b->create(loc, CmpIPredicate::slt, - operand_dim_size, result_dim_size); - strides.push_back( - b->create(loc, is_expansion, zero, one)); - - // Size of input dim can be set to the size of the corresponding output - // dimension for both cases. - sizes.push_back(result_dim_size); + int dim = it->second; + Value is_expansion = b->create( + loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size); + strides.push_back(b->create(loc, is_expansion, zero, + operand_strides[dim])); } // Type-erased memref type with static rank, dynamic sizes and strides. - SmallVector dynamic_layout(operand_shape.size(), + SmallVector dynamic_layout(result_rank, MemRefType::kDynamicStrideOrOffset); - SmallVector dynamic_shape(operand_shape.size(), + SmallVector dynamic_shape(result_rank, MemRefType::kDynamicSize); auto type_erased_memref_type = MemRefType::get( dynamic_shape, operand_type.getElementType(), diff --git a/tests/end2end/broadcast.mlir b/tests/end2end/broadcast.mlir index f8f6ce3..e2b71c7 100644 --- a/tests/end2end/broadcast.mlir +++ b/tests/end2end/broadcast.mlir @@ -1,4 +1,10 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo \ +// RUN: -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse \ +// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ +// RUN: -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s --dump-input=always func @main() -> () { call @trivial_broadcast_wrapper() : () -> () @@ -8,6 +14,9 @@ func @main() -> () { call @broadcast_in_Y_dim_transpose_wrapper() : () -> () call @broadcast_scalar_1d_wrapper() : () -> () call @broadcast_scalar_2d_wrapper() : () -> () + call @broadcast_to_the_same_shape() : () -> () + call @broadcast_1d_to_2d() : () -> () + call @broadcast_1d_to_2d_with_transpose() : () -> () return } @@ -15,199 +24,490 @@ func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface } func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } func @trivial_broadcast_wrapper() { - %input = alloc() : memref<3xf32> - %c1f32 = constant 1.0 : f32 - %c0 = constant 0 : index - store %c1f32, %input[%c0] : memref<3xf32> - %c2f32 = constant 2.0 : f32 - %c1 = constant 1 : index - store %c2f32, %input[%c1] : memref<3xf32> - %c3f32 = constant 3.0 : f32 - %c2 = constant 2 : index - store %c3f32, %input[%c2] : memref<3xf32> - %input_tensor = tensor_load %input : memref<3xf32> + %input_buf = alloc() : memref<3xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + %c1f32 = constant 1.0 : f32 + %c2f32 = constant 2.0 : f32 + %c3f32 = constant 3.0 : f32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + store %c1f32, %input_buf[%c0] : memref<3xf32> + store %c2f32, %input_buf[%c1] : memref<3xf32> + store %c3f32, %input_buf[%c2] : memref<3xf32> + %input = tensor_load %input_buf : memref<3xf32> + + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<3xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [2, 2, 2, 2] + // CHECK-NEXT: [3, 3, 3, 3] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [2, 2, 2, 2] + // CHECK-NEXT: [3, 3, 3, 3] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [2, 2, 2, 2] -// CHECK: [3, 3, 3, 3] func @broadcast_in_X_dim_wrapper() { - %input = alloc() : memref<1x4xf32> + %input_buf = alloc() : memref<1x4xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input[%c0, %c0] : memref<1x4xf32> + store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input[%c0, %c1] : memref<1x4xf32> + store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input[%c0, %c2] : memref<1x4xf32> + store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32> %c4f32 = constant 4.0 : f32 %c3 = constant 3 : index - store %c4f32, %input[%c0, %c3] : memref<1x4xf32> - %input_tensor = tensor_load %input : memref<1x4xf32> + store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32> + %input = tensor_load %input_buf : memref<1x4xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x4xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] + + // Test DynamicBroadcastInDimOp. + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 2, 3, 4] -// CHECK: [1, 2, 3, 4] -// CHECK: [1, 2, 3, 4] func @broadcast_in_Y_dim_wrapper() { - %input = alloc() : memref<3x1xf32> + %input_buf = alloc() : memref<3x1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input[%c0, %c0] : memref<3x1xf32> + store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input[%c1, %c0] : memref<3x1xf32> + store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input[%c2, %c0] : memref<3x1xf32> - %input_tensor = tensor_load %input : memref<3x1xf32> + store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32> + %input = tensor_load %input_buf : memref<3x1xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<3x1xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [2, 2, 2, 2] + // CHECK-NEXT: [3, 3, 3, 3] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [2, 2, 2, 2] + // CHECK-NEXT: [3, 3, 3, 3] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [2, 2, 2, 2] -// CHECK: [3, 3, 3, 3] func @broadcast_in_X_dim_transpose_wrapper() { - %input = alloc() : memref<4x1xf32> + %input_buf = alloc() : memref<4x1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input[%c0, %c0] : memref<4x1xf32> + store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input[%c1, %c0] : memref<4x1xf32> + store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input[%c2, %c0] : memref<4x1xf32> + store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32> %c4f32 = constant 4.0 : f32 %c3 = constant 3 : index - store %c4f32, %input[%c3, %c0] : memref<4x1xf32> - %input_tensor = tensor_load %input : memref<4x1xf32> + store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32> + %input = tensor_load %input_buf : memref<4x1xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<4x1xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] + + // Test DynamicBroadcastInDimOp. + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> + } : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] + // CHECK-NEXT: [1, 2, 3, 4] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 2, 3, 4] -// CHECK: [1, 2, 3, 4] -// CHECK: [1, 2, 3, 4] func @broadcast_in_Y_dim_transpose_wrapper() { - %input = alloc() : memref<1x3xf32> + %input_buf = alloc() : memref<1x3xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input[%c0, %c0] : memref<1x3xf32> + store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32> %c2f32 = constant 2.0 : f32 %c1 = constant 1 : index - store %c2f32, %input[%c0, %c1] : memref<1x3xf32> + store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32> %c3f32 = constant 3.0 : f32 %c2 = constant 2 : index - store %c3f32, %input[%c0, %c2] : memref<1x3xf32> - %input_tensor = tensor_load %input : memref<1x3xf32> + store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32> + %input = tensor_load %input_buf : memref<1x3xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor<1x3xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT-NEXT: [1, 1, 1, 1] + // CHECK-NEXT-NEXT: [2, 2, 2, 2] + // CHECK-NEXT-NEXT: [3, 3, 3, 3] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> + } : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT-NEXT: [1, 1, 1, 1] + // CHECK-NEXT-NEXT: [2, 2, 2, 2] + // CHECK-NEXT-NEXT: [3, 3, 3, 3] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [2, 2, 2, 2] -// CHECK: [3, 3, 3, 3] func @broadcast_scalar_1d_wrapper() { - %input = alloc() : memref<1xf32> + %input_buf = alloc() : memref<1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input[%c0] : memref<1xf32> - %input_tensor = tensor_load %input : memref<1xf32> + store %c1f32, %input_buf[%c0] : memref<1xf32> + %input = tensor_load %input_buf : memref<1xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<1xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [1, 1, 1, 1] func @broadcast_scalar_2d_wrapper() { - %input = alloc() : memref<1x1xf32> + %input_buf = alloc() : memref<1x1xf32> %c1f32 = constant 1.0 : f32 %c0 = constant 0 : index - store %c1f32, %input[%c0, %c0] : memref<1x1xf32> - %input_tensor = tensor_load %input : memref<1x1xf32> + store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32> + %input = tensor_load %input_buf : memref<1x1xf32> - %output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<1x1xf32>) -> tensor<3x4xf32> - %output = alloc() : memref<3x4xf32> - tensor_store %output_tensor, %output : memref<3x4xf32> + %output_buf = alloc() : memref<3x4xf32> + tensor_store %output, %output_buf : memref<3x4xf32> - %cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> - call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () + %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %c4 = constant 4 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32> + + %dyn_output_buf = alloc() : memref<3x4xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x4xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] + // CHECK-NEXT: [1, 1, 1, 1] return } -// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [1, 1, 1, 1] -// CHECK: [1, 1, 1, 1] +func @broadcast_to_the_same_shape() { + %input_buf = alloc() : memref<2x3xf32> + + %c1f32 = constant 1.0 : f32 + %c2f32 = constant 2.0 : f32 + %c3f32 = constant 3.0 : f32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32> + store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32> + store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32> + store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32> + store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32> + store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32> + %input = tensor_load %input_buf : memref<2x3xf32> + + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor<2x3xf32>) -> tensor<2x3xf32> + + %output_buf = alloc() : memref<2x3xf32> + tensor_store %output, %output_buf : memref<2x3xf32> + + %unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK-NEXT: [1, 2, 3] + // CHECK-NEXT: [1, 2, 3] + + // Test DynamicBroadcastInDimOp. + %shape = tensor_from_elements %c2, %c3 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32> + + %dyn_output_buf = alloc() : memref<2x3xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<2x3xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK-NEXT: [1, 2, 3] + // CHECK-NEXT: [1, 2, 3] + return +} + +func @broadcast_1d_to_2d() { + %input_buf = alloc() : memref<3xf32> + + %c1f32 = constant 1.0 : f32 + %c2f32 = constant 2.0 : f32 + %c3f32 = constant 3.0 : f32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + store %c1f32, %input_buf[%c0] : memref<3xf32> + store %c2f32, %input_buf[%c1] : memref<3xf32> + store %c3f32, %input_buf[%c2] : memref<3xf32> + %input = tensor_load %input_buf : memref<3xf32> + + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor<3xf32>) -> tensor<3x3xf32> + + %output_buf = alloc() : memref<3x3xf32> + tensor_store %output, %output_buf : memref<3x3xf32> + + %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32> + call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] + // CHECK-NEXT: [1, 1, 1] + // CHECK-NEXT: [2, 2, 2] + // CHECK-NEXT: [3, 3, 3] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %c4 = constant 3 : index + %shape = tensor_from_elements %c3, %c4 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> + + %dyn_output_buf = alloc() : memref<3x3xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] + // CHECK-NEXT: [1, 1, 1] + // CHECK-NEXT: [2, 2, 2] + // CHECK-NEXT: [3, 3, 3] + return +} + +func @broadcast_1d_to_2d_with_transpose() { + %input_buf = alloc() : memref<3xf32> + + %c1f32 = constant 1.0 : f32 + %c2f32 = constant 2.0 : f32 + %c3f32 = constant 3.0 : f32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + store %c1f32, %input_buf[%c0] : memref<3xf32> + store %c2f32, %input_buf[%c1] : memref<3xf32> + store %c3f32, %input_buf[%c2] : memref<3xf32> + %input = tensor_load %input_buf : memref<3xf32> + + // Test BroadcastInDimOp. + %output = "mhlo.broadcast_in_dim"(%input) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<3xf32>) -> tensor<3x3xf32> + + %output_buf = alloc() : memref<3x3xf32> + tensor_store %output, %output_buf : memref<3x3xf32> + + %unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32> + call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] + // CHECK-NEXT: [1, 2, 3] + // CHECK-NEXT: [1, 2, 3] + // CHECK-NEXT: [1, 2, 3] + + // Test DynamicBroadcastInDimOp. + %c3 = constant 3 : index + %shape = tensor_from_elements %c3, %c3 : tensor<2xindex> + %dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32> + + %dyn_output_buf = alloc() : memref<3x3xf32> + tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32> + + %unranked_dyn_output = memref_cast %dyn_output_buf + : memref<3x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] + // CHECK-NEXT: [1, 2, 3] + // CHECK-NEXT: [1, 2, 3] + // CHECK-NEXT: [1, 2, 3] + return +} diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 399ec9e..910129c 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \ +// RUN: -buffer-deallocation -split-input-file -cse %s -o - \ +// RUN: | FILECHECK_OPTS="" FileCheck %s // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { // ----- -func @external_func() -> tensor<3xi64> - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)> // CHECK-LABEL: func @dyn_broadcast -func @dyn_broadcast(%operand: memref) { - // CHECK-SAME: (%[[OPERAND:.*]]: memref) +func @dyn_broadcast(%operand: memref) -> index { + // CHECK-SAME: %[[OPERAND:.*]]: memref %tensor_operand = tensor_load %operand : memref %c1 = constant 1 : i64 %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64> %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor - // CHECK: %[[SHAPE:.*]] = tensor_from_elements - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> - // CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> - // CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index - // CHECK: %[[C2:.*]] = constant 2 : index - // CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> - // CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index - // CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) - - // CHECK: %[[C0_:.*]] = constant 0 : index - // CHECK: %[[C1_:.*]] = constant 1 : index - - // CHECK: %[[C1__:.*]] = constant 1 : index - // CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> - // CHECK: %[[C0___:.*]] = constant 0 : index - // CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref - // CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index - // CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] - // CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index - - // CHECK: %[[C2_:.*]] = constant 2 : index - // CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> - // CHECK: %[[C1___:.*]] = constant 1 : index - // CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref - // CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index - // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] - // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index - - // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to - // CHECK-SAME: offset: [0], - // CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]] - // CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] - // CHECK-SAME: : memref to memref - - // CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { - // CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> - // CHECK-SAME: } : (memref, memref) -> () - - // Do not store the value back to avoid the tensor-store being rewritten to - // a copy into the pre-allocated argument. - return + %rank = rank %tensor_result : tensor + return %rank : index } +// CHECK: %[[SHAPE:.*]] = tensor_from_elements +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> +// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> +// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> +// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index +// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref +// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref +// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index +// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref +// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index +// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index +// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index +// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index +// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref to memref +// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref, memref) -> () +// CHECK: dealloc %[[RESULT]] : memref // ----- @@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // CHECK: %[[C0_:.*]] = constant 0 : index - // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // CHECK: %[[C1_:.*]] = constant 1 : index - // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () @@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor) { // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // CHECK: %[[C0_:.*]] = constant 0 : index - // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // CHECK: %[[C1_:.*]] = constant 1 : index - // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () @@ -645,7 +620,7 @@ func @shape_assuming_memref(%arg0: tensor) -> tensor { %4 = tensor_cast %3 : tensor to tensor<1xindex> %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor - // CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref, memref, memref) -> () + // CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref, memref, memref) -> () %7 = mhlo.maximum %5, %6 : tensor // CHECK: shape.assuming_yield %{{.*}} : memref shape.assuming_yield %7 : tensor