Upstream mhlo.reduce lowering to Linalg to MHLO repo.
In IREE, we use indexed generic op to handle the initial value. However, we lower it to a generic op that carries an init_tensor here, and leave the handle of initialization problem to later passes. PiperOrigin-RevId: 354294807
This commit is contained in:
		
							parent
							
								
									39589add22
								
							
						
					
					
						commit
						30ce82790d
					
				|  | @ -18,6 +18,7 @@ limitations under the License. | |||
| #include <numeric> | ||||
| 
 | ||||
| #include "llvm/ADT/STLExtras.h" | ||||
| #include "llvm/ADT/SetVector.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" | ||||
|  | @ -35,19 +36,31 @@ limitations under the License. | |||
| #include "mlir/IR/BuiltinTypes.h" | ||||
| #include "mlir/IR/Location.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/IR/Matchers.h" | ||||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/OperationSupport.h" | ||||
| #include "mlir/IR/PatternMatch.h" | ||||
| #include "mlir/IR/TypeUtilities.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Pass/PassManager.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace { | ||||
| 
 | ||||
| /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
 | ||||
| /// are "parallel" except the last `nReduction` elements, where are "reduction"
 | ||||
| /// attributes.
 | ||||
| SmallVector<StringRef, 3> GetParallelAndReductionIterators( | ||||
|     unsigned nLoops, unsigned nReduction) { | ||||
|   SmallVector<StringRef, 3> res(nLoops - nReduction, | ||||
|                                 getParallelIteratorTypeName()); | ||||
|   res.append(nReduction, getReductionIteratorTypeName()); | ||||
|   return res; | ||||
| } | ||||
| 
 | ||||
| SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) { | ||||
|   static constexpr StringRef kParallelIterType = "parallel"; | ||||
|   return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType); | ||||
|   return GetParallelAndReductionIterators(nParallelLoops, 0); | ||||
| } | ||||
| 
 | ||||
| template <bool isLHLO = true> | ||||
|  | @ -107,6 +120,35 @@ SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) { | |||
|   return ret; | ||||
| } | ||||
| 
 | ||||
| /// Returns the constant value associated with the init value if the defining
 | ||||
| /// operation is a constant.
 | ||||
| Attribute GetInitValueAsConst(Value init) { | ||||
|   DenseElementsAttr attr; | ||||
|   if (!matchPattern(init, m_Constant(&attr))) return {}; | ||||
|   auto type = attr.getType().dyn_cast<ShapedType>(); | ||||
|   if (!type || type.getRank() != 0) return {}; | ||||
|   return attr.getValue({}); | ||||
| } | ||||
| 
 | ||||
| /// Returns a permutation AffineMap that puts all reduction dimensions to the
 | ||||
| /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
 | ||||
| /// if `rank` is 4 and `reductionDims` is {1, 3}, then
 | ||||
| /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
 | ||||
| /// the AffineMap is returned.
 | ||||
| AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank, | ||||
|                                       ArrayRef<int64_t> reduction_dims) { | ||||
|   llvm::SmallSetVector<int, 4> s; | ||||
|   for (auto dim : reduction_dims) s.insert(dim); | ||||
| 
 | ||||
|   SmallVector<unsigned, 4> permutation; | ||||
|   for (int i = 0; i < rank; ++i) | ||||
|     if (!s.count(i)) permutation.push_back(i); | ||||
|   for (auto dim : reduction_dims) permutation.push_back(dim); | ||||
| 
 | ||||
|   auto map = AffineMap::getPermutationMap(permutation, context); | ||||
|   return inversePermutation(map); | ||||
| } | ||||
| 
 | ||||
