Updates LLVM usage to match
[7e78d89052b1](https://github.com/llvm/llvm-project/commit/7e78d89052b1)

PiperOrigin-RevId: 333090785
This commit is contained in:
Benjamin Kramer 2020-09-22 09:06:55 -07:00 committed by TensorFlow MLIR Team
parent 7abd557a61
commit dd92c8ef61
4 changed files with 107 additions and 90 deletions

View File

@ -1,2 +1,2 @@
93fd30bac3345fea4f5beba3241f1ef4f2f5f419 7e78d89052b15f32ea56f018698194c7c9627152

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <numeric> #include <numeric>
#include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
@ -32,8 +33,10 @@ limitations under the License.
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -75,69 +78,69 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
OpTy op, ArrayRef<Value> args, OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto argType = ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
op.getOperation()->getOperand(0).getType().template cast<ShapedType>(); if (!t0) return failure();
if (!argType.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects ranked args"); unsigned nloops = t0.getRank();
return failure(); auto fail = [&](ShapedType t) {
} return !t || !t.hasRank() || t.getRank() != nloops ||
auto elemTy = argType.getElementType(); !(t.getElementType().isSignlessIntOrFloat() ||
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) { t.getElementType().isa<ComplexType>());
return failure(); };
} if (llvm::any_of(args,
[&](Value v) {
return fail(v.getType().dyn_cast<ShapedType>());
}) ||
llvm::any_of(op.getOperation()->getResultTypes(),
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
return emitError(loc,
"lhlo to linalg conversion expects ranked args of "
"signless int, float or complex element type with ")
<< nloops << " parallel iterators: " << *(op.getOperation());
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
SmallVector<AffineMap, 2> indexing_maps;
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes; SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
// This doesnt account for implicit broadcast, but the working assumption // This doesnt account for implicit broadcast, but the working assumption
// here is that are broadcasts have been made explicit. // in HLO/LHLO is that are broadcasts are made explicit.
unsigned nloops = argType.getRank();
if (isLHLO && !nloops) return failure(); if (isLHLO && !nloops) return failure();
int operandCount = (isLHLO ? args.size() - 1 : args.size()); int numInputs = (isLHLO ? args.size() - 1 : args.size());
auto verifyArgOrResultType = [&](Value val) -> ShapedType {
auto shapedType = val.getType().dyn_cast<ShapedType>(); ValueRange inputs(args.take_front(numInputs));
if (!shapedType || for (Value in : inputs)
(!shapedType.isa<MemRefType>() && bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
!shapedType.isa<RankedTensorType>()) ||
shapedType.getRank() != nloops) ValueRange outputBuffers(args.take_back(args.size() - numInputs));
return nullptr; for (Value out : outputBuffers)
indexing_maps.emplace_back( bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType()));
nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext()));
return shapedType;
};
for (const auto& arg : llvm::enumerate(args)) {
auto shapedType = verifyArgOrResultType(arg.value());
if (!shapedType) return failure();
auto& result_or_body_arg =
arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
result_or_body_arg.emplace_back(shapedType.getElementType());
}
if (!isLHLO) { if (!isLHLO) {
// HLO operations have return as tensor types. // HLO operations have return as tensor types.
assert(bodyResultTypes.empty() && assert(bodyResultTypes.empty() &&
"When lowering HLO ops result can't be part of arguments"); "When lowering HLO ops result can't be part of arguments");
Value result = op.getOperation()->getResult(0); Value result = op.getOperation()->getResult(0);
auto shapedType = verifyArgOrResultType(result); bodyResultTypes.push_back(getElementTypeOrSelf(result));
if (!shapedType) return failure(); opResultTypes.push_back(result.getType());
bodyResultTypes.push_back(shapedType.getElementType());
opResultTypes.push_back(shapedType);
} }
int64_t args_count = bodyArgTypes.size(); AffineMap commonIndexingMap =
int64_t results_count = bodyResultTypes.size(); nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext());
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
commonIndexingMap);
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, args, args_count, results_count, indexing_maps, loc, opResultTypes, inputs, outputBuffers,
/*initTensors=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace. // TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there. // That method needs to be moved out of there.
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>( Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes, op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter); llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult); nestedBuilder.create<linalg::YieldOp>(loc, opResult);
}); });
rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
@ -299,12 +302,15 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
auto nloops = resultType.getRank(); auto nloops = resultType.getRank();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*argsIn=*/1, loc,
/*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType,
/*inputs=*/args.front(),
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
/*initTensor=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
}); });
rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
return success(); return success();
} }
@ -420,8 +426,8 @@ class LhloBroadcastInDimConverter
Value val = Value val =
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero})); rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
rewriter.create<linalg::GenericOp>( rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), loc, /*inputs=*/ValueRange{},
/*argsIn=*/0, /*argsOut=*/1, /*outputBuffers=*/ValueRange{operand_adaptor.output()},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
@ -432,9 +438,8 @@ class LhloBroadcastInDimConverter
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
operand_type, &rewriter); operand_type, &rewriter);
rewriter.create<linalg::GenericOp>( rewriter.create<linalg::GenericOp>(
loc, llvm::None, loc, /*inputs=*/ValueRange{operand},
llvm::makeArrayRef({operand, operand_adaptor.output()}), /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
/*argsIn=*/1, /*argsOut=*/1, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
@ -697,9 +702,12 @@ class IotaConverter : public OpConversionPattern<OpTy> {
unsigned nloops = resultShapedType.getRank(); unsigned nloops = resultShapedType.getRank();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>( auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args, iotaOp.getLoc(),
0, // args_in /*resultTensorTypes=*/
1, // args_out isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType},
/*inputs=*/ValueRange{},
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
/*initTensors=*/ValueRange{},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
@ -717,7 +725,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
if (isLHLO) if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None); rewriter.replaceOp(iotaOp, llvm::None);
else else
rewriter.replaceOp(iotaOp, linalgOp.output_tensors()); rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
return success(); return success();
} }
}; };
@ -862,8 +870,6 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// %0 = addf %arg4, %arg5 : f32 // %0 = addf %arg4, %arg5 : f32
// "linalg.yield"(%0) : (f32) -> () // "linalg.yield"(%0) : (f32) -> ()
// }) { // }) {
// args_in = 2,
// args_out = 1,
// indexing_maps = [#map0, #map0, #map0], // indexing_maps = [#map0, #map0, #map0],
// iterator_types = ["parallel", "parallel"], // iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()

