[MLIR] Add support for representing variadic reduce-window in HLO/LMHLO dialect.
- Fixed a subset of transformations to handle variadic reduce-window. PiperOrigin-RevId: 366278650
This commit is contained in:
		
							parent
							
								
									d1f697e618
								
							
						
					
					
						commit
						ff2cbfa2ec
					
				|  | @ -1257,6 +1257,7 @@ def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", | |||
| 
 | ||||
| def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ | ||||
|       RecursiveSideEffects, | ||||
|       SameVariadicOperandSize, | ||||
|       SingleBlockImplicitTerminator<"ReturnOp"> | ||||
|     ]>, BASE_HLO_ReduceWindowOp { | ||||
| 
 | ||||
|  | @ -1264,8 +1265,8 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ | |||
|   // attributes are 1-d. Attributes' leading dimension should match rank of the | ||||
|   // inputs. | ||||
|   let arguments = (ins | ||||
|     HLO_Tensor:$operand, | ||||
|     HLO_Tensor:$init_value, | ||||
|     Variadic<HLO_Tensor>:$inputs, | ||||
|     Variadic<HLO_Tensor>:$init_values, | ||||
|     I64ElementsAttr:$window_dimensions, | ||||
|     // If strides or dilations attributes are missing then the default value is | ||||
|     // one for each of the input dimensions. Similarly, padding values are zero | ||||
|  | @ -1276,15 +1277,36 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ | |||
|     OptionalAttr<I64ElementsAttr>:$padding | ||||
|   ); | ||||
| 
 | ||||
|   let results = (outs HLO_Tensor); | ||||
|   let results = (outs Variadic<HLO_Tensor>); | ||||
| 
 | ||||
|   // TODO(hinsu): Verify that the attached body arguments and results are | ||||
|   // compatible with reduce op's operands. | ||||
|   let regions = (region SizedRegion<1>:$body); | ||||
| 
 | ||||
|   let hasCustomHLOConverter = 1; | ||||
|   // Builder for non-variadic version of the operation. | ||||
|   let builders = [ | ||||
|     OpBuilder<(ins "Type":$result_type, "Value":$operand, | ||||
|       "Value":$init_value, | ||||
|       "DenseIntElementsAttr":$window_dimensions, | ||||
|       "DenseIntElementsAttr":$window_strides, | ||||
|       "DenseIntElementsAttr":$base_dilations, | ||||
|       "DenseIntElementsAttr":$window_dilations, | ||||
|       "DenseIntElementsAttr":$padding), | ||||
|     [{ | ||||
|       build($_builder, $_state, TypeRange(result_type), ValueRange(operand), | ||||
|             ValueRange(init_value), window_dimensions, window_strides, | ||||
|             base_dilations, window_dilations, padding); | ||||
|     }]> | ||||
|   ]; | ||||
| 
 | ||||
|   let hasCustomHLOConverter = 1; | ||||
|   let verifier = [{ return Verify(*this); }]; | ||||
|   // TODO(hinsu): Implement custom printer and parser. | ||||
| 
 | ||||
|   let extraClassDeclaration = [{ | ||||
|      // Get the operation used for reduction applied to `result_index`th result. | ||||
|      Operation *getReductionOp(int result_index); | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> { | ||||
|  |  | |||
|  | @ -216,12 +216,12 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_Reduce | |||
|   let hasCanonicalizer = 1; | ||||
| } | ||||
| 
 | ||||
| def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp { | ||||
| 
 | ||||
| def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>, | ||||
|                          BASE_HLO_ReduceWindowOp { | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$operand, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$init_value, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$out, | ||||
|     Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$inputs, | ||||
|     Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values, | ||||
|     Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out, | ||||
|     I64ElementsAttr:$window_dimensions, | ||||
|     // If strides or dilations attributes are missing then the default value is | ||||
|     // one for each of the input dimensions. Similarly, padding values are zero | ||||
|  | @ -233,6 +233,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp { | |||
|   ); | ||||
| 
 | ||||
|   let regions = (region SizedRegion<1>:$body); | ||||
|   let verifier = [{ return Verify(*this); }]; | ||||
| } | ||||
| 
 | ||||
| // TODO(timshen): Add a custom syntax for this. | ||||
|  |  | |||
|  | @ -1728,6 +1728,38 @@ static LogicalResult Verify(RecvOp op) { | |||
| 
 | ||||
| OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // ReduceWindowOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| // For reduce-window, all `inputs` need to have compatible shapes.
 | ||||
| static LogicalResult Verify(ReduceWindowOp op) { | ||||
|   if (failed(verifyCompatibleShapes(op.inputs().getTypes()))) | ||||
|     return op.emitOpError() << "requires same shape for all inputs"; | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| // Get the operation used for reduction applied to `result_index`th result. Its
 | ||||
| // expected to be a binary operation that consumes `result_index`th and
 | ||||
| // `result_index + operands().size`th arguments of the body.
 | ||||
| Operation* ReduceWindowOp::getReductionOp(int result_index) { | ||||
|   auto return_op = cast<ReturnOp>(body().front().getTerminator()); | ||||
|   Operation* compute_op = return_op.results()[result_index].getDefiningOp(); | ||||
|   if (compute_op->getNumOperands() != 2) return nullptr; | ||||
|   auto arg0 = compute_op->getOperand(0).dyn_cast<BlockArgument>(); | ||||
|   auto arg1 = compute_op->getOperand(1).dyn_cast<BlockArgument>(); | ||||
|   if (!arg0 || !arg1) return nullptr; | ||||
|   int arg0_num = arg0.getArgNumber(); | ||||
|   int arg1_num = arg1.getArgNumber(); | ||||
|   int other_arg_index = result_index + inputs().size(); | ||||
|   if (arg0_num == result_index && arg1_num == other_arg_index) | ||||
|     return compute_op; | ||||
|   if (arg0_num == other_arg_index && arg1_num == result_index && | ||||
|       compute_op->hasTrait<OpTrait::IsCommutative>()) | ||||
|     return compute_op; | ||||
|   return nullptr; | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // ReverseOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -281,6 +281,17 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results, | |||
|   results.insert<RemoveCopyInReduceBody>(context); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // ReduceWindowOp.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| // For reduce-window, all `inputs` need to have compatible shapes.
 | ||||
| static LogicalResult Verify(ReduceWindowOp op) { | ||||
|   if (failed(verifyCompatibleShapes(op.inputs().getTypes()))) | ||||
|     return op.emitOpError() << "requires same shape for all operands"; | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace lmhlo
 | ||||
| }  // namespace mlir
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1687,40 +1687,34 @@ struct ReduceWindowOpOnTensorsConversion | |||
|   /// the pooling is determined based on the body of the reduce window
 | ||||
|   /// operation. This class enumerates the different variants.
 | ||||
|   enum class PoolingType { | ||||
|     kInvalid, | ||||
|     kMin, | ||||
|     kMax, | ||||
|     kAdd, | ||||
|   }; | ||||
| 
 | ||||
|   static PoolingType getPoolingType(Region& region) { | ||||
|     assert(region.getBlocks().size() == 1 && | ||||
|            "expected the region has exactlly one block"); | ||||
|     Block& block = region.front(); | ||||
|     assert(block.getOperations().size() == 2 && | ||||
|            "expected the block has exactlly two operations"); | ||||
|     auto op = block.begin(); | ||||
|     if (isa<mhlo::MinOp>(op)) return PoolingType::kMin; | ||||
|     if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax; | ||||
|     if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd; | ||||
| 
 | ||||
|     llvm_unreachable("unknown pooling type"); | ||||
|   static PoolingType getPoolingType(mhlo::ReduceWindowOp reduce_op, | ||||
|                                     int result_index) { | ||||
|     if (Operation* op = reduce_op.getReductionOp(result_index)) { | ||||
|       if (isa<mhlo::MinOp>(*op)) return PoolingType::kMin; | ||||
|       if (isa<mhlo::MaxOp>(*op)) return PoolingType::kMax; | ||||
|       if (isa<mhlo::AddOp>(*op)) return PoolingType::kAdd; | ||||
|     } | ||||
|     return PoolingType::kInvalid; | ||||
|   } | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::ReduceWindowOp op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const override { | ||||
|     auto loc = op.getLoc(); | ||||
|     auto result_type = op.getResult().getType().cast<ShapedType>(); | ||||
|     if (result_type.getRank() != 4) { | ||||
|     int rank = op.getResultTypes()[0].cast<ShapedType>().getRank(); | ||||
|     if (rank != 4) { | ||||
|       return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op"); | ||||
|     } | ||||
| 
 | ||||
|     // Create a fake window dimension.
 | ||||
|     SmallVector<int64_t, 4> shapes; | ||||
|     SmallVector<int64_t, 2> shapes; | ||||
|     shapes.push_back(op.window_dimensions().getValue<int64_t>(1)); | ||||
|     shapes.push_back(op.window_dimensions().getValue<int64_t>(2)); | ||||
|     auto fake_window_dims = rewriter.create<linalg::InitTensorOp>( | ||||
|         loc, shapes, result_type.getElementType()); | ||||
| 
 | ||||
|     if (op.window_strides() && | ||||
|         (op.window_strides().getValue().getValue<int64_t>(0) != 1 || | ||||
|  | @ -1735,10 +1729,6 @@ struct ReduceWindowOpOnTensorsConversion | |||
|           op, "expected window_dimensions to be [1,x,y,1]"); | ||||
|     } | ||||
| 
 | ||||
|     if (!args[0].getType().cast<ShapedType>().getElementType().isF32()) { | ||||
|       return rewriter.notifyMatchFailure(op, "expected element type to be f32"); | ||||
|     } | ||||
| 
 | ||||
|     Attribute strides; | ||||
|     if (op.window_stridesAttr()) { | ||||
|       strides = rewriter.getI64VectorAttr( | ||||
|  | @ -1756,39 +1746,62 @@ struct ReduceWindowOpOnTensorsConversion | |||
|       dilations = rewriter.getI64VectorAttr({1, 1}); | ||||
|     } | ||||
| 
 | ||||
|     Value init_tensor = rewriter.create<linalg::InitTensorOp>( | ||||
|         loc, result_type.getShape(), result_type.getElementType()); | ||||
|     Value init_value = args[1]; | ||||
|     init_value = rewriter.create<tensor::ExtractOp>(loc, init_value); | ||||
|     Value filled_init_tensor = | ||||
|         rewriter.create<linalg::FillOp>(loc, init_tensor, init_value) | ||||
|             .getResult(0); | ||||
|     auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp { | ||||
|       return cast<linalg::LinalgOp>( | ||||
|           rewriter | ||||
|               .create<std::remove_pointer_t<decltype(type_ptr)>>( | ||||
|                   loc, ArrayRef<Type>{result_type}, | ||||
|                   ValueRange{args[0], fake_window_dims.getResult()}, | ||||
|                   filled_init_tensor, dilations, strides) | ||||
|               .getOperation()); | ||||
|     }; | ||||
|     linalg::LinalgOp pooling_op; | ||||
|     PoolingType pooling_type = getPoolingType(op.body()); | ||||
|     switch (pooling_type) { | ||||
|       case PoolingType::kMin: { | ||||
|         pooling_op = create_op(static_cast<linalg::PoolingNHWCMinOp*>(nullptr)); | ||||
|         break; | ||||
|     SmallVector<Value> pooling_ops; | ||||
| 
 | ||||
|     ArrayRef<Value> inputs = args.take_front(op.inputs().size()); | ||||
|     ArrayRef<Value> init_values = args.drop_front(op.inputs().size()); | ||||
|     for (auto it : llvm::zip(op.getResults(), inputs, init_values)) { | ||||
|       OpResult result = std::get<0>(it); | ||||
|       Value input = std::get<1>(it); | ||||
|       Value init_value = std::get<2>(it); | ||||
|       auto result_type = result.getType().cast<ShapedType>(); | ||||
|       if (!input.getType().cast<ShapedType>().getElementType().isF32()) { | ||||
|         return rewriter.notifyMatchFailure(op, | ||||
|                                            "expected element type to be f32"); | ||||
|       } | ||||
|       case PoolingType::kMax: { | ||||
|         pooling_op = create_op(static_cast<linalg::PoolingNHWCMaxOp*>(nullptr)); | ||||
|         break; | ||||
|       } | ||||
|       case PoolingType::kAdd: { | ||||
|         pooling_op = create_op(static_cast<linalg::PoolingNHWCSumOp*>(nullptr)); | ||||
|         break; | ||||
| 
 | ||||
|       // Create a fake window dimension.
 | ||||
|       auto fake_window_dims = rewriter.create<linalg::InitTensorOp>( | ||||
|           loc, shapes, result_type.getElementType()); | ||||
|       Value init_tensor = rewriter.create<linalg::InitTensorOp>( | ||||
|           loc, result_type.getShape(), result_type.getElementType()); | ||||
|       init_value = rewriter.create<tensor::ExtractOp>(loc, init_value); | ||||
|       Value filled_init_tensor = | ||||
|           rewriter.create<linalg::FillOp>(loc, init_tensor, init_value) | ||||
|               .getResult(0); | ||||
|       auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp { | ||||
|         return cast<linalg::LinalgOp>( | ||||
|             rewriter | ||||
|                 .create<std::remove_pointer_t<decltype(type_ptr)>>( | ||||
|                     loc, ArrayRef<Type>{result_type}, | ||||
|                     ValueRange{args[0], fake_window_dims.getResult()}, | ||||
|                     filled_init_tensor, dilations, strides) | ||||
|                 .getOperation()); | ||||
|       }; | ||||
|       linalg::LinalgOp pooling_op; | ||||
|       PoolingType pooling_type = getPoolingType(op, result.getResultNumber()); | ||||
|       switch (pooling_type) { | ||||
|         case PoolingType::kMin: { | ||||
|           pooling_op = | ||||
|               create_op(static_cast<linalg::PoolingNHWCMinOp*>(nullptr)); | ||||
|           break; | ||||
|         } | ||||
|         case PoolingType::kMax: { | ||||
|           pooling_op = | ||||
|               create_op(static_cast<linalg::PoolingNHWCMaxOp*>(nullptr)); | ||||
|           break; | ||||
|         } | ||||
|         case PoolingType::kAdd: { | ||||
|           pooling_op = | ||||
|               create_op(static_cast<linalg::PoolingNHWCSumOp*>(nullptr)); | ||||
|           break; | ||||
|         } | ||||
|         case PoolingType::kInvalid: | ||||
|           return rewriter.notifyMatchFailure(op, "unknown reduction operation"); | ||||
|       } | ||||
|       pooling_ops.push_back(pooling_op->getResult(0)); | ||||
|     } | ||||
|     rewriter.replaceOp(op, pooling_op->getResult(0)); | ||||
|     rewriter.replaceOp(op, pooling_ops); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  |  | |||
|  | @ -93,8 +93,8 @@ struct MappedIvs { | |||
| }; | ||||
| 
 | ||||
| template <typename OpTy> | ||||
| MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, | ||||
|                               OpBuilder* b) { | ||||
| MappedIvs MapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs, | ||||
|                               ValueRange window_ivs, OpBuilder* b) { | ||||
|   MappedIvs mapped_ivs; | ||||
| 
 | ||||
|   if (!op.window_strides().hasValue()) { | ||||
|  | @ -108,7 +108,6 @@ MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, | |||
|   auto padding = op.padding().getValue(); | ||||
| 
 | ||||
|   auto loc = op.getLoc(); | ||||
|   auto operand = op.operand(); | ||||
|   auto operand_shape = operand.getType().template cast<MemRefType>().getShape(); | ||||
| 
 | ||||
|   // `in_bounds` is false when the mapped indices are in the padding area.
 | ||||
|  | @ -196,7 +195,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> { | |||
|   LogicalResult matchAndRewrite( | ||||
|       lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     // TODO(b/137624192) Implement variadic reduce.
 | ||||
|     // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
 | ||||
|     if (reduce_op.out().size() != 1) return failure(); | ||||
| 
 | ||||
|     scf::ReduceOp scf_reduce_op = | ||||
|  | @ -312,7 +311,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> { | |||
| //       value = input[I]
 | ||||
| //     else
 | ||||
| //       value = neutral_value
 | ||||
| //     accumulator = reduction_operator(output[O], value)
 | ||||
| //     accumulator = reduction_operator(accumulator, value)
 | ||||
| //   output[O] = accumulator
 | ||||
| //
 | ||||
| // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
 | ||||
|  | @ -367,6 +366,9 @@ class ReduceWindowOpConverter | |||
|   LogicalResult matchAndRewrite( | ||||
|       lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
 | ||||
|     if (reduce_window_op.out().size() != 1) return failure(); | ||||
| 
 | ||||
|     scf::ParallelOp output_loop, window_loop; | ||||
|     std::tie(output_loop, window_loop) = | ||||
|         CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op, | ||||
|  | @ -387,14 +389,14 @@ class ReduceWindowOpConverter | |||
|       lmhlo::ReduceWindowOp reduce_window_op, | ||||
|       ConversionPatternRewriter* rewriter) const { | ||||
|     auto loc = reduce_window_op.getLoc(); | ||||
|     Value init_value = | ||||
|         rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value()); | ||||
|     Value init_value = rewriter->create<memref::LoadOp>( | ||||
|         loc, reduce_window_op.init_values()[0]); | ||||
| 
 | ||||
|     Value zero = rewriter->create<ConstantIndexOp>(loc, 0); | ||||
|     Value one = rewriter->create<ConstantIndexOp>(loc, 1); | ||||
| 
 | ||||
|     // Create an outer parallel loop that spans the output of ReduceWindowOp.
 | ||||
|     Value output = reduce_window_op.out(); | ||||
|     Value output = reduce_window_op.out()[0]; | ||||
|     auto output_loop = MakeLoopOverShape(loc, output, rewriter); | ||||
| 
 | ||||
|     // Create a nested loop that traverses the window.
 | ||||
|  | @ -429,22 +431,22 @@ class ReduceWindowOpConverter | |||
|           "`window_dilations` attributes yet. The attributes will be ignored."); | ||||
|     } | ||||
| 
 | ||||
|     Value operand = reduce_window_op.operand(); | ||||
|     auto operand_type = operand.getType().cast<MemRefType>(); | ||||
|     Value input = reduce_window_op.inputs()[0]; | ||||
|     auto input_type = input.getType().cast<MemRefType>(); | ||||
| 
 | ||||
|     // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
 | ||||
|     MappedIvs mapped_ivs = | ||||
|         MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(), | ||||
|                             window_loop.getInductionVars(), rewriter); | ||||
|     MappedIvs mapped_ivs = MapWindowIvsToInput( | ||||
|         reduce_window_op, input, output_loop.getInductionVars(), | ||||
|         window_loop.getInductionVars(), rewriter); | ||||
| 
 | ||||
|     auto elem_or_init = rewriter->create<scf::IfOp>( | ||||
|         loc, operand_type.getElementType(), mapped_ivs.in_bounds, | ||||
|         loc, input_type.getElementType(), mapped_ivs.in_bounds, | ||||
|         /*withElseRegion=*/true); | ||||
| 
 | ||||
|     OpBuilder then_builder = | ||||
|         elem_or_init.getThenBodyBuilder(rewriter->getListener()); | ||||
|     Value elem = then_builder.create<mlir::memref::LoadOp>( | ||||
|         loc, reduce_window_op.operand(), mapped_ivs.ivs); | ||||
|     Value elem = | ||||
|         then_builder.create<mlir::memref::LoadOp>(loc, input, mapped_ivs.ivs); | ||||
|     then_builder.create<scf::YieldOp>(loc, elem); | ||||
| 
 | ||||
|     OpBuilder else_builder = | ||||
|  | @ -611,9 +613,9 @@ class SelectAndScatterOpConverter | |||
|         OpBuilder::atBlockEnd(window_loops.inner_loop.getBody()); | ||||
| 
 | ||||
|     // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
 | ||||
|     MappedIvs mapped_ivs = | ||||
|         MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(), | ||||
|                             window_loops.window_ivs, &inner_loop_b); | ||||
|     MappedIvs mapped_ivs = MapWindowIvsToInput( | ||||
|         s_and_s_op, s_and_s_op.operand(), loop_over_src.getInductionVars(), | ||||
|         window_loops.window_ivs, &inner_loop_b); | ||||
| 
 | ||||
|     IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); | ||||
| 
 | ||||
|  |  | |||
|  | @ -1863,6 +1863,43 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1 | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_window_sum_max_nhwc(%arg0: tensor<1x18x18x64xf32>, | ||||
|                              %arg1: tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) { | ||||
|   %0:2 = "mhlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ( { | ||||
|   ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>, %arg4: tensor<f32>, %arg5 : tensor<f32>): | ||||
|     %1 = mhlo.add %arg2, %arg4 : tensor<f32> | ||||
|     %2 = mhlo.maximum %arg3, %arg5 : tensor<f32> | ||||
|     "mhlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> () | ||||
|   }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, | ||||
|       window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<1x18x18x64xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) | ||||
|   return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @reduce_window_sum_max_nhwc | ||||
| // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||
| // CHECK:         %[[WINDOW0:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> | ||||
| // CHECK:         %[[INIT0:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> | ||||
| // CHECK:         %[[INIT_VAL0:.+]] = tensor.extract %[[ARG1]][] : tensor<f32> | ||||
| // CHECK:         %[[FILL0:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> | ||||
| // CHECK:         %[[RES0:.+]] = linalg.pooling_nhwc_sum | ||||
| // CHECK-SAME:      {dilations = dense<1> : vector<2xi64> | ||||
| // CHECK-SAME:       strides = dense<2> : vector<2xi64>} | ||||
| // CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) | ||||
| // CHECK-SAME:      outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> | ||||
| // CHECK:         %[[WINDOW1:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> | ||||
| // CHECK:         %[[INIT1:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> | ||||
| // CHECK:         %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor<f32> | ||||
| // CHECK:         %[[FILL1:.+]] = linalg.fill(%[[INIT1]], %[[INIT_VAL1]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> | ||||
| // CHECK:         %[[RES1:.+]] = linalg.pooling_nhwc_max | ||||
| // CHECK-SAME:      {dilations = dense<1> : vector<2xi64> | ||||
| // CHECK-SAME:       strides = dense<2> : vector<2xi64>} | ||||
| // CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) | ||||
| // CHECK-SAME:      outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> | ||||
| // CHECK:         return %[[RES0]], %[[RES1]] | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @torch_select_index(%arg0: tensor<5x1x5xi32>, | ||||
|                          %arg1: tensor<2xi32>) ->  tensor<2x1x5xi32> { | ||||
|   %0 = "mhlo.torch_index_select"(%arg0, %arg1) { | ||||
|  |  | |||
|  | @ -1389,3 +1389,37 @@ func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> { | |||
|   %1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> | ||||
|   return %1 : tensor<2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK: func @reduce_window | ||||
| func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor<f32>, %init1: tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) { | ||||
|   %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ | ||||
|          ^bb0(%a0: tensor<f32>, %a1: tensor<i32>, %b0: tensor<f32>, %b1: tensor<i32>):  // no predecessors | ||||
|               %2 = mhlo.add %a0, %b0 : tensor<f32> | ||||
|               %3 = mhlo.add %a1, %b1 : tensor<i32> | ||||
|               %4 = "mhlo.tuple"(%2, %3) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> | ||||
|               "mhlo.return"(%4) : (tuple<tensor<f32>, tensor<i32>>) -> () | ||||
|             }) | ||||
|          { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, | ||||
|            window_dimensions = dense<[5, 1]> : tensor<2xi64>, | ||||
|            window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x2xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) | ||||
|   return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %init0: tensor<f32>, %init1: tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) { | ||||
|   // expected-error @+1 {{requires same shape for all inputs}} | ||||
|   %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ | ||||
|          ^bb0(%a0: tensor<f32>, %a1: tensor<i32>, %b0: tensor<f32>, %b1: tensor<i32>):  // no predecessors | ||||
|               %2 = mhlo.add %a0, %b0 : tensor<f32> | ||||
|               %3 = mhlo.add %a1, %b1 : tensor<i32> | ||||
|               %4 = "mhlo.tuple"(%2, %3) : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> | ||||
|               "mhlo.return"(%4) : (tuple<tensor<f32>, tensor<i32>>) -> () | ||||
|             }) | ||||
|          { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, | ||||
|            window_dimensions = dense<[5, 1]> : tensor<2xi64>, | ||||
|            window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) | ||||
|   return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue