diff --git a/WORKSPACE b/WORKSPACE index c5d4c8a..1d908e8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,9 +15,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "56fcd4ea8dafecbd71ff6eda7db6407d66505c93" +LLVM_COMMIT = "91e7a17133324ac4beaf6ed45c170436c2d91c98" -LLVM_SHA256 = "71537fa919251e225c34a820bd7ef3d0a295db58c5c5f07032323d45f4b318e7" +LLVM_SHA256 = "07ef834c47337dc7b38c765693a3b5a1835aac2f716fd9a06c374a71722c20de" LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index b3c240e..8e58c2d 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -56fcd4ea8dafecbd71ff6eda7db6407d66505c93 +91e7a17133324ac4beaf6ed45c170436c2d91c98 diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index ae05e1b..3d72efc 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1051,19 +1051,18 @@ class SliceConverter : public OpConversionPattern { return failure(); } - SmallVector ranges; + SmallVector offsets, sizes, strides; for (int i = 0, e = arg_type.getRank(); i < e; ++i) { - Value start_index = rewriter.create( - loc, slice_op.start_indices().getValue(i)); - Value limit_index = rewriter.create( - loc, slice_op.limit_indices().getValue(i)); - Value stride = rewriter.create( - loc, slice_op.strides().getValue(i)); - ranges.push_back(rewriter.create(loc, start_index, - limit_index, stride)); + offsets.push_back(rewriter.getI64IntegerAttr( + slice_op.start_indices().getValue(i))); + sizes.push_back(rewriter.getI64IntegerAttr( + slice_op.limit_indices().getValue(i) - + slice_op.start_indices().getValue(i))); + strides.push_back( + rewriter.getI64IntegerAttr(slice_op.strides().getValue(i))); } - auto linalg_slice = - rewriter.create(loc, slice_op.getOperand(0), ranges); + auto linalg_slice = rewriter.create(loc, slice_op.getOperand(0), + offsets, sizes, strides); rewriter.create(loc, linalg_slice, slice_op.getOperand(1)); rewriter.eraseOp(slice_op); return success(); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 1637eef..a8710ef 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -119,32 +119,32 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Compute memrefs for the value to reduce. This makes it easier to just // inline the body. auto output = *reduce_op.out().begin(); - // TODO(herhut) Move this to the SliceOp builder. auto resType = MemRefType::get( - llvm::None, output.getType().cast().getElementType(), + llvm::None, getElementTypeOrSelf(output.getType()), makeStridedLinearLayoutMap(llvm::None, MemRefType::getDynamicStrideOrOffset(), rewriter.getContext())); - auto accumulator = rewriter.create( - loc, resType, output, ArrayRef{launch_op.getThreadIds().x}); + OpFoldResult offset = launch_op.getThreadIds().x; + auto oneAttr = rewriter.getI64IntegerAttr(1); + OpFoldResult size = oneAttr; + OpFoldResult stride = oneAttr; + auto accumulator = rewriter.create(loc, resType, output, + offset, size, stride); llvm::SmallVector indexings; auto input_buffer = *reduce_op.operands().begin(); - auto input_type = input_buffer.getType().cast(); - for (int64_t dim = 0; dim < input_type.getRank(); ++dim) { - indexings.push_back(dim == reducing_dimension - ? loop.getInductionVar() - : launch_op.getThreadIds().x); - } - // TODO(herhut) Move this to the SliceOp builder. - auto input = *reduce_op.operand_begin(); - auto rhs = rewriter.create( - loc, - MemRefType::get( - llvm::None, input_type.getElementType(), - makeStridedLinearLayoutMap(llvm::None, - MemRefType::getDynamicStrideOrOffset(), - rewriter.getContext())), - input, indexings); + auto input_type_rank = + input_buffer.getType().cast().getRank(); + + Value input = *reduce_op.operand_begin(); + SmallVector offsets = llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, input_type_rank), [&](int dim) -> OpFoldResult { + return dim == reducing_dimension ? loop.getInductionVar() + : launch_op.getThreadIds().x; + })); + SmallVector sizes(input_type_rank, oneAttr); + SmallVector strides(input_type_rank, oneAttr); + auto rhs = rewriter.create(loc, accumulator.getType(), input, + offsets, sizes, strides); // Now copy over the actual body of the reduction, leaving out the // terminator. diff --git a/tests/lhlo-legalize-to-gpu.mlir b/tests/lhlo-legalize-to-gpu.mlir index f89c996..26b230a 100644 --- a/tests/lhlo-legalize-to-gpu.mlir +++ b/tests/lhlo-legalize-to-gpu.mlir @@ -13,22 +13,22 @@ func @reduce(%arg: memref<100x10xf32>, return } -// CHECK: func @reduce(%[[ARG0:.*]]: memref<100x10xf32>, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xf32>) { +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)> + +// CHECK: func @reduce(%[[ARG0:.*]]: memref<100x10xf32>, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xf32>) { // CHECK-DAG: %[[C100:.*]] = constant 100 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index -// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) { -// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref -// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32> -// CHECK-DAG: %[[LB:.*]] = constant 0 : index -// CHECK-DAG: %[[UB:.*]] = constant 10 : index -// CHECK-DAG: %[[STEP:.*]] = constant 1 : index -// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { -// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref -// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref -// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () -// CHECK: } -// CHECK: gpu.terminator -// CHECK: } -// CHECK: return -// CHECK: } -// CHECK: } +// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) { +// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref +// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32> +// CHECK-DAG: %[[LB:.*]] = constant 0 : index +// CHECK-DAG: %[[UB:.*]] = constant 10 : index +// CHECK-DAG: %[[STEP:.*]] = constant 1 : index +// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref +// CHECK: %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref +// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () +// CHECK: } +// CHECK: gpu.terminator +// CHECK: } +// CHECK: return diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index ec614be..e31369c 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -745,15 +745,7 @@ func @slice(%operand: memref, %result: memref) { } : (memref, memref) -> () return } -// CHECK: %[[L0:.*]] = constant 0 : index -// CHECK: %[[L2:.*]] = constant 2 : index -// CHECK: %[[L1:.*]] = constant 1 : index -// CHECK: %[[LHS:.*]] = linalg.range %[[L0]] : %[[L2]] : %[[L1]] -// CHECK: %[[R0:.*]] = constant 1 : index -// CHECK: %[[R2:.*]] = constant 3 : index -// CHECK: %[[R1:.*]] = constant 1 : index -// CHECK: %[[RHS:.*]] = linalg.range %[[R0]] : %[[R2]] : %[[R1]] -// CHECK: %[[RESULT:.*]] = linalg.slice %[[IN]][%[[LHS]], %[[RHS]]] +// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref to memref<2x2xf32, #{{.*}}> // CHECK: linalg.copy(%[[RESULT]], %[[OUT]]) // -----