Upstream mhlo.reduce lowering to Linalg to MHLO repo.
In IREE, we use indexed generic op to handle the initial value. However, we lower it to a generic op that carries an init_tensor here, and leave the handle of initialization problem to later passes. PiperOrigin-RevId: 354294807
This commit is contained in:
parent
39589add22
commit
30ce82790d
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||
|
@ -35,19 +36,31 @@ limitations under the License.
|
|||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
|
||||
/// are "parallel" except the last `nReduction` elements, where are "reduction"
|
||||
/// attributes.
|
||||
SmallVector<StringRef, 3> GetParallelAndReductionIterators(
|
||||
unsigned nLoops, unsigned nReduction) {
|
||||
SmallVector<StringRef, 3> res(nLoops - nReduction,
|
||||
getParallelIteratorTypeName());
|
||||
res.append(nReduction, getReductionIteratorTypeName());
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
|
||||
static constexpr StringRef kParallelIterType = "parallel";
|
||||
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
|
||||
return GetParallelAndReductionIterators(nParallelLoops, 0);
|
||||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
|
@ -107,6 +120,35 @@ SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
/// Returns the constant value associated with the init value if the defining
|
||||
/// operation is a constant.
|
||||
Attribute GetInitValueAsConst(Value init) {
|
||||
DenseElementsAttr attr;
|
||||
if (!matchPattern(init, m_Constant(&attr))) return {};
|
||||
auto type = attr.getType().dyn_cast<ShapedType>();
|
||||
if (!type || type.getRank() != 0) return {};
|
||||
return attr.getValue({});
|
||||
}
|
||||
|
||||
/// Returns a permutation AffineMap that puts all reduction dimensions to the
|
||||
/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
|
||||
/// if `rank` is 4 and `reductionDims` is {1, 3}, then
|
||||
/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
|
||||
/// the AffineMap is returned.
|
||||
AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
|
||||
ArrayRef<int64_t> reduction_dims) {
|
||||
llvm::SmallSetVector<int, 4> s;
|
||||
for (auto dim : reduction_dims) s.insert(dim);
|
||||
|
||||
SmallVector<unsigned, 4> permutation;
|
||||
for (int i = 0; i < rank; ++i)
|
||||
if (!s.count(i)) permutation.push_back(i);
|
||||
for (auto dim : reduction_dims) permutation.push_back(dim);
|
||||
|
||||
auto map = AffineMap::getPermutationMap(permutation, context);
|
||||
return inversePermutation(map);
|
||||
}
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
|
@ -1226,6 +1268,146 @@ class DotGeneralOpOnTensorsConversion
|
|||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
OpTy op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Only convert the body of reduction ops to std ops.
|
||||
auto parent_op = op.getOperation()->getParentRegion()->getParentOp();
|
||||
if (!isa<mhlo::ReduceOp, linalg::GenericOp, linalg::IndexedGenericOp>(
|
||||
parent_op)) {
|
||||
return failure();
|
||||
}
|
||||
if (!op.getResult().getType().template isa<TensorType>()) return failure();
|
||||
if (llvm::all_of(args, [](Value arg) {
|
||||
return arg.getType().template isa<TensorType>();
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
|
||||
args, &rewriter);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
|
||||
OpBuilder& b, Location loc, Value arg, ShapedType result_type,
|
||||
ArrayRef<int64_t> reduction_dims) {
|
||||
llvm::SmallSetVector<int, 4> s;
|
||||
for (auto dim : reduction_dims) s.insert(dim);
|
||||
|
||||
SmallVector<unsigned, 4> parallel_dims;
|
||||
SmallVector<Value, 8> dyn_shape;
|
||||
int rank = arg.getType().cast<RankedTensorType>().getRank();
|
||||
for (int i = 0, j = 0; i < rank; ++i) {
|
||||
if (s.count(i)) continue;
|
||||
if (!result_type.isDynamicDim(j++)) continue;
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
|
||||
}
|
||||
|
||||
return dyn_shape;
|
||||
}
|
||||
|
||||
class ReduceRegionReturnOpConversion
|
||||
: public OpConversionPattern<mhlo::ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ReturnOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ReduceOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
mhlo::ReduceOp::Adaptor adaptor(args);
|
||||
if (op.getNumOperands() != 2) {
|
||||
return op.emitError("expects exactly two operands");
|
||||
}
|
||||
Value src = adaptor.operands()[0];
|
||||
auto src_type = src.getType().cast<ShapedType>();
|
||||
int src_rank = src_type.getRank();
|
||||
if (!src_rank) {
|
||||
return rewriter.notifyMatchFailure(op, "expects known-rank args");
|
||||
}
|
||||
|
||||
// Check if init_value is constant. If so, inline the value into the region.
|
||||
Value init_value = adaptor.init_values()[0];
|
||||
Attribute init_const_val = GetInitValueAsConst(init_value);
|
||||
if (init_const_val) {
|
||||
init_value = rewriter.create<ConstantOp>(
|
||||
init_value.getDefiningOp()->getLoc(), init_const_val);
|
||||
} else {
|
||||
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
||||
}
|
||||
|
||||
// Prepare indexing maps for linalg generic op. The elements are for src and
|
||||
// dst. Transpose `src` to make the reduction loops be the innermost,
|
||||
// because it's easier to fully utilize processors.
|
||||
SmallVector<AffineMap, 3> indexing_maps;
|
||||
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
|
||||
indexing_maps.emplace_back(GetTransposeMapForReduction(
|
||||
rewriter.getContext(), src_rank, reduction_dims));
|
||||
|
||||
// The indexing map of `dst` should drop the reduction loops. Since the
|
||||
// reduction loops now are all in the innermost, drops
|
||||
// `reduction_dims.size()` dimensions. We don't need an inverse permutation
|
||||
// here because they are the same.
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
|
||||
exprs, rewriter.getContext()));
|
||||
|
||||
SmallVector<Value, 2> inputs = {adaptor.operands()[0]};
|
||||
Type result_type = op.getResult(0).getType();
|
||||
auto shaped_type = result_type.cast<ShapedType>();
|
||||
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
|
||||
reduction_dims);
|
||||
auto init_tensor =
|
||||
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
||||
rewriter.getIndexType());
|
||||
Region& region = init_tensor.body();
|
||||
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
rewriter.create<tensor::YieldOp>(loc, init_value);
|
||||
}
|
||||
|
||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
|
||||
/*outputBuffers=*/ValueRange{init_tensor}, indexing_maps,
|
||||
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
|
||||
|
||||
// Convert the signature of the body. The reduce op region apply function
|
||||
// has a signature (lhs, rhs) -> output, all of the same tensor type t.
|
||||
// This is converted to a function with the same signature but with
|
||||
// element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
|
||||
// be converted to "(f32, f32, f32)".
|
||||
Region& region = linalg_op.region();
|
||||
rewriter.inlineRegionBefore(op.body(), region, region.end());
|
||||
TypeConverter::SignatureConversion signatureConverter(2);
|
||||
signatureConverter.addInputs(0, src_type.getElementType());
|
||||
signatureConverter.addInputs(1, src_type.getElementType());
|
||||
rewriter.applySignatureConversion(®ion, signatureConverter);
|
||||
rewriter.replaceOp(op, linalg_op.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
|
@ -1356,54 +1538,57 @@ namespace mhlo {
|
|||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns
|
||||
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ClampOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::PowOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SignOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>,
|
||||
DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>(
|
||||
context);
|
||||
patterns->insert<
|
||||
BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ClampOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::PowOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SignOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
||||
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
||||
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
||||
ReduceRegionReturnOpConversion>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
|
|
@ -980,3 +980,178 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
|
|||
tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ({
|
||||
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
|
||||
%1 = mhlo.add %arg3, %arg4 : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
|
||||
return %0 : tensor<5xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: @reduce_add
|
||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ({
|
||||
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
|
||||
%1 = mhlo.minimum %arg3, %arg4 : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
|
||||
return %0 : tensor<5xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: @reduce_minimum
|
||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ({
|
||||
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
|
||||
%1 = mhlo.maximum %arg3, %arg4 : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
|
||||
return %0 : tensor<5xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: @reduce_maximum
|
||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<4xi32> {
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ({
|
||||
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
|
||||
%1 = mhlo.maximum %arg3, %arg4 : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: @reduce_dim0
|
||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> {
|
||||
%cst = constant dense<0xFF800000> : tensor<f32>
|
||||
%0 = "mhlo.reduce"(%arg0, %cst) ({
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
|
||||
%1 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: @reduce_init_const
|
||||
// CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>,
|
||||
%arg1: tensor<i32>) -> tensor<4xi32> {
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ({
|
||||
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<5x4x3xi32>, tensor<i32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
// CHECK-LABEL: @reduce_multi_dimensions
|
||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32> {
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ({
|
||||
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
|
||||
%1 = mhlo.add %arg3, %arg4 : tensor<i32>
|
||||
"mhlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
|
||||
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<?xi32>)
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
|
Loading…
Reference in New Issue