View File

@ -3,20 +3,24 @@
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32> %temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32 %out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> }
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32 %out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> }
dealloc %temp_result : memref<6x6xf32> dealloc %temp_result : memref<6x6xf32>
return return
} }
@ -59,36 +63,34 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg2: memref<100x10xf32>) { %arg2: memref<100x10xf32>) {
%0 = alloc() : memref<100x10xf32> %0 = alloc() : memref<100x10xf32>
linalg.generic { linalg.generic {
args_in = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0)>,
args_out = 1 : i64, affine_map<(d0, d1) -> (d0, d1)>],
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
iterator_types = ["parallel", "parallel"] ins(%arg1 : memref<100xf32>)
} %arg1, %0 { outs(%0 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors ^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32 linalg.yield %arg3 : f32
}: memref<100xf32>, memref<100x10xf32> }
%1 = alloc() : memref<100x10xf32> %1 = alloc() : memref<100x10xf32>
linalg.generic { linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"] iterator_types = ["parallel", "parallel"]}
} %arg0, %0, %1 { ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = subf %arg3, %arg4 : f32 %2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32 linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> }
dealloc %0 : memref<100x10xf32> dealloc %0 : memref<100x10xf32>
linalg.generic { linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"] iterator_types = ["parallel", "parallel"]}
} %1, %arg2 { ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors ^bb0(%arg3: f32, %arg4: f32): // no predecessors
%2 = exp %arg3 : f32 %2 = exp %arg3 : f32
linalg.yield %2 : f32 linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32> }
dealloc %1 : memref<100x10xf32> dealloc %1 : memref<100x10xf32>
return return
} }
@ -130,20 +132,24 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// ----- // -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} #pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32> %temp_result = alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { linalg.generic #pointwise_4d_trait
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%temp_result : memref<6x6x6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32 %out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> }
linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result { linalg.generic #pointwise_4d_trait
ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%result : memref<6x6x6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32 %out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> }
dealloc %temp_result : memref<6x6x6x6xf32> dealloc %temp_result : memref<6x6x6x6xf32>
return return
} }
@ -184,21 +190,25 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
// ----- // -----
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32> %temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32 %out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> }
%result = alloc() : memref<6x6xf32> %result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32 %out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32 linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> }
dealloc %temp_result : memref<6x6xf32> dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32> return %result : memref<6x6xf32>
} }

View File

@ -263,7 +263,8 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]] // CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
// CHECK-SAME: memref<1x5xf32> into memref<5xf32> // CHECK-SAME: memref<1x5xf32> into memref<5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = // CHECK: linalg.generic {{{.*}}indexing_maps =
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]] // CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-SAME: ins(%[[RESHAPED_ARG]] :
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32