Upstream mhlo.reduce lowering to Linalg to MHLO repo.

In IREE, we use indexed generic op to handle the initial value. However, we
lower it to a generic op that carries an init_tensor here, and leave the handle
of initialization problem to later passes.

PiperOrigin-RevId: 354294807
This commit is contained in:
Hanhan Wang 2021-01-28 05:44:49 -08:00 committed by TensorFlow MLIR Team
parent 39589add22
commit 30ce82790d
2 changed files with 410 additions and 50 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <numeric> #include <numeric>
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
@ -35,19 +36,31 @@ limitations under the License.
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace { namespace {
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
SmallVector<StringRef, 3> GetParallelAndReductionIterators(
unsigned nLoops, unsigned nReduction) {
SmallVector<StringRef, 3> res(nLoops - nReduction,
getParallelIteratorTypeName());
res.append(nReduction, getReductionIteratorTypeName());
return res;
}
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) { SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
static constexpr StringRef kParallelIterType = "parallel"; return GetParallelAndReductionIterators(nParallelLoops, 0);
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
} }
template <bool isLHLO = true> template <bool isLHLO = true>
@ -107,6 +120,35 @@ SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
return ret; return ret;
} }
/// Returns the constant value associated with the init value if the defining
/// operation is a constant.
Attribute GetInitValueAsConst(Value init) {
DenseElementsAttr attr;
if (!matchPattern(init, m_Constant(&attr))) return {};
auto type = attr.getType().dyn_cast<ShapedType>();
if (!type || type.getRank() != 0) return {};
return attr.getValue({});
}
/// Returns a permutation AffineMap that puts all reduction dimensions to the
/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
/// if `rank` is 4 and `reductionDims` is {1, 3}, then
/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
/// the AffineMap is returned.
AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
ArrayRef<int64_t> reduction_dims) {
llvm::SmallSetVector<int, 4> s;
for (auto dim : reduction_dims) s.insert(dim);
SmallVector<unsigned, 4> permutation;
for (int i = 0; i < rank; ++i)
if (!s.count(i)) permutation.push_back(i);
for (auto dim : reduction_dims) permutation.push_back(dim);
auto map = AffineMap::getPermutationMap(permutation, context);
return inversePermutation(map);
}
template <typename OpTy, bool isLHLO = true> template <typename OpTy, bool isLHLO = true>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> { class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public: public:
@ -1226,6 +1268,146 @@ class DotGeneralOpOnTensorsConversion
} }
}; };
template <typename OpTy>
struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Only convert the body of reduction ops to std ops.
auto parent_op = op.getOperation()->getParentRegion()->getParentOp();
if (!isa<mhlo::ReduceOp, linalg::GenericOp, linalg::IndexedGenericOp>(
parent_op)) {
return failure();
}
if (!op.getResult().getType().template isa<TensorType>()) return failure();
if (llvm::all_of(args, [](Value arg) {
return arg.getType().template isa<TensorType>();
})) {
return failure();
}
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
args, &rewriter);
rewriter.replaceOp(op, result);
return success();
}
};
SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
OpBuilder& b, Location loc, Value arg, ShapedType result_type,
ArrayRef<int64_t> reduction_dims) {
llvm::SmallSetVector<int, 4> s;
for (auto dim : reduction_dims) s.insert(dim);
SmallVector<unsigned, 4> parallel_dims;
SmallVector<Value, 8> dyn_shape;
int rank = arg.getType().cast<RankedTensorType>().getRank();
for (int i = 0, j = 0; i < rank; ++i) {
if (s.count(i)) continue;
if (!result_type.isDynamicDim(j++)) continue;
dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
}
return dyn_shape;
}
class ReduceRegionReturnOpConversion
: public OpConversionPattern<mhlo::ReturnOp> {
public:
using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ReturnOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
return success();
}
};
class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
public:
using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ReduceOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
Location loc = op.getLoc();
mhlo::ReduceOp::Adaptor adaptor(args);
if (op.getNumOperands() != 2) {
return op.emitError("expects exactly two operands");
}
Value src = adaptor.operands()[0];
auto src_type = src.getType().cast<ShapedType>();
int src_rank = src_type.getRank();
if (!src_rank) {
return rewriter.notifyMatchFailure(op, "expects known-rank args");
}
// Check if init_value is constant. If so, inline the value into the region.
Value init_value = adaptor.init_values()[0];
Attribute init_const_val = GetInitValueAsConst(init_value);
if (init_const_val) {
init_value = rewriter.create<ConstantOp>(
init_value.getDefiningOp()->getLoc(), init_const_val);
} else {
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
}
// Prepare indexing maps for linalg generic op. The elements are for src and
// dst. Transpose `src` to make the reduction loops be the innermost,
// because it's easier to fully utilize processors.
SmallVector<AffineMap, 3> indexing_maps;
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
indexing_maps.emplace_back(GetTransposeMapForReduction(
rewriter.getContext(), src_rank, reduction_dims));
// The indexing map of `dst` should drop the reduction loops. Since the
// reduction loops now are all in the innermost, drops
// `reduction_dims.size()` dimensions. We don't need an inverse permutation
// here because they are the same.
SmallVector<AffineExpr, 4> exprs;
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
exprs, rewriter.getContext()));
SmallVector<Value, 2> inputs = {adaptor.operands()[0]};
Type result_type = op.getResult(0).getType();
auto shaped_type = result_type.cast<ShapedType>();
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
reduction_dims);
auto init_tensor =
rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape);
{
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
rewriter.getIndexType());
Region& region = init_tensor.body();
Block* block = rewriter.createBlock(&region, region.begin(), arg_types);
rewriter.setInsertionPointToEnd(block);
rewriter.create<tensor::YieldOp>(loc, init_value);
}
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
/*outputBuffers=*/ValueRange{init_tensor}, indexing_maps,
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
// Convert the signature of the body. The reduce op region apply function
// has a signature (lhs, rhs) -> output, all of the same tensor type t.
// This is converted to a function with the same signature but with
// element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
// be converted to "(f32, f32, f32)".
Region& region = linalg_op.region();
rewriter.inlineRegionBefore(op.body(), region, region.end());
TypeConverter::SignatureConversion signatureConverter(2);
signatureConverter.addInputs(0, src_type.getElementType());
signatureConverter.addInputs(1, src_type.getElementType());
rewriter.applySignatureConversion(&region, signatureConverter);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context, void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
@ -1356,54 +1538,57 @@ namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context, void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
patterns patterns->insert<
->insert<BroadcastConverter<mhlo::BroadcastOp, false>, BroadcastConverter<mhlo::BroadcastOp, false>,
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>, PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>, PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>, PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::Atan2Op, false>, PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>, PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::ClampOp, false>, PointwiseToLinalgConverter<mhlo::ClampOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>, PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>, PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>, PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>, PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>, PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>, PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>, PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>, PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>, PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>, PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::LogisticOp, false>, PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
PointwiseToLinalgConverter<mhlo::Log1pOp, false>, PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>, PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>, PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>, PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>, PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>, PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>, PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::PowOp, false>, PointwiseToLinalgConverter<mhlo::PowOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>, PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>, PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>, PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>, PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>, PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>, PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
PointwiseToLinalgConverter<mhlo::SignOp, false>, PointwiseToLinalgConverter<mhlo::SignOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>, PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>, PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>, PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>, PointwiseToLinalgConverter<mhlo::TanhOp, false>,
PointwiseToLinalgConverter<mhlo::XorOp, false>, PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>, TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>( DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
context); patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
ReduceRegionXLAOpConversion<mhlo::MinOp>,
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
ReduceRegionReturnOpConversion>(context);
} }
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {

View File

@ -980,3 +980,178 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
tensor<4xf32>) -> tensor<4xf32> tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
// -----
func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = mhlo.add %arg3, %arg4 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
return %0 : tensor<5xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: @reduce_add
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = mhlo.minimum %arg3, %arg4 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
return %0 : tensor<5xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: @reduce_minimum
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = mhlo.maximum %arg3, %arg4 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32>
return %0 : tensor<5xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: @reduce_maximum
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<4xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = mhlo.maximum %arg3, %arg4 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: @reduce_dim0
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> {
%cst = constant dense<0xFF800000> : tensor<f32>
%0 = "mhlo.reduce"(%arg0, %cst) ({
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%1 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: @reduce_init_const
// CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>,
%arg1: tensor<i32>) -> tensor<4xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = mhlo.add %arg2, %arg3 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<5x4x3xi32>, tensor<i32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-LABEL: @reduce_multi_dimensions
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32> {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>):
%1 = mhlo.add %arg3, %arg4 : tensor<i32>
"mhlo.return"(%1) : (tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>)
// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<?xi32>)
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32