| template <typename OpTy, bool isLHLO = true> | ||||
| class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> { | ||||
|  public: | ||||
|  | @ -1226,6 +1268,146 @@ class DotGeneralOpOnTensorsConversion | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename OpTy> | ||||
| struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> { | ||||
|   using OpConversionPattern<OpTy>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       OpTy op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     // Only convert the body of reduction ops to std ops.
 | ||||
|     auto parent_op = op.getOperation()->getParentRegion()->getParentOp(); | ||||
|     if (!isa<mhlo::ReduceOp, linalg::GenericOp, linalg::IndexedGenericOp>( | ||||
|             parent_op)) { | ||||
|       return failure(); | ||||
|     } | ||||
|     if (!op.getResult().getType().template isa<TensorType>()) return failure(); | ||||
|     if (llvm::all_of(args, [](Value arg) { | ||||
|           return arg.getType().template isa<TensorType>(); | ||||
|         })) { | ||||
|       return failure(); | ||||
|     } | ||||
|     Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(), | ||||
|                                                         args, &rewriter); | ||||
|     rewriter.replaceOp(op, result); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| SmallVector<Value, 8> GetReduceOpInitTensorDynSizes( | ||||
|     OpBuilder& b, Location loc, Value arg, ShapedType result_type, | ||||
|     ArrayRef<int64_t> reduction_dims) { | ||||
|   llvm::SmallSetVector<int, 4> s; | ||||
|   for (auto dim : reduction_dims) s.insert(dim); | ||||
| 
 | ||||
|   SmallVector<unsigned, 4> parallel_dims; | ||||
|   SmallVector<Value, 8> dyn_shape; | ||||
|   int rank = arg.getType().cast<RankedTensorType>().getRank(); | ||||
|   for (int i = 0, j = 0; i < rank; ++i) { | ||||
|     if (s.count(i)) continue; | ||||
|     if (!result_type.isDynamicDim(j++)) continue; | ||||
|     dyn_shape.push_back(b.create<DimOp>(loc, arg, i)); | ||||
|   } | ||||
| 
 | ||||
|   return dyn_shape; | ||||
| } | ||||
| 
 | ||||
| class ReduceRegionReturnOpConversion | ||||
|     : public OpConversionPattern<mhlo::ReturnOp> { | ||||
|  public: | ||||
|   using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::ReturnOp op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> { | ||||
|  public: | ||||
|   using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::ReduceOp op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     Location loc = op.getLoc(); | ||||
|     mhlo::ReduceOp::Adaptor adaptor(args); | ||||
|     if (op.getNumOperands() != 2) { | ||||
|       return op.emitError("expects exactly two operands"); | ||||
|     } | ||||
|     Value src = adaptor.operands()[0]; | ||||
|     auto src_type = src.getType().cast<ShapedType>(); | ||||
|     int src_rank = src_type.getRank(); | ||||
|     if (!src_rank) { | ||||
|       return rewriter.notifyMatchFailure(op, "expects known-rank args"); | ||||
|     } | ||||
| 
 | ||||
|     // Check if init_value is constant. If so, inline the value into the region.
 | ||||
|     Value init_value = adaptor.init_values()[0]; | ||||
|     Attribute init_const_val = GetInitValueAsConst(init_value); | ||||
|     if (init_const_val) { | ||||
|       init_value = rewriter.create<ConstantOp>( | ||||
|           init_value.getDefiningOp()->getLoc(), init_const_val); | ||||
|     } else { | ||||
|       init_value = rewriter.create<tensor::ExtractOp>(loc, init_value); | ||||
|     } | ||||
| 
 | ||||
|     // Prepare indexing maps for linalg generic op. The elements are for src and
 | ||||
|     // dst. Transpose `src` to make the reduction loops be the innermost,
 | ||||
|     // because it's easier to fully utilize processors.
 | ||||
|     SmallVector<AffineMap, 3> indexing_maps; | ||||
|     SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions()); | ||||
|     indexing_maps.emplace_back(GetTransposeMapForReduction( | ||||
|         rewriter.getContext(), src_rank, reduction_dims)); | ||||
| 
 | ||||
|     // The indexing map of `dst` should drop the reduction loops. Since the
 | ||||
|     // reduction loops now are all in the innermost, drops
 | ||||
|     // `reduction_dims.size()` dimensions. We don't need an inverse permutation
 | ||||
|     // here because they are the same.
 | ||||
|     SmallVector<AffineExpr, 4> exprs; | ||||
|     for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i) | ||||
|       exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||
|     indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0, | ||||
|                                               exprs, rewriter.getContext())); | ||||
| 
 | ||||
