[HLO] Fix HLO DynamicBroadcastInDimOp -> LHLO lowering.
The conversion had a bug in computation of strides and sizes args for std.memref_reinterpret_cast. The previous version also relied on linalg::ReshapeOp to do broadcasting when the rank of the output was higher than the rank of the input. Now the broadcasting is entirely done via descriptor modification and linalg::ReshapeOp was replaced with CopyOp. PiperOrigin-RevId: 341379871
This commit is contained in:
parent
4ef12aa000
commit
d4f2c767d3
|
@ -194,8 +194,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
|
|
||||||
Value transformed_operand =
|
Value transformed_operand =
|
||||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||||
rewriter.create<lmhlo::BroadcastInDimOp>(
|
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
|
||||||
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {resultBuffer});
|
rewriter.replaceOp(op, {resultBuffer});
|
||||||
|
|
||||||
|
@ -211,48 +210,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto operand_type = operand.getType().cast<MemRefType>();
|
auto operand_type = operand.getType().cast<MemRefType>();
|
||||||
auto operand_shape = operand_type.getShape();
|
auto operand_shape = operand_type.getShape();
|
||||||
|
auto operand_rank = operand_type.getRank();
|
||||||
|
|
||||||
SmallVector<Value, 2> sizes, strides;
|
auto result_type = op.getType().cast<RankedTensorType>();
|
||||||
sizes.reserve(operand_shape.size());
|
auto result_rank = result_type.getRank();
|
||||||
strides.reserve(operand_shape.size());
|
|
||||||
|
|
||||||
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||||
Value one = b->create<ConstantIndexOp>(loc, 1);
|
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||||
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
|
||||||
Value broadcast_dim_value =
|
|
||||||
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
|
|
||||||
Value result_dim_size = b->create<ExtractElementOp>(
|
|
||||||
loc, op.output_dimensions(), broadcast_dim_value);
|
|
||||||
Value operand_dim_size =
|
|
||||||
ShapedType::isDynamic(operand_shape[dim.index()])
|
|
||||||
? b->create<DimOp>(loc, operand, dim.index()).getResult()
|
|
||||||
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
|
|
||||||
.getResult();
|
|
||||||
|
|
||||||
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
|
// Compute a reversed scan product. Compute the stride for the dimensions so
|
||||||
// tensor<index> for `output_dimensions` as well.
|
// far, working from minor to major dimensions. Additionally, save the
|
||||||
|
// operand shape Values to use in the next loop.
|
||||||
|
SmallVector<Value, 2> operand_strides(operand_rank, one);
|
||||||
|
SmallVector<Value, 2> operand_sizes(operand_rank, one);
|
||||||
|
Value stride_so_far = one;
|
||||||
|
for (int i = operand_rank - 1; i >= 0; --i) {
|
||||||
|
Value operand_dim_size =
|
||||||
|
ShapedType::isDynamic(operand_shape[i])
|
||||||
|
? b->create<DimOp>(loc, operand, i).getResult()
|
||||||
|
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
|
||||||
|
operand_sizes[i] = operand_dim_size;
|
||||||
|
|
||||||
|
operand_strides[i] = stride_so_far;
|
||||||
|
if (i > 0) {
|
||||||
|
stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 2> sizes, strides;
|
||||||
|
sizes.reserve(result_rank);
|
||||||
|
strides.reserve(result_rank);
|
||||||
|
|
||||||
|
DenseMap<int, int> output_to_input_dim;
|
||||||
|
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||||
|
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < result_rank; ++i) {
|
||||||
|
Value i_val = b->create<ConstantIndexOp>(loc, i);
|
||||||
|
Value result_dim_size =
|
||||||
|
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
|
||||||
if (!result_dim_size.getType().isIndex()) {
|
if (!result_dim_size.getType().isIndex()) {
|
||||||
result_dim_size =
|
result_dim_size =
|
||||||
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
|
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
|
||||||
}
|
}
|
||||||
|
sizes.push_back(result_dim_size);
|
||||||
|
|
||||||
|
auto it = output_to_input_dim.find(i);
|
||||||
|
// If the rank of the output is greater than the rank of the input, i.e.
|
||||||
|
// there was no output dimension in the inverse broadcast_dimensions map
|
||||||
|
// we also set stride to 0 to emulate padding of the shape with 1s and the
|
||||||
|
// corresponding expansion.
|
||||||
|
if (it == output_to_input_dim.end()) {
|
||||||
|
strides.push_back(zero);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// There can be two cases:
|
// There can be two cases:
|
||||||
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
|
// 1) Operand dim == result dim => expansion is not needed
|
||||||
|
// => stride flattened buffer stride
|
||||||
// 2) Operand dim < result dim => expansion is needed => stride := 0.
|
// 2) Operand dim < result dim => expansion is needed => stride := 0.
|
||||||
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
|
int dim = it->second;
|
||||||
operand_dim_size, result_dim_size);
|
Value is_expansion = b->create<CmpIOp>(
|
||||||
strides.push_back(
|
loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
|
||||||
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
|
strides.push_back(b->create<mlir::SelectOp>(loc, is_expansion, zero,
|
||||||
|
operand_strides[dim]));
|
||||||
// Size of input dim can be set to the size of the corresponding output
|
|
||||||
// dimension for both cases.
|
|
||||||
sizes.push_back(result_dim_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type-erased memref type with static rank, dynamic sizes and strides.
|
// Type-erased memref type with static rank, dynamic sizes and strides.
|
||||||
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
|
SmallVector<int64_t, 2> dynamic_layout(result_rank,
|
||||||
MemRefType::kDynamicStrideOrOffset);
|
MemRefType::kDynamicStrideOrOffset);
|
||||||
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
|
SmallVector<int64_t, 2> dynamic_shape(result_rank,
|
||||||
MemRefType::kDynamicSize);
|
MemRefType::kDynamicSize);
|
||||||
auto type_erased_memref_type = MemRefType::get(
|
auto type_erased_memref_type = MemRefType::get(
|
||||||
dynamic_shape, operand_type.getElementType(),
|
dynamic_shape, operand_type.getElementType(),
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo \
|
||||||
|
// RUN: -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse \
|
||||||
|
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
|
||||||
|
// RUN: -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm \
|
||||||
|
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
|
||||||
|
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
|
||||||
|
// RUN: | FileCheck %s --dump-input=always
|
||||||
|
|
||||||
func @main() -> () {
|
func @main() -> () {
|
||||||
call @trivial_broadcast_wrapper() : () -> ()
|
call @trivial_broadcast_wrapper() : () -> ()
|
||||||
|
@ -8,6 +14,9 @@ func @main() -> () {
|
||||||
call @broadcast_in_Y_dim_transpose_wrapper() : () -> ()
|
call @broadcast_in_Y_dim_transpose_wrapper() : () -> ()
|
||||||
call @broadcast_scalar_1d_wrapper() : () -> ()
|
call @broadcast_scalar_1d_wrapper() : () -> ()
|
||||||
call @broadcast_scalar_2d_wrapper() : () -> ()
|
call @broadcast_scalar_2d_wrapper() : () -> ()
|
||||||
|
call @broadcast_to_the_same_shape() : () -> ()
|
||||||
|
call @broadcast_1d_to_2d() : () -> ()
|
||||||
|
call @broadcast_1d_to_2d_with_transpose() : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,199 +24,490 @@ func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface }
|
||||||
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
||||||
|
|
||||||
func @trivial_broadcast_wrapper() {
|
func @trivial_broadcast_wrapper() {
|
||||||
%input = alloc() : memref<3xf32>
|
%input_buf = alloc() : memref<3xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
|
||||||
%c0 = constant 0 : index
|
|
||||||
store %c1f32, %input[%c0] : memref<3xf32>
|
|
||||||
%c2f32 = constant 2.0 : f32
|
|
||||||
%c1 = constant 1 : index
|
|
||||||
store %c2f32, %input[%c1] : memref<3xf32>
|
|
||||||
%c3f32 = constant 3.0 : f32
|
|
||||||
%c2 = constant 2 : index
|
|
||||||
store %c3f32, %input[%c2] : memref<3xf32>
|
|
||||||
%input_tensor = tensor_load %input : memref<3xf32>
|
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
%c1f32 = constant 1.0 : f32
|
||||||
|
%c2f32 = constant 2.0 : f32
|
||||||
|
%c3f32 = constant 3.0 : f32
|
||||||
|
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
store %c1f32, %input_buf[%c0] : memref<3xf32>
|
||||||
|
store %c2f32, %input_buf[%c1] : memref<3xf32>
|
||||||
|
store %c3f32, %input_buf[%c2] : memref<3xf32>
|
||||||
|
%input = tensor_load %input_buf : memref<3xf32>
|
||||||
|
|
||||||
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<3xf32>) -> tensor<3x4xf32>
|
} : (tensor<3xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [2, 2, 2, 2]
|
||||||
|
// CHECK-NEXT: [3, 3, 3, 3]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [2, 2, 2, 2]
|
||||||
|
// CHECK-NEXT: [3, 3, 3, 3]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [2, 2, 2, 2]
|
|
||||||
// CHECK: [3, 3, 3, 3]
|
|
||||||
|
|
||||||
func @broadcast_in_X_dim_wrapper() {
|
func @broadcast_in_X_dim_wrapper() {
|
||||||
%input = alloc() : memref<1x4xf32>
|
%input_buf = alloc() : memref<1x4xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input[%c0, %c0] : memref<1x4xf32>
|
store %c1f32, %input_buf[%c0, %c0] : memref<1x4xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input[%c0, %c1] : memref<1x4xf32>
|
store %c2f32, %input_buf[%c0, %c1] : memref<1x4xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input[%c0, %c2] : memref<1x4xf32>
|
store %c3f32, %input_buf[%c0, %c2] : memref<1x4xf32>
|
||||||
%c4f32 = constant 4.0 : f32
|
%c4f32 = constant 4.0 : f32
|
||||||
%c3 = constant 3 : index
|
%c3 = constant 3 : index
|
||||||
store %c4f32, %input[%c0, %c3] : memref<1x4xf32>
|
store %c4f32, %input_buf[%c0, %c3] : memref<1x4xf32>
|
||||||
%input_tensor = tensor_load %input : memref<1x4xf32>
|
%input = tensor_load %input_buf : memref<1x4xf32>
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<1x4xf32>) -> tensor<3x4xf32>
|
} : (tensor<1x4xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<1x4xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 2, 3, 4]
|
|
||||||
// CHECK: [1, 2, 3, 4]
|
|
||||||
// CHECK: [1, 2, 3, 4]
|
|
||||||
|
|
||||||
func @broadcast_in_Y_dim_wrapper() {
|
func @broadcast_in_Y_dim_wrapper() {
|
||||||
%input = alloc() : memref<3x1xf32>
|
%input_buf = alloc() : memref<3x1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input[%c0, %c0] : memref<3x1xf32>
|
store %c1f32, %input_buf[%c0, %c0] : memref<3x1xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input[%c1, %c0] : memref<3x1xf32>
|
store %c2f32, %input_buf[%c1, %c0] : memref<3x1xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input[%c2, %c0] : memref<3x1xf32>
|
store %c3f32, %input_buf[%c2, %c0] : memref<3x1xf32>
|
||||||
%input_tensor = tensor_load %input : memref<3x1xf32>
|
%input = tensor_load %input_buf : memref<3x1xf32>
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<3x1xf32>) -> tensor<3x4xf32>
|
} : (tensor<3x1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [2, 2, 2, 2]
|
||||||
|
// CHECK-NEXT: [3, 3, 3, 3]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<3x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [2, 2, 2, 2]
|
||||||
|
// CHECK-NEXT: [3, 3, 3, 3]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [2, 2, 2, 2]
|
|
||||||
// CHECK: [3, 3, 3, 3]
|
|
||||||
|
|
||||||
func @broadcast_in_X_dim_transpose_wrapper() {
|
func @broadcast_in_X_dim_transpose_wrapper() {
|
||||||
%input = alloc() : memref<4x1xf32>
|
%input_buf = alloc() : memref<4x1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input[%c0, %c0] : memref<4x1xf32>
|
store %c1f32, %input_buf[%c0, %c0] : memref<4x1xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input[%c1, %c0] : memref<4x1xf32>
|
store %c2f32, %input_buf[%c1, %c0] : memref<4x1xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input[%c2, %c0] : memref<4x1xf32>
|
store %c3f32, %input_buf[%c2, %c0] : memref<4x1xf32>
|
||||||
%c4f32 = constant 4.0 : f32
|
%c4f32 = constant 4.0 : f32
|
||||||
%c3 = constant 3 : index
|
%c3 = constant 3 : index
|
||||||
store %c4f32, %input[%c3, %c0] : memref<4x1xf32>
|
store %c4f32, %input_buf[%c3, %c0] : memref<4x1xf32>
|
||||||
%input_tensor = tensor_load %input : memref<4x1xf32>
|
%input = tensor_load %input_buf : memref<4x1xf32>
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
} : (tensor<4x1xf32>) -> tensor<3x4xf32>
|
} : (tensor<4x1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
|
} : (tensor<4x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
|
// CHECK-NEXT: [1, 2, 3, 4]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 2, 3, 4]
|
|
||||||
// CHECK: [1, 2, 3, 4]
|
|
||||||
// CHECK: [1, 2, 3, 4]
|
|
||||||
|
|
||||||
func @broadcast_in_Y_dim_transpose_wrapper() {
|
func @broadcast_in_Y_dim_transpose_wrapper() {
|
||||||
%input = alloc() : memref<1x3xf32>
|
%input_buf = alloc() : memref<1x3xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input[%c0, %c0] : memref<1x3xf32>
|
store %c1f32, %input_buf[%c0, %c0] : memref<1x3xf32>
|
||||||
%c2f32 = constant 2.0 : f32
|
%c2f32 = constant 2.0 : f32
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
store %c2f32, %input[%c0, %c1] : memref<1x3xf32>
|
store %c2f32, %input_buf[%c0, %c1] : memref<1x3xf32>
|
||||||
%c3f32 = constant 3.0 : f32
|
%c3f32 = constant 3.0 : f32
|
||||||
%c2 = constant 2 : index
|
%c2 = constant 2 : index
|
||||||
store %c3f32, %input[%c0, %c2] : memref<1x3xf32>
|
store %c3f32, %input_buf[%c0, %c2] : memref<1x3xf32>
|
||||||
%input_tensor = tensor_load %input : memref<1x3xf32>
|
%input = tensor_load %input_buf : memref<1x3xf32>
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
} : (tensor<1x3xf32>) -> tensor<3x4xf32>
|
} : (tensor<1x3xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT-NEXT: [2, 2, 2, 2]
|
||||||
|
// CHECK-NEXT-NEXT: [3, 3, 3, 3]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
|
} : (tensor<1x3xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT-NEXT: [2, 2, 2, 2]
|
||||||
|
// CHECK-NEXT-NEXT: [3, 3, 3, 3]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [2, 2, 2, 2]
|
|
||||||
// CHECK: [3, 3, 3, 3]
|
|
||||||
|
|
||||||
func @broadcast_scalar_1d_wrapper() {
|
func @broadcast_scalar_1d_wrapper() {
|
||||||
%input = alloc() : memref<1xf32>
|
%input_buf = alloc() : memref<1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input[%c0] : memref<1xf32>
|
store %c1f32, %input_buf[%c0] : memref<1xf32>
|
||||||
%input_tensor = tensor_load %input : memref<1xf32>
|
%input = tensor_load %input_buf : memref<1xf32>
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
} : (tensor<1xf32>) -> tensor<3x4xf32>
|
} : (tensor<1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
|
|
||||||
func @broadcast_scalar_2d_wrapper() {
|
func @broadcast_scalar_2d_wrapper() {
|
||||||
%input = alloc() : memref<1x1xf32>
|
%input_buf = alloc() : memref<1x1xf32>
|
||||||
%c1f32 = constant 1.0 : f32
|
%c1f32 = constant 1.0 : f32
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
store %c1f32, %input[%c0, %c0] : memref<1x1xf32>
|
store %c1f32, %input_buf[%c0, %c0] : memref<1x1xf32>
|
||||||
%input_tensor = tensor_load %input : memref<1x1xf32>
|
%input = tensor_load %input_buf : memref<1x1xf32>
|
||||||
|
|
||||||
%output_tensor = "mhlo.broadcast_in_dim"(%input_tensor) {
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<1x1xf32>) -> tensor<3x4xf32>
|
} : (tensor<1x1xf32>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
%output = alloc() : memref<3x4xf32>
|
%output_buf = alloc() : memref<3x4xf32>
|
||||||
tensor_store %output_tensor, %output : memref<3x4xf32>
|
tensor_store %output, %output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
%cast_for_print = memref_cast %output : memref<3x4xf32> to memref<*xf32>
|
%unranked_output = memref_cast %output_buf : memref<3x4xf32> to memref<*xf32>
|
||||||
call @print_memref_f32(%cast_for_print) : (memref<*xf32>) -> ()
|
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<3x4xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x4xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x4xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x4xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1, 1]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: rank = 2 offset = 0 sizes = [3, 4] strides = [4, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
// CHECK: [1, 1, 1, 1]
|
|
||||||
|
|
||||||
|
func @broadcast_to_the_same_shape() {
|
||||||
|
%input_buf = alloc() : memref<2x3xf32>
|
||||||
|
|
||||||
|
%c1f32 = constant 1.0 : f32
|
||||||
|
%c2f32 = constant 2.0 : f32
|
||||||
|
%c3f32 = constant 3.0 : f32
|
||||||
|
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
store %c1f32, %input_buf[%c0, %c0] : memref<2x3xf32>
|
||||||
|
store %c1f32, %input_buf[%c1, %c0] : memref<2x3xf32>
|
||||||
|
store %c2f32, %input_buf[%c0, %c1] : memref<2x3xf32>
|
||||||
|
store %c2f32, %input_buf[%c1, %c1] : memref<2x3xf32>
|
||||||
|
store %c3f32, %input_buf[%c0, %c2] : memref<2x3xf32>
|
||||||
|
store %c3f32, %input_buf[%c1, %c2] : memref<2x3xf32>
|
||||||
|
%input = tensor_load %input_buf : memref<2x3xf32>
|
||||||
|
|
||||||
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||||
|
|
||||||
|
%output_buf = alloc() : memref<2x3xf32>
|
||||||
|
tensor_store %output, %output_buf : memref<2x3xf32>
|
||||||
|
|
||||||
|
%unraked_output = memref_cast %output_buf : memref<2x3xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%shape = tensor_from_elements %c2, %c3 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
|
} : (tensor<2x3xf32>, tensor<2xindex>) -> tensor<2x3xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<2x3xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<2x3xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<2x3xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @broadcast_1d_to_2d() {
|
||||||
|
%input_buf = alloc() : memref<3xf32>
|
||||||
|
|
||||||
|
%c1f32 = constant 1.0 : f32
|
||||||
|
%c2f32 = constant 2.0 : f32
|
||||||
|
%c3f32 = constant 3.0 : f32
|
||||||
|
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
store %c1f32, %input_buf[%c0] : memref<3xf32>
|
||||||
|
store %c2f32, %input_buf[%c1] : memref<3xf32>
|
||||||
|
store %c3f32, %input_buf[%c2] : memref<3xf32>
|
||||||
|
%input = tensor_load %input_buf : memref<3xf32>
|
||||||
|
|
||||||
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<3xf32>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
|
%output_buf = alloc() : memref<3x3xf32>
|
||||||
|
tensor_store %output, %output_buf : memref<3x3xf32>
|
||||||
|
|
||||||
|
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1]
|
||||||
|
// CHECK-NEXT: [2, 2, 2]
|
||||||
|
// CHECK-NEXT: [3, 3, 3]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%c4 = constant 3 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c4 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x3xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x3xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
|
// CHECK-NEXT: [1, 1, 1]
|
||||||
|
// CHECK-NEXT: [2, 2, 2]
|
||||||
|
// CHECK-NEXT: [3, 3, 3]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @broadcast_1d_to_2d_with_transpose() {
|
||||||
|
%input_buf = alloc() : memref<3xf32>
|
||||||
|
|
||||||
|
%c1f32 = constant 1.0 : f32
|
||||||
|
%c2f32 = constant 2.0 : f32
|
||||||
|
%c3f32 = constant 3.0 : f32
|
||||||
|
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
store %c1f32, %input_buf[%c0] : memref<3xf32>
|
||||||
|
store %c2f32, %input_buf[%c1] : memref<3xf32>
|
||||||
|
store %c3f32, %input_buf[%c2] : memref<3xf32>
|
||||||
|
%input = tensor_load %input_buf : memref<3xf32>
|
||||||
|
|
||||||
|
// Test BroadcastInDimOp.
|
||||||
|
%output = "mhlo.broadcast_in_dim"(%input) {
|
||||||
|
broadcast_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
} : (tensor<3xf32>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
|
%output_buf = alloc() : memref<3x3xf32>
|
||||||
|
tensor_store %output, %output_buf : memref<3x3xf32>
|
||||||
|
|
||||||
|
%unraked_output = memref_cast %output_buf : memref<3x3xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unraked_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
|
||||||
|
// Test DynamicBroadcastInDimOp.
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%shape = tensor_from_elements %c3, %c3 : tensor<2xindex>
|
||||||
|
%dyn_output = "mhlo.dynamic_broadcast_in_dim"(%input, %shape) {
|
||||||
|
broadcast_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
} : (tensor<3xf32>, tensor<2xindex>) -> tensor<3x3xf32>
|
||||||
|
|
||||||
|
%dyn_output_buf = alloc() : memref<3x3xf32>
|
||||||
|
tensor_store %dyn_output, %dyn_output_buf : memref<3x3xf32>
|
||||||
|
|
||||||
|
%unranked_dyn_output = memref_cast %dyn_output_buf
|
||||||
|
: memref<3x3xf32> to memref<*xf32>
|
||||||
|
call @print_memref_f32(%unranked_dyn_output) : (memref<*xf32>) -> ()
|
||||||
|
// CHECK: rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
// CHECK-NEXT: [1, 2, 3]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
|
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
|
||||||
|
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
|
||||||
|
// RUN: | FILECHECK_OPTS="" FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @attrs
|
// CHECK-LABEL: func @attrs
|
||||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @external_func() -> tensor<3xi64>
|
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
|
||||||
|
|
||||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @dyn_broadcast
|
// CHECK-LABEL: func @dyn_broadcast
|
||||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
|
||||||
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
|
||||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||||
%c1 = constant 1 : i64
|
%c1 = constant 1 : i64
|
||||||
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
|
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
|
||||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
%rank = rank %tensor_result : tensor<?x?x?xf32>
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
return %rank : index
|
||||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
|
||||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
|
||||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
|
||||||
|
|
||||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
|
||||||
|
|
||||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
|
||||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[C0___:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
|
||||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
|
||||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
|
||||||
|
|
||||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
|
||||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
|
||||||
// CHECK: %[[C1___:.*]] = constant 1 : index
|
|
||||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
|
||||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
|
||||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
|
||||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
|
||||||
|
|
||||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
|
|
||||||
// CHECK-SAME: offset: [0],
|
|
||||||
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
|
|
||||||
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
|
||||||
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
|
|
||||||
|
|
||||||
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
|
||||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
|
||||||
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
|
||||||
|
|
||||||
// Do not store the value back to avoid the tensor-store being rewritten to
|
|
||||||
// a copy into the pre-allocated argument.
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||||
|
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||||
|
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||||
|
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||||
|
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||||
|
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
||||||
|
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
||||||
|
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||||
|
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||||
|
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||||
|
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||||
|
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
@ -645,7 +620,7 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
||||||
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
|
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
|
||||||
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
||||||
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||||
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||||
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
||||||
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
|
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
|
||||||
shape.assuming_yield %7 : tensor<?xf16>
|
shape.assuming_yield %7 : tensor<?xf16>
|
||||||
|
|
Loading…
Reference in New Issue