From 30ce82790d4ffef48f70e99a6f96f13ddbe857d8 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Thu, 28 Jan 2021 05:44:49 -0800 Subject: [PATCH] Upstream mhlo.reduce lowering to Linalg to MHLO repo. In IREE, we use indexed generic op to handle the initial value. However, we lower it to a generic op that carries an init_tensor here, and leave the handle of initialization problem to later passes. PiperOrigin-RevId: 354294807 --- .../mhlo/transforms/legalize_to_linalg.cc | 285 +++++++++++++++--- tests/hlo-legalize-to-linalg.mlir | 175 +++++++++++ 2 files changed, 410 insertions(+), 50 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 1486078..e267ce2 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" @@ -35,19 +36,31 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace { +/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes +/// are "parallel" except the last `nReduction` elements, where are "reduction" +/// attributes. +SmallVector GetParallelAndReductionIterators( + unsigned nLoops, unsigned nReduction) { + SmallVector res(nLoops - nReduction, + getParallelIteratorTypeName()); + res.append(nReduction, getReductionIteratorTypeName()); + return res; +} + SmallVector GetNParallelLoopsAttrs(unsigned nParallelLoops) { - static constexpr StringRef kParallelIterType = "parallel"; - return SmallVector(nParallelLoops, kParallelIterType); + return GetParallelAndReductionIterators(nParallelLoops, 0); } template @@ -107,6 +120,35 @@ SmallVector Extract1DVector(DenseIntElementsAttr elements) { return ret; } +/// Returns the constant value associated with the init value if the defining +/// operation is a constant. +Attribute GetInitValueAsConst(Value init) { + DenseElementsAttr attr; + if (!matchPattern(init, m_Constant(&attr))) return {}; + auto type = attr.getType().dyn_cast(); + if (!type || type.getRank() != 0) return {}; + return attr.getValue({}); +} + +/// Returns a permutation AffineMap that puts all reduction dimensions to the +/// last. The order of parallel loops and reduction loops are all sorted. E.g., +/// if `rank` is 4 and `reductionDims` is {1, 3}, then +/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of +/// the AffineMap is returned. +AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank, + ArrayRef reduction_dims) { + llvm::SmallSetVector s; + for (auto dim : reduction_dims) s.insert(dim); + + SmallVector permutation; + for (int i = 0; i < rank; ++i) + if (!s.count(i)) permutation.push_back(i); + for (auto dim : reduction_dims) permutation.push_back(dim); + + auto map = AffineMap::getPermutationMap(permutation, context); + return inversePermutation(map); +} + template class PointwiseToLinalgConverter : public OpConversionPattern { public: @@ -1226,6 +1268,146 @@ class DotGeneralOpOnTensorsConversion } }; +template +struct ReduceRegionXLAOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + OpTy op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + // Only convert the body of reduction ops to std ops. + auto parent_op = op.getOperation()->getParentRegion()->getParentOp(); + if (!isa( + parent_op)) { + return failure(); + } + if (!op.getResult().getType().template isa()) return failure(); + if (llvm::all_of(args, [](Value arg) { + return arg.getType().template isa(); + })) { + return failure(); + } + Value result = lmhlo::HloOpToStdScalarOp::map(op, args[0].getType(), + args, &rewriter); + rewriter.replaceOp(op, result); + return success(); + } +}; + +SmallVector GetReduceOpInitTensorDynSizes( + OpBuilder& b, Location loc, Value arg, ShapedType result_type, + ArrayRef reduction_dims) { + llvm::SmallSetVector s; + for (auto dim : reduction_dims) s.insert(dim); + + SmallVector parallel_dims; + SmallVector dyn_shape; + int rank = arg.getType().cast().getRank(); + for (int i = 0, j = 0; i < rank; ++i) { + if (s.count(i)) continue; + if (!result_type.isDynamicDim(j++)) continue; + dyn_shape.push_back(b.create(loc, arg, i)); + } + + return dyn_shape; +} + +class ReduceRegionReturnOpConversion + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ReturnOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOpWithNewOp(op, args); + return success(); + } +}; + +class ReduceOnTensorsConversion : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ReduceOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + Location loc = op.getLoc(); + mhlo::ReduceOp::Adaptor adaptor(args); + if (op.getNumOperands() != 2) { + return op.emitError("expects exactly two operands"); + } + Value src = adaptor.operands()[0]; + auto src_type = src.getType().cast(); + int src_rank = src_type.getRank(); + if (!src_rank) { + return rewriter.notifyMatchFailure(op, "expects known-rank args"); + } + + // Check if init_value is constant. If so, inline the value into the region. + Value init_value = adaptor.init_values()[0]; + Attribute init_const_val = GetInitValueAsConst(init_value); + if (init_const_val) { + init_value = rewriter.create( + init_value.getDefiningOp()->getLoc(), init_const_val); + } else { + init_value = rewriter.create(loc, init_value); + } + + // Prepare indexing maps for linalg generic op. The elements are for src and + // dst. Transpose `src` to make the reduction loops be the innermost, + // because it's easier to fully utilize processors. + SmallVector indexing_maps; + SmallVector reduction_dims = Extract1DVector(op.dimensions()); + indexing_maps.emplace_back(GetTransposeMapForReduction( + rewriter.getContext(), src_rank, reduction_dims)); + + // The indexing map of `dst` should drop the reduction loops. Since the + // reduction loops now are all in the innermost, drops + // `reduction_dims.size()` dimensions. We don't need an inverse permutation + // here because they are the same. + SmallVector exprs; + for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0, + exprs, rewriter.getContext())); + + SmallVector inputs = {adaptor.operands()[0]}; + Type result_type = op.getResult(0).getType(); + auto shaped_type = result_type.cast(); + SmallVector dyn_shape = GetReduceOpInitTensorDynSizes( + rewriter, loc, adaptor.operands()[0], result_type.cast(), + reduction_dims); + auto init_tensor = + rewriter.create(loc, result_type, dyn_shape); + { + OpBuilder::InsertionGuard guard(rewriter); + SmallVector arg_types(shaped_type.getRank(), + rewriter.getIndexType()); + Region& region = init_tensor.body(); + Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); + rewriter.setInsertionPointToEnd(block); + rewriter.create(loc, init_value); + } + + auto linalg_op = rewriter.create( + loc, /*resultTensorTypes=*/op.getResultTypes(), inputs, + /*outputBuffers=*/ValueRange{init_tensor}, indexing_maps, + GetParallelAndReductionIterators(src_rank, reduction_dims.size())); + + // Convert the signature of the body. The reduce op region apply function + // has a signature (lhs, rhs) -> output, all of the same tensor type t. + // This is converted to a function with the same signature but with + // element types. E.g., "(tensor, tensor) -> tensor" will + // be converted to "(f32, f32, f32)". + Region& region = linalg_op.region(); + rewriter.inlineRegionBefore(op.body(), region, region.end()); + TypeConverter::SignatureConversion signatureConverter(2); + signatureConverter.addInputs(0, src_type.getElementType()); + signatureConverter.addInputs(1, src_type.getElementType()); + rewriter.applySignatureConversion(®ion, signatureConverter); + rewriter.replaceOp(op, linalg_op.getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1356,54 +1538,57 @@ namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { - patterns - ->insert, - ConstConverter, HloDynamicBroadcastInDimConverter, - HloBroadcastInDimConverter, IotaConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - TransposeConverter, - DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>( - context); + patterns->insert< + BroadcastConverter, + ConstConverter, HloDynamicBroadcastInDimConverter, + HloBroadcastInDimConverter, IotaConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + TransposeConverter, DotOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context); + patterns->insert, + ReduceRegionXLAOpConversion, + ReduceRegionXLAOpConversion, + ReduceRegionReturnOpConversion>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 1fb52f5..51d2f90 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -980,3 +980,178 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } + +// ----- + +func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4 : tensor): + %1 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor) -> tensor<5xi32> + return %0 : tensor<5xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: @reduce_add +// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4 : tensor): + %1 = mhlo.minimum %arg3, %arg4 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor) -> tensor<5xi32> + return %0 : tensor<5xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: @reduce_minimum +// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4 : tensor): + %1 = mhlo.maximum %arg3, %arg4 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor) -> tensor<5xi32> + return %0 : tensor<5xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: @reduce_maximum +// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<4xi32> { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4 : tensor): + %1 = mhlo.maximum %arg3, %arg4 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor) -> tensor<4xi32> + return %0 : tensor<4xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: @reduce_dim0 +// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> { + %cst = constant dense<0xFF800000> : tensor + %0 = "mhlo.reduce"(%arg0, %cst) ({ + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %1 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor) -> tensor<1xf32> + return %0 : tensor<1xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: @reduce_init_const +// CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32 +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>, + %arg1: tensor) -> tensor<4xi32> { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<5x4x3xi32>, tensor) -> tensor<4xi32> + return %0 : tensor<4xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-LABEL: @reduce_multi_dimensions +// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4 : tensor): + %1 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor +// CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.generate +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%{{.*}}tensor) +// CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor) +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32