|     SmallVector<Value, 2> inputs = {adaptor.operands()[0]}; | ||||
|     Type result_type = op.getResult(0).getType(); | ||||
|     auto shaped_type = result_type.cast<ShapedType>(); | ||||
|     SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes( | ||||
|         rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(), | ||||
|         reduction_dims); | ||||
|     auto init_tensor = | ||||
|         rewriter.create<tensor::GenerateOp>(loc, result_type, dyn_shape); | ||||
|     { | ||||
|       OpBuilder::InsertionGuard guard(rewriter); | ||||
|       SmallVector<Type, 4> arg_types(shaped_type.getRank(), | ||||
|                                      rewriter.getIndexType()); | ||||
|       Region& region = init_tensor.body(); | ||||
|       Block* block = rewriter.createBlock(®ion, region.begin(), arg_types); | ||||
|       rewriter.setInsertionPointToEnd(block); | ||||
|       rewriter.create<tensor::YieldOp>(loc, init_value); | ||||
|     } | ||||
| 
 | ||||
|     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||
|         loc, /*resultTensorTypes=*/op.getResultTypes(), inputs, | ||||
|         /*outputBuffers=*/ValueRange{init_tensor}, indexing_maps, | ||||
|         GetParallelAndReductionIterators(src_rank, reduction_dims.size())); | ||||
| 
 | ||||
|     // Convert the signature of the body. The reduce op region apply function
 | ||||
|     // has a signature (lhs, rhs) -> output, all of the same tensor type t.
 | ||||
|     // This is converted to a function with the same signature but with
 | ||||
|     // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
 | ||||
|     // be converted to "(f32, f32, f32)".
 | ||||
