Updates LLVM usage to match
[7e78d89052b1](https://github.com/llvm/llvm-project/commit/7e78d89052b1)

PiperOrigin-RevId: 333090785
This commit is contained in:
Benjamin Kramer 2020-09-22 09:06:55 -07:00 committed by TensorFlow MLIR Team
parent 7abd557a61
commit dd92c8ef61
4 changed files with 107 additions and 90 deletions

View File

@ -1,2 +1,2 @@
93fd30bac3345fea4f5beba3241f1ef4f2f5f419
7e78d89052b15f32ea56f018698194c7c9627152

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <numeric>
#include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
@ -32,8 +33,10 @@ limitations under the License.
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -75,69 +78,69 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
auto argType =
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
if (!argType.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects ranked args");
return failure();
}
auto elemTy = argType.getElementType();
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
return failure();
}
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
if (!t0) return failure();
unsigned nloops = t0.getRank();
auto fail = [&](ShapedType t) {
return !t || !t.hasRank() || t.getRank() != nloops ||
!(t.getElementType().isSignlessIntOrFloat() ||
t.getElementType().isa<ComplexType>());
};
if (llvm::any_of(args,
[&](Value v) {
return fail(v.getType().dyn_cast<ShapedType>());
}) ||
llvm::any_of(op.getOperation()->getResultTypes(),
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
return emitError(loc,
"lhlo to linalg conversion expects ranked args of "
"signless int, float or complex element type with ")
<< nloops << " parallel iterators: " << *(op.getOperation());
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<AffineMap, 2> indexing_maps;
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
// This doesnt account for implicit broadcast, but the working assumption
// here is that are broadcasts have been made explicit.
unsigned nloops = argType.getRank();
// in HLO/LHLO is that are broadcasts are made explicit.
if (isLHLO && !nloops) return failure();
int operandCount = (isLHLO ? args.size() - 1 : args.size());
auto verifyArgOrResultType = [&](Value val) -> ShapedType {
auto shapedType = val.getType().dyn_cast<ShapedType>();
if (!shapedType ||
(!shapedType.isa<MemRefType>() &&
!shapedType.isa<RankedTensorType>()) ||
shapedType.getRank() != nloops)
return nullptr;
indexing_maps.emplace_back(
nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext()));
return shapedType;
};
for (const auto& arg : llvm::enumerate(args)) {
auto shapedType = verifyArgOrResultType(arg.value());
if (!shapedType) return failure();
auto& result_or_body_arg =
arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
result_or_body_arg.emplace_back(shapedType.getElementType());
}
int numInputs = (isLHLO ? args.size() - 1 : args.size());
ValueRange inputs(args.take_front(numInputs));
for (Value in : inputs)
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
ValueRange outputBuffers(args.take_back(args.size() - numInputs));
for (Value out : outputBuffers)
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType()));
if (!isLHLO) {
// HLO operations have return as tensor types.
assert(bodyResultTypes.empty() &&
"When lowering HLO ops result can't be part of arguments");
Value result = op.getOperation()->getResult(0);
auto shapedType = verifyArgOrResultType(result);
if (!shapedType) return failure();
bodyResultTypes.push_back(shapedType.getElementType());
opResultTypes.push_back(shapedType);
bodyResultTypes.push_back(getElementTypeOrSelf(result));
opResultTypes.push_back(result.getType());
}
int64_t args_count = bodyArgTypes.size();
int64_t results_count = bodyResultTypes.size();
AffineMap commonIndexingMap =
nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext());
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
commonIndexingMap);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, args, args_count, results_count, indexing_maps,
loc, opResultTypes, inputs, outputBuffers,
/*initTensors=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there.
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
});
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
@ -299,12 +302,15 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
auto nloops = resultType.getRank();
auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*argsIn=*/1,
/*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
loc,
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType,
/*inputs=*/args.front(),
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
/*initTensor=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
});
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
return success();
}
@ -420,8 +426,8 @@ class LhloBroadcastInDimConverter
Value val =
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
/*argsIn=*/0, /*argsOut=*/1,
loc, /*inputs=*/ValueRange{},
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
@ -432,9 +438,8 @@ class LhloBroadcastInDimConverter
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
operand_type, &rewriter);
rewriter.create<linalg::GenericOp>(
loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}),
/*argsIn=*/1, /*argsOut=*/1, indexing_maps,
loc, /*inputs=*/ValueRange{operand},
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
@ -697,9 +702,12 @@ class IotaConverter : public OpConversionPattern<OpTy> {
unsigned nloops = resultShapedType.getRank();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
0, // args_in
1, // args_out
iotaOp.getLoc(),
/*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType},
/*inputs=*/ValueRange{},
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
/*initTensors=*/ValueRange{},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
@ -717,7 +725,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None);
else
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
return success();
}
};
@ -862,8 +870,6 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// %0 = addf %arg4, %arg5 : f32
// "linalg.yield"(%0) : (f32) -> ()
// }) {
// args_in = 2,
// args_out = 1,
// indexing_maps = [#map0, #map0, #map0],
// iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()

View File

@ -3,20 +3,24 @@
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
}
linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
}
dealloc %temp_result : memref<6x6xf32>
return
}
@ -59,36 +63,34 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg2: memref<100x10xf32>) {
%0 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} %arg1, %0 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}: memref<100xf32>, memref<100x10xf32>
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : memref<100xf32>)
outs(%0 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}
%1 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} %arg0, %0, %1 {
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32>
}
dealloc %0 : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} %1, %arg2 {
iterator_types = ["parallel", "parallel"]}
ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%2 = exp %arg3 : f32
linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32>
}
dealloc %1 : memref<100x10xf32>
return
}
@ -130,20 +132,24 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
linalg.generic #pointwise_4d_trait
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%temp_result : memref<6x6x6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result {
}
linalg.generic #pointwise_4d_trait
ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%result : memref<6x6x6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
}
dealloc %temp_result : memref<6x6x6x6xf32>
return
}
@ -184,21 +190,25 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
}
%result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
}
dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32>
}

View File

@ -263,7 +263,8 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps =
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-SAME: ins(%[[RESHAPED_ARG]] :
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32