Integrate LLVM at llvm/llvm-project@7e78d89052
Updates LLVM usage to match [7e78d89052b1](https://github.com/llvm/llvm-project/commit/7e78d89052b1) PiperOrigin-RevId: 333090785
This commit is contained in:
parent
7abd557a61
commit
dd92c8ef61
|
@ -1,2 +1,2 @@
|
|||
93fd30bac3345fea4f5beba3241f1ef4f2f5f419
|
||||
7e78d89052b15f32ea56f018698194c7c9627152
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
|
@ -32,8 +33,10 @@ limitations under the License.
|
|||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
@ -75,69 +78,69 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
OpTy op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
auto argType =
|
||||
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
|
||||
if (!argType.hasRank()) {
|
||||
emitError(loc, "lhlo to linalg conversion expects ranked args");
|
||||
return failure();
|
||||
}
|
||||
auto elemTy = argType.getElementType();
|
||||
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
|
||||
return failure();
|
||||
}
|
||||
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
|
||||
if (!t0) return failure();
|
||||
|
||||
unsigned nloops = t0.getRank();
|
||||
auto fail = [&](ShapedType t) {
|
||||
return !t || !t.hasRank() || t.getRank() != nloops ||
|
||||
!(t.getElementType().isSignlessIntOrFloat() ||
|
||||
t.getElementType().isa<ComplexType>());
|
||||
};
|
||||
if (llvm::any_of(args,
|
||||
[&](Value v) {
|
||||
return fail(v.getType().dyn_cast<ShapedType>());
|
||||
}) ||
|
||||
llvm::any_of(op.getOperation()->getResultTypes(),
|
||||
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
|
||||
return emitError(loc,
|
||||
"lhlo to linalg conversion expects ranked args of "
|
||||
"signless int, float or complex element type with ")
|
||||
<< nloops << " parallel iterators: " << *(op.getOperation());
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
SmallVector<AffineMap, 2> indexing_maps;
|
||||
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
|
||||
|
||||
// This doesnt account for implicit broadcast, but the working assumption
|
||||
// here is that are broadcasts have been made explicit.
|
||||
unsigned nloops = argType.getRank();
|
||||
// in HLO/LHLO is that are broadcasts are made explicit.
|
||||
|
||||
if (isLHLO && !nloops) return failure();
|
||||
|
||||
int operandCount = (isLHLO ? args.size() - 1 : args.size());
|
||||
auto verifyArgOrResultType = [&](Value val) -> ShapedType {
|
||||
auto shapedType = val.getType().dyn_cast<ShapedType>();
|
||||
if (!shapedType ||
|
||||
(!shapedType.isa<MemRefType>() &&
|
||||
!shapedType.isa<RankedTensorType>()) ||
|
||||
shapedType.getRank() != nloops)
|
||||
return nullptr;
|
||||
indexing_maps.emplace_back(
|
||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||
: AffineMap::get(nloops, 0, rewriter.getContext()));
|
||||
return shapedType;
|
||||
};
|
||||
for (const auto& arg : llvm::enumerate(args)) {
|
||||
auto shapedType = verifyArgOrResultType(arg.value());
|
||||
if (!shapedType) return failure();
|
||||
auto& result_or_body_arg =
|
||||
arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
|
||||
result_or_body_arg.emplace_back(shapedType.getElementType());
|
||||
}
|
||||
int numInputs = (isLHLO ? args.size() - 1 : args.size());
|
||||
|
||||
ValueRange inputs(args.take_front(numInputs));
|
||||
for (Value in : inputs)
|
||||
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
|
||||
|
||||
ValueRange outputBuffers(args.take_back(args.size() - numInputs));
|
||||
for (Value out : outputBuffers)
|
||||
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType()));
|
||||
|
||||
if (!isLHLO) {
|
||||
// HLO operations have return as tensor types.
|
||||
assert(bodyResultTypes.empty() &&
|
||||
"When lowering HLO ops result can't be part of arguments");
|
||||
Value result = op.getOperation()->getResult(0);
|
||||
auto shapedType = verifyArgOrResultType(result);
|
||||
if (!shapedType) return failure();
|
||||
bodyResultTypes.push_back(shapedType.getElementType());
|
||||
opResultTypes.push_back(shapedType);
|
||||
bodyResultTypes.push_back(getElementTypeOrSelf(result));
|
||||
opResultTypes.push_back(result.getType());
|
||||
}
|
||||
|
||||
int64_t args_count = bodyArgTypes.size();
|
||||
int64_t results_count = bodyResultTypes.size();
|
||||
AffineMap commonIndexingMap =
|
||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||
: AffineMap::get(nloops, 0, rewriter.getContext());
|
||||
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
||||
commonIndexingMap);
|
||||
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, opResultTypes, args, args_count, results_count, indexing_maps,
|
||||
loc, opResultTypes, inputs, outputBuffers,
|
||||
/*initTensors=*/ValueRange{}, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||
// That method needs to be moved out of there.
|
||||
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||
op, bodyResultTypes,
|
||||
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
|
||||
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||
});
|
||||
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||
|
@ -299,12 +302,15 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
|||
auto nloops = resultType.getRank();
|
||||
auto loc = op.getLoc();
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*argsIn=*/1,
|
||||
/*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
|
||||
loc,
|
||||
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType,
|
||||
/*inputs=*/args.front(),
|
||||
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
|
||||
/*initTensor=*/ValueRange{}, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||
return success();
|
||||
}
|
||||
|
@ -420,8 +426,8 @@ class LhloBroadcastInDimConverter
|
|||
Value val =
|
||||
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||
rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
|
||||
/*argsIn=*/0, /*argsOut=*/1,
|
||||
loc, /*inputs=*/ValueRange{},
|
||||
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
|
@ -432,9 +438,8 @@ class LhloBroadcastInDimConverter
|
|||
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
||||
operand_type, &rewriter);
|
||||
rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None,
|
||||
llvm::makeArrayRef({operand, operand_adaptor.output()}),
|
||||
/*argsIn=*/1, /*argsOut=*/1, indexing_maps,
|
||||
loc, /*inputs=*/ValueRange{operand},
|
||||
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
|
@ -697,9 +702,12 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
|||
unsigned nloops = resultShapedType.getRank();
|
||||
|
||||
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
|
||||
0, // args_in
|
||||
1, // args_out
|
||||
iotaOp.getLoc(),
|
||||
/*resultTensorTypes=*/
|
||||
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType},
|
||||
/*inputs=*/ValueRange{},
|
||||
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
|
||||
/*initTensors=*/ValueRange{},
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
||||
|
@ -717,7 +725,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
|||
if (isLHLO)
|
||||
rewriter.replaceOp(iotaOp, llvm::None);
|
||||
else
|
||||
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
|
||||
rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -862,8 +870,6 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
// %0 = addf %arg4, %arg5 : f32
|
||||
// "linalg.yield"(%0) : (f32) -> ()
|
||||
// }) {
|
||||
// args_in = 2,
|
||||
// args_out = 1,
|
||||
// indexing_maps = [#map0, #map0, #map0],
|
||||
// iterator_types = ["parallel", "parallel"],
|
||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
|
|
|
@ -3,20 +3,24 @@
|
|||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%temp_result : memref<6x6xf32>) {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||
}
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%result : memref<6x6xf32>) {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
}
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
return
|
||||
}
|
||||
|
@ -59,36 +63,34 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
%arg2: memref<100x10xf32>) {
|
||||
%0 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 1 : i64,
|
||||
args_out = 1 : i64,
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
} %arg1, %0 {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg1 : memref<100xf32>)
|
||||
outs(%0 : memref<100x10xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
linalg.yield %arg3 : f32
|
||||
}: memref<100xf32>, memref<100x10xf32>
|
||||
}
|
||||
%1 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 2 : i64,
|
||||
args_out = 1 : i64,
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
} %arg0, %0, %1 {
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
|
||||
outs(%1 : memref<100x10xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
|
||||
%2 = subf %arg3, %arg4 : f32
|
||||
linalg.yield %2 : f32
|
||||
}: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32>
|
||||
}
|
||||
dealloc %0 : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 1 : i64,
|
||||
args_out = 1 : i64,
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
} %1, %arg2 {
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%1 : memref<100x10xf32>)
|
||||
outs(%arg2 : memref<100x10xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%2 = exp %arg3 : f32
|
||||
linalg.yield %2 : f32
|
||||
}: memref<100x10xf32>, memref<100x10xf32>
|
||||
}
|
||||
dealloc %1 : memref<100x10xf32>
|
||||
return
|
||||
}
|
||||
|
@ -130,20 +132,24 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
||||
linalg.generic #pointwise_4d_trait
|
||||
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
||||
outs(%temp_result : memref<6x6x6x6xf32>) {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||
linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result {
|
||||
}
|
||||
linalg.generic #pointwise_4d_trait
|
||||
ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
||||
outs(%result : memref<6x6x6x6xf32>) {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||
}
|
||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||
return
|
||||
}
|
||||
|
@ -184,21 +190,25 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
|||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%temp_result : memref<6x6xf32>) {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
}
|
||||
%result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%result : memref<6x6xf32>) {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
}
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
return %result : memref<6x6xf32>
|
||||
}
|
||||
|
|
|
@ -263,7 +263,8 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
|||
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
|
||||
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps =
|
||||
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
|
||||
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-SAME: ins(%[[RESHAPED_ARG]] :
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
|
|
Loading…
Reference in New Issue