|     Region& region = linalg_op.region(); | ||||
|     rewriter.inlineRegionBefore(op.body(), region, region.end()); | ||||
|     TypeConverter::SignatureConversion signatureConverter(2); | ||||
|     signatureConverter.addInputs(0, src_type.getElementType()); | ||||
|     signatureConverter.addInputs(1, src_type.getElementType()); | ||||
|     rewriter.applySignatureConversion(®ion, signatureConverter); | ||||
|     rewriter.replaceOp(op, linalg_op.getResults()); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| void populateLHLOToLinalgConversionPattern(MLIRContext* context, | ||||
|                                            OwningRewritePatternList* patterns) { | ||||
|   // clang-format off
 | ||||
|  | @ -1356,54 +1538,57 @@ namespace mhlo { | |||
| 
 | ||||
| void populateHLOToLinalgConversionPattern(MLIRContext* context, | ||||
|                                           OwningRewritePatternList* patterns) { | ||||
|   patterns | ||||
|       ->insert<BroadcastConverter<mhlo::BroadcastOp, false>, | ||||
|                ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, | ||||
|                HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::AbsOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::AddOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::AndOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::Atan2Op, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::CeilOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ClampOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::CompareOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ComplexOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ConvertOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::CopyOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::CosOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::DivOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ExpOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::FloorOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ImagOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::LogOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::LogisticOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::Log1pOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::MaxOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::MinOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::MulOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::NegOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::NotOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::OrOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::PowOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::RealOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::RemOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::SelectOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::SignOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::SinOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::SqrtOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::SubOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::TanhOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::XorOp, false>, | ||||
|                ReshapeOpConverter<mhlo::ReshapeOp, false>, | ||||
|                ReverseConverter<mhlo::ReverseOp, false>, | ||||
|                TransposeConverter<mhlo::TransposeOp, false>, | ||||
|                DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>( | ||||
|           context); | ||||
|   patterns->insert< | ||||
|       BroadcastConverter<mhlo::BroadcastOp, false>, | ||||
|       ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, | ||||
|       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::AbsOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::AddOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::AndOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::Atan2Op, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::CeilOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ClampOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::CompareOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ComplexOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ConvertOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::CopyOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::CosOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::DivOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ExpOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::FloorOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ImagOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::LogOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::LogisticOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::Log1pOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::MaxOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::MinOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::MulOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::NegOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::NotOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::OrOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::PowOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::RealOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::RemOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::SelectOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::SignOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::SinOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::SqrtOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::SubOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::TanhOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::XorOp, false>, | ||||
|       ReshapeOpConverter<mhlo::ReshapeOp, false>, | ||||
|       ReverseConverter<mhlo::ReverseOp, false>, | ||||
|       TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion, | ||||
|       DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context); | ||||
|   patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>, | ||||
|                    ReduceRegionXLAOpConversion<mhlo::MinOp>, | ||||
|                    ReduceRegionXLAOpConversion<mhlo::MaxOp>, | ||||
|                    ReduceRegionReturnOpConversion>(context); | ||||
| } | ||||
| 
 | ||||
| std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() { | ||||
|  |  | |||
|  | @ -980,3 +980,178 @@ func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) | |||
|       tensor<4xf32>) -> tensor<4xf32> | ||||
|   return %0 : tensor<4xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> { | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ({ | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>): | ||||
|     %1 = mhlo.add %arg3, %arg4 : tensor<i32> | ||||
|     "mhlo.return"(%1) : (tensor<i32>) -> () | ||||
|   }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32> | ||||
|   return %0 : tensor<5xi32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK-LABEL: @reduce_add | ||||
| // CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : i32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_minimum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> { | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ({ | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>): | ||||
|     %1 = mhlo.minimum %arg3, %arg4 : tensor<i32> | ||||
|     "mhlo.return"(%1) : (tensor<i32>) -> () | ||||
|   }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32> | ||||
|   return %0 : tensor<5xi32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK-LABEL: @reduce_minimum | ||||
| // CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): | ||||
| // CHECK-NEXT:   %[[CMP:.*]] = cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : i32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_maximum(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<5xi32> { | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ({ | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>): | ||||
|     %1 = mhlo.maximum %arg3, %arg4 : tensor<i32> | ||||
|     "mhlo.return"(%1) : (tensor<i32>) -> () | ||||
|   }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<5xi32> | ||||
|   return %0 : tensor<5xi32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK-LABEL: @reduce_maximum | ||||
| // CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<5xi32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): | ||||
| // CHECK-NEXT:   %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : i32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor<i32>) -> tensor<4xi32> { | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ({ | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>): | ||||
|     %1 = mhlo.maximum %arg3, %arg4 : tensor<i32> | ||||
|     "mhlo.return"(%1) : (tensor<i32>) -> () | ||||
|   }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor<i32>) -> tensor<4xi32> | ||||
|   return %0 : tensor<4xi32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK-LABEL: @reduce_dim0 | ||||
| // CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): | ||||
| // CHECK-NEXT:   %[[CMP:.*]] = cmpi sgt, %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : i32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> { | ||||
|   %cst = constant dense<0xFF800000> : tensor<f32> | ||||
|   %0 = "mhlo.reduce"(%arg0, %cst) ({ | ||||
|   ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors | ||||
|     %1 = mhlo.add %arg1, %arg2 : tensor<f32> | ||||
|     "mhlo.return"(%1) : (tensor<f32>) -> () | ||||
|   }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32> | ||||
|   return %0 : tensor<1xf32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK-LABEL: @reduce_init_const | ||||
| // CHECK: %[[INIT:.*]] = constant 0xFF800000 : f32 | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<1xf32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : f32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>, | ||||
|                               %arg1: tensor<i32>) -> tensor<4xi32> { | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ({ | ||||
|   ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>): | ||||
|     %1 = mhlo.add %arg2, %arg3 : tensor<i32> | ||||
|     "mhlo.return"(%1) : (tensor<i32>) -> () | ||||
|   }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<5x4x3xi32>, tensor<i32>) -> tensor<4xi32> | ||||
|   return %0 : tensor<4xi32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)> | ||||
| // CHECK-LABEL: @reduce_multi_dimensions | ||||
| // CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<4xi32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : i32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32> { | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ({ | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4 : tensor<i32>): | ||||
|     %1 = mhlo.add %arg3, %arg4 : tensor<i32> | ||||
|     "mhlo.return"(%1) : (tensor<i32>) -> () | ||||
|   }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xi32>, tensor<i32>) -> tensor<?xi32> | ||||
|   return %0 : tensor<?xi32> | ||||
| } | ||||
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> | ||||
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32> | ||||
| // CHECK: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32> | ||||
| // CHECK: %[[C0:.*]] = constant 0 : index | ||||
| // CHECK: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32> | ||||
| // CHECK: %[[INIT_TENSOR:.*]] = tensor.generate | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] | ||||
| // CHECK-SAME: iterator_types = ["parallel", "reduction"] | ||||
| // CHECK-SAME: ins(%{{.*}}tensor<?x?xi32>) | ||||
| // CHECK-SAME: outs(%[[INIT_TENSOR]] : tensor<?xi32>) | ||||
| // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : i32 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue