[HLO] Fix HLO DynamicBroadcastInDimOp -> LHLO lowering.

The conversion had a bug in computation of strides and sizes args for std.memref_reinterpret_cast. The previous version also relied on linalg::ReshapeOp to do broadcasting when the rank of the output was higher than the rank of the input. Now the broadcasting is entirely done via descriptor modification and linalg::ReshapeOp was replaced with CopyOp.

PiperOrigin-RevId: 341379871
This commit is contained in:
Alexander Belyaev 2020-11-09 04:23:54 -08:00 committed by TensorFlow MLIR Team
parent 4ef12aa000
commit d4f2c767d3
3 changed files with 492 additions and 190 deletions

View File

@ -194,8 +194,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
Value transformed_operand = Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<lmhlo::BroadcastInDimOp>( rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer}); rewriter.replaceOp(op, {resultBuffer});
@ -211,48 +210,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>(); auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape(); auto operand_shape = operand_type.getShape();
auto operand_rank = operand_type.getRank();
SmallVector<Value, 2> sizes, strides; auto result_type = op.getType().cast<RankedTensorType>();
sizes.reserve(operand_shape.size()); auto result_rank = result_type.getRank();
strides.reserve(operand_shape.size());
Value zero = b->create<ConstantIndexOp>(loc, 0); Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1); Value one = b->create<ConstantIndexOp>(loc, 1);
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
Value broadcast_dim_value =
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
Value result_dim_size = b->create<ExtractElementOp>(
loc, op.output_dimensions(), broadcast_dim_value);
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[dim.index()])
? b->create<DimOp>(loc, operand, dim.index()).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
.getResult();
// TODO(pifon): Revisit if this cast is needed. Maybe we can use // Compute a reversed scan product. Compute the stride for the dimensions so
// tensor<index> for `output_dimensions` as well. // far, working from minor to major dimensions. Additionally, save the
// operand shape Values to use in the next loop.
SmallVector<Value, 2> operand_strides(operand_rank, one);
SmallVector<Value, 2> operand_sizes(operand_rank, one);
Value stride_so_far = one;
for (int i = operand_rank - 1; i >= 0; --i) {
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[i])
? b->create<DimOp>(loc, operand, i).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
operand_sizes[i] = operand_dim_size;
operand_strides[i] = stride_so_far;
if (i > 0) {
stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
}
}
SmallVector<Value, 2> sizes, strides;
sizes.reserve(result_rank);
strides.reserve(result_rank);
DenseMap<int, int> output_to_input_dim;
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
}
for (int i = 0; i < result_rank; ++i) {
Value i_val = b->create<ConstantIndexOp>(loc, i);
Value result_dim_size =
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
if (!result_dim_size.getType().isIndex()) { if (!result_dim_size.getType().isIndex()) {
result_dim_size = result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType()); b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
} }
sizes.push_back(result_dim_size);
auto it = output_to_input_dim.find(i);
// If the rank of the output is greater than the rank of the input, i.e.
// there was no output dimension in the inverse broadcast_dimensions map
// we also set stride to 0 to emulate padding of the shape with 1s and the
// corresponding expansion.
if (it == output_to_input_dim.end()) {
strides.push_back(zero);
continue;
}
// There can be two cases: // There can be two cases:
// 1) Operand dim == result dim => expansion is not needed => stride := 1. // 1) Operand dim == result dim => expansion is not needed
// => stride flattened buffer stride
// 2) Operand dim < result dim => expansion is needed => stride := 0. // 2) Operand dim < result dim => expansion is needed => stride := 0.
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt, int dim = it->second;
operand_dim_size, result_dim_size); Value is_expansion = b->create<CmpIOp>(
strides.push_back( loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
b->create<mlir::SelectOp>(loc, is_expansion, zero, one)); strides.push_back(b->create<mlir::SelectOp>(loc, is_expansion, zero,
operand_strides[dim]));
// Size of input dim can be set to the size of the corresponding output
// dimension for both cases.
sizes.push_back(result_dim_size);
} }
// Type-erased memref type with static rank, dynamic sizes and strides. // Type-erased memref type with static rank, dynamic sizes and strides.
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(), SmallVector<int64_t, 2> dynamic_layout(result_rank,
MemRefType::kDynamicStrideOrOffset); MemRefType::kDynamicStrideOrOffset);
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(), SmallVector<int64_t, 2> dynamic_shape(result_rank,
MemRefType::kDynamicSize); MemRefType::kDynamicSize);
auto type_erased_memref_type = MemRefType::get( auto type_erased_memref_type = MemRefType::get(
dynamic_shape, operand_type.getElementType(), dynamic_shape, operand_type.getElementType(),

View File

@ -1,4 +1,10 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo \
// RUN: -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s --dump-input=always
func @main() -> () { func @main() -> () {
call @trivial_broadcast_wrapper() : () -> () call @trivial_broadcast_wrapper() : () -> ()
@ -8,6 +14,9 @@ func @main() -> () {
call @broadcast_in_Y_dim_transpose_wrapper() : () -> () call @broadcast_in_Y_dim_transpose_wrapper() : () -> ()
call @broadcast_scalar_1d_wrapper() : () -> () call @broadcast_scalar_1d_wrapper() : () -> ()
call @broadcast_scalar_2d_wrapper() : () -> () call @broadcast_scalar_2d_wrapper() : () -> ()
call @broadcast_to_the_same_shape() : () -> ()
call @broadcast_1d_to_2d() : () -> ()
call @broadcast_1d_to_2d_with_transpose() : () -> ()
return return
} }
@ -15,199 +24,490 @@ func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
func @trivial_broadcast_wrapper() { func @trivial_broadcast_wrapper() {
%input = alloc() : memref<3xf32> %input_buf = alloc() : memref<3xf32>
%c1f32 = constant 1.0 : f32
%c0 = constant 0 : index
store %c1f32, %input[%c0] : memref<3xf32>
%c2f32 = constant 2.0 : f32
%c1 = constant 1 : index
store %c2f32, %input[%c1] : memref<3xf32>
%c3f32 = constant 3.0 : f32
%c2 = constant 2 : index
store %c3f32, %input[%c2] : memref<3xf32>
%input_tensor = tensor_load %input : memref<3xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { %c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32
%c3f32 = constant 3.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
store %c1f32, %input_buf[%c0] : memref<3xf32>
store %c2f32, %input_buf[%c1] : memref<3xf32>
store %c3f32, %input_buf[%c2] : memref<3xf32>
%input = tensor_load %input_buf : memref<3xf32>
// Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>) -> tensor<3x4xf32> } : (tensor<3xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [2, 2, 2, 2]
// CHECK-NEXT: [3, 3, 3, 3]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [2, 2, 2, 2]
// CHECK-NEXT: [3, 3, 3, 3]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [2, 2, 2, 2]
// CHECK: [3, 3, 3, 3]
func @broadcast_in_X_dim_wrapper() { func @broadcast_in_X_dim_wrapper() {
%input = alloc() : memref<1x4xf32> %input_buf = alloc() : memref<1x4xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input[%c0, %c0] : memref<1x4xf32> store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input[%c0, %c1] : memref<1x4xf32> store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input[%c0, %c2] : memref<1x4xf32> store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
%c4f32 = constant 4.0 : f32 %c4f32 = constant 4.0 : f32
%c3 = constant 3 : index %c3 = constant 3 : index
store %c4f32, %input[%c0, %c3] : memref<1x4xf32> store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
%input_tensor = tensor_load %input : memref<1x4xf32> %input = tensor_load %input_buf : memref<1x4xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x4xf32>) -> tensor<3x4xf32> } : (tensor<1x4xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
// Test DynamicBroadcastInDimOp.
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 2, 3, 4]
// CHECK: [1, 2, 3, 4]
// CHECK: [1, 2, 3, 4]
func @broadcast_in_Y_dim_wrapper() { func @broadcast_in_Y_dim_wrapper() {
%input = alloc() : memref<3x1xf32> %input_buf = alloc() : memref<3x1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input[%c0, %c0] : memref<3x1xf32> store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input[%c1, %c0] : memref<3x1xf32> store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input[%c2, %c0] : memref<3x1xf32> store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
%input_tensor = tensor_load %input : memref<3x1xf32> %input = tensor_load %input_buf : memref<3x1xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<3x1xf32>) -> tensor<3x4xf32> } : (tensor<3x1xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [2, 2, 2, 2]
// CHECK-NEXT: [3, 3, 3, 3]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [2, 2, 2, 2]
// CHECK-NEXT: [3, 3, 3, 3]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [2, 2, 2, 2]
// CHECK: [3, 3, 3, 3]
func @broadcast_in_X_dim_transpose_wrapper() { func @broadcast_in_X_dim_transpose_wrapper() {
%input = alloc() : memref<4x1xf32> %input_buf = alloc() : memref<4x1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input[%c0, %c0] : memref<4x1xf32> store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input[%c1, %c0] : memref<4x1xf32> store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input[%c2, %c0] : memref<4x1xf32> store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
%c4f32 = constant 4.0 : f32 %c4f32 = constant 4.0 : f32
%c3 = constant 3 : index %c3 = constant 3 : index
store %c4f32, %input[%c3, %c0] : memref<4x1xf32> store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
%input_tensor = tensor_load %input : memref<4x1xf32> %input = tensor_load %input_buf : memref<4x1xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<4x1xf32>) -> tensor<3x4xf32> } : (tensor<4x1xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
// Test DynamicBroadcastInDimOp.
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
// CHECK-NEXT: [1, 2, 3, 4]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 2, 3, 4]
// CHECK: [1, 2, 3, 4]
// CHECK: [1, 2, 3, 4]
func @broadcast_in_Y_dim_transpose_wrapper() { func @broadcast_in_Y_dim_transpose_wrapper() {
%input = alloc() : memref<1x3xf32> %input_buf = alloc() : memref<1x3xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input[%c0, %c0] : memref<1x3xf32> store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
%c2f32 = constant 2.0 : f32 %c2f32 = constant 2.0 : f32
%c1 = constant 1 : index %c1 = constant 1 : index
store %c2f32, %input[%c0, %c1] : memref<1x3xf32> store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
%c3f32 = constant 3.0 : f32 %c3f32 = constant 3.0 : f32
%c2 = constant 2 : index %c2 = constant 2 : index
store %c3f32, %input[%c0, %c2] : memref<1x3xf32> store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
%input_tensor = tensor_load %input : memref<1x3xf32> %input = tensor_load %input_buf : memref<1x3xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<1x3xf32>) -> tensor<3x4xf32> } : (tensor<1x3xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT-NEXT: [1, 1, 1, 1]
// CHECK-NEXT-NEXT: [2, 2, 2, 2]
// CHECK-NEXT-NEXT: [3, 3, 3, 3]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT-NEXT: [1, 1, 1, 1]
// CHECK-NEXT-NEXT: [2, 2, 2, 2]
// CHECK-NEXT-NEXT: [3, 3, 3, 3]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [2, 2, 2, 2]
// CHECK: [3, 3, 3, 3]
func @broadcast_scalar_1d_wrapper() { func @broadcast_scalar_1d_wrapper() {
%input = alloc() : memref<1xf32> %input_buf = alloc() : memref<1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input[%c0] : memref<1xf32> store %c1f32, %input_buf[%c0] : memref<1xf32>
%input_tensor = tensor_load %input : memref<1xf32> %input = tensor_load %input_buf : memref<1xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<0> : tensor<1xi64> broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>) -> tensor<3x4xf32> } : (tensor<1xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [1, 1, 1, 1]
func @broadcast_scalar_2d_wrapper() { func @broadcast_scalar_2d_wrapper() {
%input = alloc() : memref<1x1xf32> %input_buf = alloc() : memref<1x1xf32>
%c1f32 = constant 1.0 : f32 %c1f32 = constant 1.0 : f32
%c0 = constant 0 : index %c0 = constant 0 : index
store %c1f32, %input[%c0, %c0] : memref<1x1xf32> store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
%input_tensor = tensor_load %input : memref<1x1xf32> %input = tensor_load %input_buf : memref<1x1xf32>
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) { // Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x1xf32>) -> tensor<3x4xf32> } : (tensor<1x1xf32>) -> tensor<3x4xf32>
%output = alloc() : memref<3x4xf32> %output_buf = alloc() : memref<3x4xf32>
tensor_store %output_tensor, %output : memref<3x4xf32> tensor_store %output, %output_buf : memref<3x4xf32>
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32> %unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> () call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 4 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
%dyn_output_buf = alloc() : memref<3x4xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x4xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
// CHECK-NEXT: [1, 1, 1, 1]
return return
} }
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [1, 1, 1, 1]
// CHECK: [1, 1, 1, 1]
func @broadcast_to_the_same_shape() {
%input_buf = alloc() : memref<2x3xf32>
%c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32
%c3f32 = constant 3.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
%input = tensor_load %input_buf : memref<2x3xf32>
// Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
%output_buf = alloc() : memref<2x3xf32>
tensor_store %output, %output_buf : memref<2x3xf32>
%unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK-NEXT: [1, 2, 3]
// CHECK-NEXT: [1, 2, 3]
// Test DynamicBroadcastInDimOp.
%shape = tensor_from_elements %c2, %c3 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
%dyn_output_buf = alloc() : memref<2x3xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<2x3xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK-NEXT: [1, 2, 3]
// CHECK-NEXT: [1, 2, 3]
return
}
func @broadcast_1d_to_2d() {
%input_buf = alloc() : memref<3xf32>
%c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32
%c3f32 = constant 3.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
store %c1f32, %input_buf[%c0] : memref<3xf32>
store %c2f32, %input_buf[%c1] : memref<3xf32>
store %c3f32, %input_buf[%c2] : memref<3xf32>
%input = tensor_load %input_buf : memref<3xf32>
// Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>) -> tensor<3x3xf32>
%output_buf = alloc() : memref<3x3xf32>
tensor_store %output, %output_buf : memref<3x3xf32>
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
// CHECK-NEXT: [1, 1, 1]
// CHECK-NEXT: [2, 2, 2]
// CHECK-NEXT: [3, 3, 3]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%c4 = constant 3 : index
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
%dyn_output_buf = alloc() : memref<3x3xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
// CHECK-NEXT: [1, 1, 1]
// CHECK-NEXT: [2, 2, 2]
// CHECK-NEXT: [3, 3, 3]
return
}
func @broadcast_1d_to_2d_with_transpose() {
%input_buf = alloc() : memref<3xf32>
%c1f32 = constant 1.0 : f32
%c2f32 = constant 2.0 : f32
%c3f32 = constant 3.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
store %c1f32, %input_buf[%c0] : memref<3xf32>
store %c2f32, %input_buf[%c1] : memref<3xf32>
store %c3f32, %input_buf[%c2] : memref<3xf32>
%input = tensor_load %input_buf : memref<3xf32>
// Test BroadcastInDimOp.
%output = "mhlo.broadcast_in_dim"(%input) {
broadcast_dimensions = dense<1> : tensor<1xi64>
} : (tensor<3xf32>) -> tensor<3x3xf32>
%output_buf = alloc() : memref<3x3xf32>
tensor_store %output, %output_buf : memref<3x3xf32>
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
// CHECK-NEXT: [1, 2, 3]
// CHECK-NEXT: [1, 2, 3]
// CHECK-NEXT: [1, 2, 3]
// Test DynamicBroadcastInDimOp.
%c3 = constant 3 : index
%shape = tensor_from_elements %c3, %c3 : tensor<2xindex>
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
broadcast_dimensions = dense<1> : tensor<1xi64>
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
%dyn_output_buf = alloc() : memref<3x3xf32>
tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32>
%unranked_dyn_output = memref_cast %dyn_output_buf
: memref<3x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
// CHECK-NEXT: [1, 2, 3]
// CHECK-NEXT: [1, 2, 3]
// CHECK-NEXT: [1, 2, 3]
return
}

View File

@ -1,4 +1,6 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
// RUN: | FILECHECK_OPTS="" FileCheck %s
// CHECK-LABEL: func @attrs // CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
// ----- // -----
func @external_func() -> tensor<3xi64> // CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK-LABEL: func @dyn_broadcast // CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) { func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>) // CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
%tensor_operand = tensor_load %operand : memref<?x?xf32> %tensor_operand = tensor_load %operand : memref<?x?xf32>
%c1 = constant 1 : i64 %c1 = constant 1 : i64
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64> %shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%rank = rank %tensor_result : tensor<?x?x?xf32>
return %rank : index
}
// CHECK: %[[SHAPE:.*]] = tensor_from_elements // CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index // CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> // CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index // CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> // CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index // CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) // CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[C0_:.*]] = constant 0 : index // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[C1_:.*]] = constant 1 : index // CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[C1__:.*]] = constant 1 : index // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> // CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[C0___:.*]] = constant 0 : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32> // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index // CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] // CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
// CHECK-SAME: offset: [0],
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
return
}
// ----- // -----
@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32> // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32> // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -645,7 +620,7 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex> %4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16> %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16> %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> () // CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
%7 = mhlo.maximum %5, %6 : tensor<?xf16> %7 = mhlo.maximum %5, %6 : tensor<?xf16>
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16> // CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
shape.assuming_yield %7 : tensor<?xf16> shape.assuming_yield %7 : tensor<?xf16>