Integrate LLVM at llvm/llvm-project@7e78d89052
Updates LLVM usage to match [7e78d89052b1](https://github.com/llvm/llvm-project/commit/7e78d89052b1) PiperOrigin-RevId: 333090785
This commit is contained in:
parent
7abd557a61
commit
dd92c8ef61
|
@ -1,2 +1,2 @@
|
||||||
93fd30bac3345fea4f5beba3241f1ef4f2f5f419
|
7e78d89052b15f32ea56f018698194c7c9627152
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||||
|
@ -32,8 +33,10 @@ limitations under the License.
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/OperationSupport.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
@ -75,69 +78,69 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
OpTy op, ArrayRef<Value> args,
|
OpTy op, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto argType =
|
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
|
||||||
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
|
if (!t0) return failure();
|
||||||
if (!argType.hasRank()) {
|
|
||||||
emitError(loc, "lhlo to linalg conversion expects ranked args");
|
unsigned nloops = t0.getRank();
|
||||||
return failure();
|
auto fail = [&](ShapedType t) {
|
||||||
}
|
return !t || !t.hasRank() || t.getRank() != nloops ||
|
||||||
auto elemTy = argType.getElementType();
|
!(t.getElementType().isSignlessIntOrFloat() ||
|
||||||
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
|
t.getElementType().isa<ComplexType>());
|
||||||
return failure();
|
};
|
||||||
}
|
if (llvm::any_of(args,
|
||||||
|
[&](Value v) {
|
||||||
|
return fail(v.getType().dyn_cast<ShapedType>());
|
||||||
|
}) ||
|
||||||
|
llvm::any_of(op.getOperation()->getResultTypes(),
|
||||||
|
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
|
||||||
|
return emitError(loc,
|
||||||
|
"lhlo to linalg conversion expects ranked args of "
|
||||||
|
"signless int, float or complex element type with ")
|
||||||
|
<< nloops << " parallel iterators: " << *(op.getOperation());
|
||||||
|
|
||||||
// Construct the indexing maps needed for linalg.generic ops.
|
// Construct the indexing maps needed for linalg.generic ops.
|
||||||
SmallVector<AffineMap, 2> indexing_maps;
|
|
||||||
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
|
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
|
||||||
|
|
||||||
// This doesnt account for implicit broadcast, but the working assumption
|
// This doesnt account for implicit broadcast, but the working assumption
|
||||||
// here is that are broadcasts have been made explicit.
|
// in HLO/LHLO is that are broadcasts are made explicit.
|
||||||
unsigned nloops = argType.getRank();
|
|
||||||
|
|
||||||
if (isLHLO && !nloops) return failure();
|
if (isLHLO && !nloops) return failure();
|
||||||
|
|
||||||
int operandCount = (isLHLO ? args.size() - 1 : args.size());
|
int numInputs = (isLHLO ? args.size() - 1 : args.size());
|
||||||
auto verifyArgOrResultType = [&](Value val) -> ShapedType {
|
|
||||||
auto shapedType = val.getType().dyn_cast<ShapedType>();
|
ValueRange inputs(args.take_front(numInputs));
|
||||||
if (!shapedType ||
|
for (Value in : inputs)
|
||||||
(!shapedType.isa<MemRefType>() &&
|
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
|
||||||
!shapedType.isa<RankedTensorType>()) ||
|
|
||||||
shapedType.getRank() != nloops)
|
ValueRange outputBuffers(args.take_back(args.size() - numInputs));
|
||||||
return nullptr;
|
for (Value out : outputBuffers)
|
||||||
indexing_maps.emplace_back(
|
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType()));
|
||||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
|
||||||
: AffineMap::get(nloops, 0, rewriter.getContext()));
|
|
||||||
return shapedType;
|
|
||||||
};
|
|
||||||
for (const auto& arg : llvm::enumerate(args)) {
|
|
||||||
auto shapedType = verifyArgOrResultType(arg.value());
|
|
||||||
if (!shapedType) return failure();
|
|
||||||
auto& result_or_body_arg =
|
|
||||||
arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
|
|
||||||
result_or_body_arg.emplace_back(shapedType.getElementType());
|
|
||||||
}
|
|
||||||
if (!isLHLO) {
|
if (!isLHLO) {
|
||||||
// HLO operations have return as tensor types.
|
// HLO operations have return as tensor types.
|
||||||
assert(bodyResultTypes.empty() &&
|
assert(bodyResultTypes.empty() &&
|
||||||
"When lowering HLO ops result can't be part of arguments");
|
"When lowering HLO ops result can't be part of arguments");
|
||||||
Value result = op.getOperation()->getResult(0);
|
Value result = op.getOperation()->getResult(0);
|
||||||
auto shapedType = verifyArgOrResultType(result);
|
bodyResultTypes.push_back(getElementTypeOrSelf(result));
|
||||||
if (!shapedType) return failure();
|
opResultTypes.push_back(result.getType());
|
||||||
bodyResultTypes.push_back(shapedType.getElementType());
|
|
||||||
opResultTypes.push_back(shapedType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t args_count = bodyArgTypes.size();
|
AffineMap commonIndexingMap =
|
||||||
int64_t results_count = bodyResultTypes.size();
|
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||||
|
: AffineMap::get(nloops, 0, rewriter.getContext());
|
||||||
|
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
||||||
|
commonIndexingMap);
|
||||||
|
|
||||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
loc, opResultTypes, args, args_count, results_count, indexing_maps,
|
loc, opResultTypes, inputs, outputBuffers,
|
||||||
|
/*initTensors=*/ValueRange{}, indexing_maps,
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||||
// That method needs to be moved out of there.
|
// That method needs to be moved out of there.
|
||||||
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||||
op, bodyResultTypes,
|
op, bodyResultTypes,
|
||||||
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
|
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
||||||
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||||
});
|
});
|
||||||
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||||
|
@ -299,12 +302,15 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||||
auto nloops = resultType.getRank();
|
auto nloops = resultType.getRank();
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*argsIn=*/1,
|
loc,
|
||||||
/*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
|
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType,
|
||||||
|
/*inputs=*/args.front(),
|
||||||
|
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
|
||||||
|
/*initTensor=*/ValueRange{}, indexing_maps,
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -420,8 +426,8 @@ class LhloBroadcastInDimConverter
|
||||||
Value val =
|
Value val =
|
||||||
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||||
rewriter.create<linalg::GenericOp>(
|
rewriter.create<linalg::GenericOp>(
|
||||||
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
|
loc, /*inputs=*/ValueRange{},
|
||||||
/*argsIn=*/0, /*argsOut=*/1,
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
|
@ -432,9 +438,8 @@ class LhloBroadcastInDimConverter
|
||||||
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
||||||
operand_type, &rewriter);
|
operand_type, &rewriter);
|
||||||
rewriter.create<linalg::GenericOp>(
|
rewriter.create<linalg::GenericOp>(
|
||||||
loc, llvm::None,
|
loc, /*inputs=*/ValueRange{operand},
|
||||||
llvm::makeArrayRef({operand, operand_adaptor.output()}),
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
|
||||||
/*argsIn=*/1, /*argsOut=*/1, indexing_maps,
|
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||||
|
@ -697,9 +702,12 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
unsigned nloops = resultShapedType.getRank();
|
unsigned nloops = resultShapedType.getRank();
|
||||||
|
|
||||||
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||||
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
|
iotaOp.getLoc(),
|
||||||
0, // args_in
|
/*resultTensorTypes=*/
|
||||||
1, // args_out
|
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType},
|
||||||
|
/*inputs=*/ValueRange{},
|
||||||
|
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
|
||||||
|
/*initTensors=*/ValueRange{},
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
||||||
|
@ -717,7 +725,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
if (isLHLO)
|
if (isLHLO)
|
||||||
rewriter.replaceOp(iotaOp, llvm::None);
|
rewriter.replaceOp(iotaOp, llvm::None);
|
||||||
else
|
else
|
||||||
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
|
rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -862,8 +870,6 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
// %0 = addf %arg4, %arg5 : f32
|
// %0 = addf %arg4, %arg5 : f32
|
||||||
// "linalg.yield"(%0) : (f32) -> ()
|
// "linalg.yield"(%0) : (f32) -> ()
|
||||||
// }) {
|
// }) {
|
||||||
// args_in = 2,
|
|
||||||
// args_out = 1,
|
|
||||||
// indexing_maps = [#map0, #map0, #map0],
|
// indexing_maps = [#map0, #map0, #map0],
|
||||||
// iterator_types = ["parallel", "parallel"],
|
// iterator_types = ["parallel", "parallel"],
|
||||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
|
|
@ -3,20 +3,24 @@
|
||||||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||||
%temp_result = alloc() : memref<6x6xf32>
|
%temp_result = alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
linalg.generic #pointwise_2d_trait
|
||||||
|
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
|
outs(%temp_result : memref<6x6xf32>) {
|
||||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
}
|
||||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
linalg.generic #pointwise_2d_trait
|
||||||
|
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
|
outs(%result : memref<6x6xf32>) {
|
||||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
}
|
||||||
dealloc %temp_result : memref<6x6xf32>
|
dealloc %temp_result : memref<6x6xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -59,36 +63,34 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
%arg2: memref<100x10xf32>) {
|
%arg2: memref<100x10xf32>) {
|
||||||
%0 = alloc() : memref<100x10xf32>
|
%0 = alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
args_in = 1 : i64,
|
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||||
args_out = 1 : i64,
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
|
iterator_types = ["parallel", "parallel"]}
|
||||||
iterator_types = ["parallel", "parallel"]
|
ins(%arg1 : memref<100xf32>)
|
||||||
} %arg1, %0 {
|
outs(%0 : memref<100x10xf32>) {
|
||||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||||
linalg.yield %arg3 : f32
|
linalg.yield %arg3 : f32
|
||||||
}: memref<100xf32>, memref<100x10xf32>
|
}
|
||||||
%1 = alloc() : memref<100x10xf32>
|
%1 = alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
args_in = 2 : i64,
|
|
||||||
args_out = 1 : i64,
|
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
iterator_types = ["parallel", "parallel"]
|
iterator_types = ["parallel", "parallel"]}
|
||||||
} %arg0, %0, %1 {
|
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
|
||||||
|
outs(%1 : memref<100x10xf32>) {
|
||||||
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
|
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
|
||||||
%2 = subf %arg3, %arg4 : f32
|
%2 = subf %arg3, %arg4 : f32
|
||||||
linalg.yield %2 : f32
|
linalg.yield %2 : f32
|
||||||
}: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32>
|
}
|
||||||
dealloc %0 : memref<100x10xf32>
|
dealloc %0 : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
args_in = 1 : i64,
|
|
||||||
args_out = 1 : i64,
|
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
iterator_types = ["parallel", "parallel"]
|
iterator_types = ["parallel", "parallel"]}
|
||||||
} %1, %arg2 {
|
ins(%1 : memref<100x10xf32>)
|
||||||
|
outs(%arg2 : memref<100x10xf32>) {
|
||||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||||
%2 = exp %arg3 : f32
|
%2 = exp %arg3 : f32
|
||||||
linalg.yield %2 : f32
|
linalg.yield %2 : f32
|
||||||
}: memref<100x10xf32>, memref<100x10xf32>
|
}
|
||||||
dealloc %1 : memref<100x10xf32>
|
dealloc %1 : memref<100x10xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -130,20 +132,24 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||||
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
linalg.generic #pointwise_4d_trait
|
||||||
|
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
||||||
|
outs(%temp_result : memref<6x6x6x6xf32>) {
|
||||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
}
|
||||||
linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result {
|
linalg.generic #pointwise_4d_trait
|
||||||
|
ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
||||||
|
outs(%result : memref<6x6x6x6xf32>) {
|
||||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
}
|
||||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -184,21 +190,25 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||||
%temp_result = alloc() : memref<6x6xf32>
|
%temp_result = alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
linalg.generic #pointwise_2d_trait
|
||||||
|
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
|
outs(%temp_result : memref<6x6xf32>) {
|
||||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||||
%out = addf %summand_1_in, %summand_2_in : f32
|
%out = addf %summand_1_in, %summand_2_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
}
|
||||||
%result = alloc() : memref<6x6xf32>
|
%result = alloc() : memref<6x6xf32>
|
||||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
linalg.generic #pointwise_2d_trait
|
||||||
|
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
||||||
|
outs(%result : memref<6x6xf32>) {
|
||||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||||
linalg.yield %out : f32
|
linalg.yield %out : f32
|
||||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
}
|
||||||
dealloc %temp_result : memref<6x6xf32>
|
dealloc %temp_result : memref<6x6xf32>
|
||||||
return %result : memref<6x6xf32>
|
return %result : memref<6x6xf32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -263,7 +263,8 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
||||||
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
|
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
|
||||||
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
|
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
|
||||||
// CHECK: linalg.generic {{{.*}}indexing_maps =
|
// CHECK: linalg.generic {{{.*}}indexing_maps =
|
||||||
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
|
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
|
// CHECK-SAME: ins(%[[RESHAPED_ARG]] :
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue