From dd92c8ef619d35e33e6ed11d3aa25e215557c586 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 22 Sep 2020 09:06:55 -0700 Subject: [PATCH] Integrate LLVM at llvm/llvm-project@7e78d89052b1 Updates LLVM usage to match [7e78d89052b1](https://github.com/llvm/llvm-project/commit/7e78d89052b1) PiperOrigin-RevId: 333090785 --- build_tools/llvm_version.txt | 2 +- .../mhlo/transforms/legalize_to_linalg.cc | 116 +++++++++--------- tests/lhlo-fuse-linalg.mlir | 76 +++++++----- tests/lhlo-legalize-to-linalg.mlir | 3 +- 4 files changed, 107 insertions(+), 90 deletions(-) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 7c90cec..c96fc97 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -93fd30bac3345fea4f5beba3241f1ef4f2f5f419 +7e78d89052b15f32ea56f018698194c7c9627152 diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 0a8105e..1e8442e 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" @@ -32,8 +33,10 @@ limitations under the License. #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -75,69 +78,69 @@ class PointwiseToLinalgConverter : public OpConversionPattern { OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); - auto argType = - op.getOperation()->getOperand(0).getType().template cast(); - if (!argType.hasRank()) { - emitError(loc, "lhlo to linalg conversion expects ranked args"); - return failure(); - } - auto elemTy = argType.getElementType(); - if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { - return failure(); - } + ShapedType t0 = args[0].getType().template dyn_cast(); + if (!t0) return failure(); + + unsigned nloops = t0.getRank(); + auto fail = [&](ShapedType t) { + return !t || !t.hasRank() || t.getRank() != nloops || + !(t.getElementType().isSignlessIntOrFloat() || + t.getElementType().isa()); + }; + if (llvm::any_of(args, + [&](Value v) { + return fail(v.getType().dyn_cast()); + }) || + llvm::any_of(op.getOperation()->getResultTypes(), + [&](Type t) { return fail(t.dyn_cast()); })) + return emitError(loc, + "lhlo to linalg conversion expects ranked args of " + "signless int, float or complex element type with ") + << nloops << " parallel iterators: " << *(op.getOperation()); // Construct the indexing maps needed for linalg.generic ops. - SmallVector indexing_maps; SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; // This doesnt account for implicit broadcast, but the working assumption - // here is that are broadcasts have been made explicit. - unsigned nloops = argType.getRank(); + // in HLO/LHLO is that are broadcasts are made explicit. if (isLHLO && !nloops) return failure(); - int operandCount = (isLHLO ? args.size() - 1 : args.size()); - auto verifyArgOrResultType = [&](Value val) -> ShapedType { - auto shapedType = val.getType().dyn_cast(); - if (!shapedType || - (!shapedType.isa() && - !shapedType.isa()) || - shapedType.getRank() != nloops) - return nullptr; - indexing_maps.emplace_back( - nloops ? rewriter.getMultiDimIdentityMap(nloops) - : AffineMap::get(nloops, 0, rewriter.getContext())); - return shapedType; - }; - for (const auto& arg : llvm::enumerate(args)) { - auto shapedType = verifyArgOrResultType(arg.value()); - if (!shapedType) return failure(); - auto& result_or_body_arg = - arg.index() < operandCount ? bodyArgTypes : bodyResultTypes; - result_or_body_arg.emplace_back(shapedType.getElementType()); - } + int numInputs = (isLHLO ? args.size() - 1 : args.size()); + + ValueRange inputs(args.take_front(numInputs)); + for (Value in : inputs) + bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + + ValueRange outputBuffers(args.take_back(args.size() - numInputs)); + for (Value out : outputBuffers) + bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); + if (!isLHLO) { // HLO operations have return as tensor types. assert(bodyResultTypes.empty() && "When lowering HLO ops result can't be part of arguments"); Value result = op.getOperation()->getResult(0); - auto shapedType = verifyArgOrResultType(result); - if (!shapedType) return failure(); - bodyResultTypes.push_back(shapedType.getElementType()); - opResultTypes.push_back(shapedType); + bodyResultTypes.push_back(getElementTypeOrSelf(result)); + opResultTypes.push_back(result.getType()); } - int64_t args_count = bodyArgTypes.size(); - int64_t results_count = bodyResultTypes.size(); + AffineMap commonIndexingMap = + nloops ? rewriter.getMultiDimIdentityMap(nloops) + : AffineMap::get(nloops, 0, rewriter.getContext()); + SmallVector indexing_maps(args.size() + (isLHLO ? 0 : 1), + commonIndexingMap); + auto linalgOp = rewriter.create( - loc, opResultTypes, args, args_count, results_count, indexing_maps, + loc, opResultTypes, inputs, outputBuffers, + /*initTensors=*/ValueRange{}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. Value opResult = lmhlo::HloOpToStdScalarOp::map( op, bodyResultTypes, - llvm::to_vector<2>(args.take_front(args_count)), &rewriter); + llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); nestedBuilder.create(loc, opResult); }); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); @@ -299,12 +302,15 @@ class DataMovementOpConverter : public OpConversionPattern { auto nloops = resultType.getRank(); auto loc = op.getLoc(); auto linalgOp = rewriter.create( - loc, isLHLO ? ArrayRef{} : resultType, args, /*argsIn=*/1, - /*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), + loc, + /*resultTensorTypes=*/isLHLO ? ArrayRef{} : resultType, + /*inputs=*/args.front(), + /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, + /*initTensor=*/ValueRange{}, indexing_maps, + GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); }); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); return success(); } @@ -420,8 +426,8 @@ class LhloBroadcastInDimConverter Value val = rewriter.create(loc, operand, llvm::makeArrayRef({zero})); rewriter.create( - loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), - /*argsIn=*/0, /*argsOut=*/1, + loc, /*inputs=*/ValueRange{}, + /*outputBuffers=*/ValueRange{operand_adaptor.output()}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { @@ -432,9 +438,8 @@ class LhloBroadcastInDimConverter auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, operand_type, &rewriter); rewriter.create( - loc, llvm::None, - llvm::makeArrayRef({operand, operand_adaptor.output()}), - /*argsIn=*/1, /*argsOut=*/1, indexing_maps, + loc, /*inputs=*/ValueRange{operand}, + /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); @@ -697,9 +702,12 @@ class IotaConverter : public OpConversionPattern { unsigned nloops = resultShapedType.getRank(); auto linalgOp = rewriter.create( - iotaOp.getLoc(), isLHLO ? ArrayRef{} : resultShapedType, args, - 0, // args_in - 1, // args_out + iotaOp.getLoc(), + /*resultTensorTypes=*/ + isLHLO ? ArrayRef{} : ArrayRef{resultShapedType}, + /*inputs=*/ValueRange{}, + /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, + /*initTensors=*/ValueRange{}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, @@ -717,7 +725,7 @@ class IotaConverter : public OpConversionPattern { if (isLHLO) rewriter.replaceOp(iotaOp, llvm::None); else - rewriter.replaceOp(iotaOp, linalgOp.output_tensors()); + rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); return success(); } }; @@ -862,8 +870,6 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // %0 = addf %arg4, %arg5 : f32 // "linalg.yield"(%0) : (f32) -> () // }) { -// args_in = 2, -// args_out = 1, // indexing_maps = [#map0, #map0, #map0], // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index 6a67466..9a218b3 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -3,20 +3,24 @@ // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %temp_result = alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + linalg.generic #pointwise_2d_trait + ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) + outs(%temp_result : memref<6x6xf32>) { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> - linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + } + linalg.generic #pointwise_2d_trait + ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) + outs(%result : memref<6x6xf32>) { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + } dealloc %temp_result : memref<6x6xf32> return } @@ -59,36 +63,34 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, %arg2: memref<100x10xf32>) { %0 = alloc() : memref<100x10xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, - indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } %arg1, %0 { - ^bb0(%arg3: f32, %arg4: f32): // no predecessors - linalg.yield %arg3 : f32 - }: memref<100xf32>, memref<100x10xf32> + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1 : memref<100xf32>) + outs(%0 : memref<100x10xf32>) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + linalg.yield %arg3 : f32 + } %1 = alloc() : memref<100x10xf32> linalg.generic { - args_in = 2 : i64, - args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } %arg0, %0, %1 { + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) + outs(%1 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %2 = subf %arg3, %arg4 : f32 linalg.yield %2 : f32 - }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> + } dealloc %0 : memref<100x10xf32> linalg.generic { - args_in = 1 : i64, - args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } %1, %arg2 { + iterator_types = ["parallel", "parallel"]} + ins(%1 : memref<100x10xf32>) + outs(%arg2 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors %2 = exp %arg3 : f32 linalg.yield %2 : f32 - }: memref<100x10xf32>, memref<100x10xf32> + } dealloc %1 : memref<100x10xf32> return } @@ -130,20 +132,24 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %temp_result = alloc() : memref<6x6x6x6xf32> - linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { + linalg.generic #pointwise_4d_trait + ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) + outs(%temp_result : memref<6x6x6x6xf32>) { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> - linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result { + } + linalg.generic #pointwise_4d_trait + ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) + outs(%result : memref<6x6x6x6xf32>) { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + } dealloc %temp_result : memref<6x6x6x6xf32> return } @@ -184,21 +190,25 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %temp_result = alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + linalg.generic #pointwise_2d_trait + ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) + outs(%temp_result : memref<6x6xf32>) { ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): %out = addf %summand_1_in, %summand_2_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + } %result = alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + linalg.generic #pointwise_2d_trait + ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) + outs(%result : memref<6x6xf32>) { ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): %out = mulf %temp_result_in, %multiplier_in : f32 linalg.yield %out : f32 - } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + } dealloc %temp_result : memref<6x6xf32> return %result : memref<6x6xf32> } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 3162f37..d0e19c1 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -263,7 +263,8 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, // CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]] // CHECK-SAME: memref<1x5xf32> into memref<5xf32> // CHECK: linalg.generic {{{.*}}indexing_maps = -// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]] +// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-SAME: ins(%[[RESHAPED_ARG]] : // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32