Transition to value-typed Value, rename Value* -> Value, in accordance with upstream MLIR style change.
This commit is contained in:
		
							parent
							
								
									33f988e18a
								
							
						
					
					
						commit
						0582846864
					
				|  | @ -71,7 +71,7 @@ struct OnnxOnnfSymbolMapping { | |||
|    *  @param name onnx tensor name. | ||||
|    *  @return onnf tensor corresponding to `name`. | ||||
|    */ | ||||
|   mlir::Value *GetTensorByOnnxName(std::string name) { | ||||
|   mlir::Value GetTensorByOnnxName(std::string name) { | ||||
|     assert(onnx_name2onnf_tensor.find(legalize_name(name)) != | ||||
|                             onnx_name2onnf_tensor.end() && | ||||
|                         "Tensor not found"); | ||||
|  | @ -81,9 +81,9 @@ struct OnnxOnnfSymbolMapping { | |||
|   /*!
 | ||||
|    *  Add a new mapping from onnx tensor name to MLIR symbol. | ||||
|    *  @param name onnx tensor name. | ||||
|    *  @param tensor MLIR Value* pointer. | ||||
|    *  @param tensor MLIR Value  pointer. | ||||
|    */ | ||||
|   void AddMapping(std::string name, mlir::Value *tensor) { | ||||
|   void AddMapping(std::string name, mlir::Value tensor) { | ||||
|     assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && | ||||
|                         "Tensor already exists."); | ||||
|     onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); | ||||
|  | @ -97,7 +97,7 @@ private: | |||
|   /*!
 | ||||
|    *  mapping from onnx tensor names to MLIR tensor. | ||||
|    */ | ||||
|   std::map<std::string, mlir::Value*> onnx_name2onnf_tensor; | ||||
|   std::map<std::string, mlir::Value> onnx_name2onnf_tensor; | ||||
| }; | ||||
| 
 | ||||
| class FrontendGenImpl { | ||||
|  | @ -192,13 +192,13 @@ private: | |||
| 
 | ||||
|   /*!
 | ||||
|    * Import a input tensor symbol by recording a new entry in frontend_symbols_ | ||||
|    * recording the mapping between legalized onnx tensor name and mlir::Value* | ||||
|    * recording the mapping between legalized onnx tensor name and mlir::Value | ||||
|    * for further lookup in computation node importing. | ||||
|    * @param input onnx input tensor ValueInfoProto. | ||||
|    * @param symbol mlir input argument. | ||||
|    */ | ||||
|   void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, | ||||
|                                mlir::Value *symbol) { | ||||
|                                mlir::Value symbol) { | ||||
|     auto input_tensor_legalized_name = legalize_name(input.name()); | ||||
|     assert( | ||||
|         !frontend_symbols_.ContainKey(input_tensor_legalized_name) && | ||||
|  | @ -480,7 +480,7 @@ private: | |||
|   } | ||||
| 
 | ||||
|   void ImportNodeGeneric(onnx::NodeProto node) { | ||||
|     std::vector<mlir::Value *> inputs; | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (auto item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||
|  | @ -515,7 +515,7 @@ private: | |||
|       onnx::NodeProto node, int nIn, int nOut, | ||||
|       std::initializer_list<std::tuple<std::string, std::string, std::string>> | ||||
|           attrs) { | ||||
|     std::vector<mlir::Value *> inputs; | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (auto item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||
|  | @ -562,7 +562,7 @@ private: | |||
|       onnx::NodeProto node, int nIn, int nOut, | ||||
|       std::initializer_list<std::tuple<std::string, std::string, std::string>> | ||||
|           attrs) { | ||||
|     std::vector<mlir::Value *> inputs; | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (auto item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||
|  | @ -633,7 +633,7 @@ private: | |||
|   } | ||||
| 
 | ||||
|   void ImportNode(onnx::NodeProto node) { | ||||
|     std::vector<mlir::Value *> inputs; | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (auto item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||
|  | @ -662,17 +662,17 @@ private: | |||
|    * Import output tensor, by doing the following: | ||||
|    * - Add the type of this output tensor to a list of tensor | ||||
|    *   types representing return types of this graph function. | ||||
|    * - Add this output tensor to the list of mlir::Value* | ||||
|    * - Add this output tensor to the list of mlir::Value | ||||
|    *   to be returned by the function representing computation graph. | ||||
|    * @param output onnx output tensor ValueInfoProto. | ||||
|    * @param ret_types a vector of tensor types representing graph's | ||||
|    *   output tensor types. | ||||
|    * @param ret_vals a vector of mlir Value* representing graph's | ||||
|    * @param ret_vals a vector of mlir Value  representing graph's | ||||
|    *   output tensor. | ||||
|    */ | ||||
|   void ImportOutputTensor(const onnx::ValueInfoProto &output, | ||||
|                           llvm::SmallVectorImpl<mlir::Type> &ret_types, | ||||
|                           llvm::SmallVectorImpl<mlir::Value *> &ret_vals) { | ||||
|                           llvm::SmallVectorImpl<mlir::Value> &ret_vals) { | ||||
|     auto output_tensor_legalized_name = legalize_name(output.name()); | ||||
|     assert( | ||||
|         frontend_symbols_.ContainKey(output_tensor_legalized_name) && | ||||
|  | @ -722,7 +722,7 @@ private: | |||
|     } | ||||
| 
 | ||||
|     llvm::SmallVector<mlir::Type, 4> ret_types; | ||||
|     llvm::SmallVector<mlir::Value *, 4> ret_vals; | ||||
|     llvm::SmallVector<mlir::Value, 4> ret_vals; | ||||
|     // Import the output tensors
 | ||||
|     for (const auto &output : graph.output()) { | ||||
|       ImportOutputTensor(output, ret_types, ret_vals); | ||||
|  |  | |||
|  | @ -9,8 +9,9 @@ namespace onnf { | |||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||
|     const Type& operandType, Value*& operand) { | ||||
| ParseResult | ||||
| KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType, | ||||
|                                                Value &operand) { | ||||
|   // If operand queue is empty, parse more operands and cache them.
 | ||||
|   if (_operandRefQueue.empty()) { | ||||
|     // Parse operand types:
 | ||||
|  | @ -27,7 +28,7 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | |||
|     auto operand_ref = _operandRefQueue.front(); | ||||
|     _operandRefQueue.pop(); | ||||
| 
 | ||||
|     llvm::SmallVector<Value*, 1> operands; | ||||
|     llvm::SmallVector<Value, 1> operands; | ||||
|     _parser.resolveOperand(operand_ref, operandType, operands); | ||||
|     operand = operands.front(); | ||||
|     return success(); | ||||
|  | @ -38,8 +39,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | |||
| } | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||
|     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { | ||||
|   Value* operand = nullptr; | ||||
|     const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) { | ||||
|   Value operand = nullptr; | ||||
|   if (ParseOptionalOperand(operandType, operand)) | ||||
|     return failure(); | ||||
| 
 | ||||
|  | @ -47,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | |||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOperand( | ||||
|     const Type& operandType, Value*& operand) { | ||||
| ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType, | ||||
|                                                    Value &operand) { | ||||
|   if (ParseOptionalOperand(operandType, operand)) | ||||
|     return _parser.emitError( | ||||
|         _parser.getCurrentLocation(), "Expecting an operand."); | ||||
|  | @ -56,7 +57,7 @@ ParseResult KrnlDialectOperandParser::ParseOperand( | |||
| } | ||||
| 
 | ||||
| ParseResult KrnlDialectOperandParser::ParseOperand( | ||||
|     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { | ||||
|     const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) { | ||||
|   if (ParseOptionalOperand(operandType, operandList)) | ||||
|     return _parser.emitError( | ||||
|         _parser.getCurrentLocation(), "Expecting an operand."); | ||||
|  | @ -129,7 +130,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { | |||
|   boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||
| } | ||||
| 
 | ||||
| void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) { | ||||
| void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) { | ||||
|   if (boundMaps.size() % 2 == 0) | ||||
|     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); | ||||
|   AffineMap map = builder.getSymbolIdentityMap(); | ||||
|  |  | |||
|  | @ -17,20 +17,22 @@ class KrnlDialectOperandParser { | |||
|       : _parser(parser), _builder(parser.getBuilder()){}; | ||||
| 
 | ||||
|   // Parse an optional operand.
 | ||||
|   mlir::ParseResult ParseOptionalOperand( | ||||
|       const mlir::Type& operandType, mlir::Value*& operand); | ||||
|   mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType, | ||||
|                                          mlir::Value &operand); | ||||
| 
 | ||||
|   // Parse an optional operand and push it to an operand list.
 | ||||
|   mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType, | ||||
|       llvm::SmallVectorImpl<mlir::Value*>& operandList); | ||||
|   mlir::ParseResult | ||||
|   ParseOptionalOperand(const mlir::Type &operandType, | ||||
|                        llvm::SmallVectorImpl<mlir::Value> &operandList); | ||||
| 
 | ||||
|   // Parse a required operand.
 | ||||
|   mlir::ParseResult ParseOperand( | ||||
|       const mlir::Type& operandType, mlir::Value*& operand); | ||||
|   mlir::ParseResult ParseOperand(const mlir::Type &operandType, | ||||
|                                  mlir::Value &operand); | ||||
| 
 | ||||
|   // Parse a required operand and push it to an operand list.
 | ||||
|   mlir::ParseResult ParseOperand(const mlir::Type& operandType, | ||||
|       llvm::SmallVectorImpl<mlir::Value*>& operandList); | ||||
|   mlir::ParseResult | ||||
|   ParseOperand(const mlir::Type &operandType, | ||||
|                llvm::SmallVectorImpl<mlir::Value> &operandList); | ||||
| 
 | ||||
|   // Do we have more operands to parse?
 | ||||
|   bool hasOperandLeft() { return !_operandRefQueue.empty(); } | ||||
|  | @ -63,11 +65,10 @@ void printBound(mlir::AffineMapAttr boundMap, | |||
| namespace mlir { | ||||
| 
 | ||||
| struct KrnlIterateOperandPack { | ||||
|   KrnlIterateOperandPack(mlir::Builder& builder, | ||||
|       llvm::ArrayRef<mlir::Value*> inputLoops, | ||||
|       llvm::ArrayRef<mlir::Value*> optimizedLoops) | ||||
|       : builder(builder), | ||||
|         inputLoops(inputLoops), | ||||
|   KrnlIterateOperandPack(mlir::Builder &builder, | ||||
|                          llvm::ArrayRef<mlir::Value> inputLoops, | ||||
|                          llvm::ArrayRef<mlir::Value> optimizedLoops) | ||||
|       : builder(builder), inputLoops(inputLoops), | ||||
|         optimizedLoops(optimizedLoops) { | ||||
|     _operands.insert( | ||||
|         _operands.end(), optimizedLoops.begin(), optimizedLoops.end()); | ||||
|  | @ -75,9 +76,9 @@ struct KrnlIterateOperandPack { | |||
| 
 | ||||
|   void pushConstantBound(int64_t bound); | ||||
| 
 | ||||
|   void pushOperandBound(mlir::Value* operand); | ||||
|   void pushOperandBound(mlir::Value operand); | ||||
| 
 | ||||
|   llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; } | ||||
|   llvm::SmallVector<mlir::Value, 8> getOperands() const { return _operands; } | ||||
| 
 | ||||
|   mlir::ArrayAttr getAttributes() const { | ||||
|     return builder.getArrayAttr(boundMaps); | ||||
|  | @ -90,11 +91,11 @@ struct KrnlIterateOperandPack { | |||
|  private: | ||||
|   int _boundIdx = 0; | ||||
| 
 | ||||
|   llvm::SmallVector<mlir::Value*, 8> _operands; | ||||
|   llvm::SmallVector<mlir::Value, 8> _operands; | ||||
| 
 | ||||
|   llvm::SmallVector<mlir::Attribute, 8> boundMaps; | ||||
| 
 | ||||
|   llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops; | ||||
|   llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops; | ||||
| 
 | ||||
|   mlir::Builder& builder; | ||||
| }; | ||||
|  |  | |||
|  | @ -44,21 +44,21 @@ static MemRefType convertTensorToMemRef(TensorType type) { | |||
| } | ||||
| 
 | ||||
| /// Insert an allocation and deallocation for the given MemRefType.
 | ||||
| static Value *insertAllocAndDealloc(MemRefType type, Location loc, | ||||
|                                     PatternRewriter &rewriter, | ||||
|                                     bool insertDealloc, | ||||
|                                     ArrayRef<Value *> operands = {}) { | ||||
| static Value insertAllocAndDealloc(MemRefType type, Location loc, | ||||
|                                    PatternRewriter &rewriter, | ||||
|                                    bool insertDealloc, | ||||
|                                    ArrayRef<Value> operands = {}) { | ||||
|   // Put together alloc operands for any dynamic dimensions of the memref.
 | ||||
|   AllocOp alloc; | ||||
|   if (!operands.empty()) { | ||||
|     auto memRefShape = type.getShape(); | ||||
|     auto rank = memRefShape.size(); | ||||
| 
 | ||||
|     std::map<int, Value *> fromOperands; | ||||
|     std::map<int, Value> fromOperands; | ||||
|     for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { | ||||
|       int memRefDimIdx = rank - 1 - reversedIdx; | ||||
|       if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
 | ||||
|         Value *maxDim = nullptr; | ||||
|         Value maxDim = nullptr; | ||||
|         for (int i = 0; i < operands.size(); i++) { | ||||
|           auto operandShape = | ||||
|               operands[i]->getType().cast<MemRefType>().getShape(); | ||||
|  | @ -85,7 +85,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     SmallVector<Value *, 4> allocOperands; | ||||
|     SmallVector<Value, 4> allocOperands; | ||||
|     for (int i = 0; i < rank; ++i) | ||||
|       if (memRefShape[i] < 0) | ||||
|         allocOperands.push_back(fromOperands[i]); | ||||
|  | @ -146,14 +146,14 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { | |||
| 
 | ||||
| // Get run-time dimension information for unknown dimensions used for
 | ||||
| // broadcasting.
 | ||||
| std::map<int, std::map<int, Value *>> | ||||
| std::map<int, std::map<int, Value>> | ||||
| getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | ||||
|                       MemRefType memRefType, ArrayRef<Value *> operands) { | ||||
|                       MemRefType memRefType, ArrayRef<Value> operands) { | ||||
|   auto memRefShape = memRefType.getShape(); | ||||
|   int64_t rank = memRefShape.size(); | ||||
|   // For unknown dimensions, we need to get dimension values at runtime in
 | ||||
|   // order to do broadcasting.
 | ||||
|   std::map<int, std::map<int, Value *>> DimInfo; | ||||
|   std::map<int, std::map<int, Value>> DimInfo; | ||||
|   // For each result dimension, compute the number of sharing operands.
 | ||||
|   // Sharing operands are operands sharing the same index (counting from the
 | ||||
|   // rightmost to the leftmost) for a given dimension.
 | ||||
|  | @ -173,7 +173,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | |||
|   // We only care about unknown dimensions whose number of sharing operands is
 | ||||
|   // more than one, since they are potentially broadcasted dimensions.
 | ||||
|   for (int i = 0; i < operands.size(); ++i) { | ||||
|     std::map<int, Value *> broadcastedDims; | ||||
|     std::map<int, Value> broadcastedDims; | ||||
|     auto shape = operands[i]->getType().cast<MemRefType>().getShape(); | ||||
|     int size = shape.size(); | ||||
|     for (int j = 0; j < shape.size(); ++j) { | ||||
|  | @ -192,17 +192,17 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | |||
| 
 | ||||
| // Extract induction variables that are used for broadcasting values of a
 | ||||
| // given operand.
 | ||||
| std::vector<Value *> | ||||
| std::vector<Value> | ||||
| getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, | ||||
|                           ArrayRef<Value *> loopIVs, Value *operand, | ||||
|                           std::map<int, Value *> broadcastedDims) { | ||||
|                           ArrayRef<Value> loopIVs, Value operand, | ||||
|                           std::map<int, Value> broadcastedDims) { | ||||
|   // `operand` must has a ranked type. This should have been checked by the
 | ||||
|   // shape inference pass.
 | ||||
|   auto operandShape = operand->getType().cast<MemRefType>().getShape(); | ||||
|   auto rank = operandShape.size(); | ||||
|   auto loopCount = loopIVs.size(); | ||||
| 
 | ||||
|   std::vector<Value *> newLoopIVs; | ||||
|   std::vector<Value> newLoopIVs; | ||||
|   for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { | ||||
|     auto dimIdx = rank - 1 - reversedIdx; | ||||
|     auto loopIdx = loopCount - 1 - reversedIdx; | ||||
|  | @ -247,7 +247,7 @@ struct ScalarOp<ONNXMulOp> { | |||
| template <> | ||||
| struct ScalarOp<ONNXDivOp> { | ||||
|   using FOp = DivFOp; | ||||
|   using IOp = DivISOp; | ||||
|   using IOp = SignedDivIOp; | ||||
| }; | ||||
| 
 | ||||
| template <> | ||||
|  | @ -295,9 +295,9 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp; | |||
| // Scalar unary ops for lowering to Krnl dialect.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <typename UnaryOp> | ||||
| Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | ||||
|                           ArrayRef<Value *> operands, | ||||
|                           ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | ||||
|                          ArrayRef<Value> operands, | ||||
|                          ConversionPatternRewriter &rewriter) { | ||||
|   /* Lower UnaryOp to Ops in the Standard dialect.
 | ||||
|    */ | ||||
|   auto loc = op->getLoc(); | ||||
|  | @ -318,14 +318,13 @@ Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXTanhOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op, | ||||
|                                       ArrayRef<Type> result_types, | ||||
|                                       ArrayRef<Value *> operands, | ||||
|                                       ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||
|   //                         AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|   auto neg = rewriter.create<SubFOp>(loc, zero, operand); | ||||
|  | @ -342,14 +341,13 @@ Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXSinhOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op, | ||||
|                                       ArrayRef<Type> result_types, | ||||
|                                       ArrayRef<Value *> operands, | ||||
|                                       ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||
|   //                         ConstantOp 2)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|   auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); | ||||
|  | @ -366,14 +364,13 @@ Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXCoshOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op, | ||||
|                                       ArrayRef<Type> result_types, | ||||
|                                       ArrayRef<Value *> operands, | ||||
|                                       ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||
|   //                         ConstantOp 2)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|   auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); | ||||
|  | @ -390,14 +387,14 @@ Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXSigmoidOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | ||||
|                                          ArrayRef<Type> result_types, | ||||
|                                          ArrayRef<Value *> operands, | ||||
|                                          ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | ||||
|                                         ArrayRef<Type> result_types, | ||||
|                                         ArrayRef<Value> operands, | ||||
|                                         ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
 | ||||
|   //                            AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|   auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); | ||||
|  | @ -413,8 +410,8 @@ Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXHardSigmoidOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands, | ||||
| Value mapToLowerScalarOp<ONNXHardSigmoidOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
|   // %Y = AddFOp(MulFOp(alpha, %X), beta)
 | ||||
|   // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
 | ||||
|  | @ -424,7 +421,7 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( | |||
|   //                                  %Z,
 | ||||
|   //                                  Constant 1)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); | ||||
|   auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); | ||||
| 
 | ||||
|  | @ -449,14 +446,14 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( | |||
| // Scalar unary ops for lowering ONNXEluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value *> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||
|   //                          MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
 | ||||
|   //                          %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|  | @ -478,15 +475,14 @@ Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXReluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op, | ||||
|                                       ArrayRef<Type> result_types, | ||||
|                                       ArrayRef<Value *> operands, | ||||
|                                       ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||
|   //                           ConstantOp 0,
 | ||||
|   //                           %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|   auto lessThanZero = | ||||
|  | @ -500,15 +496,15 @@ Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXLeakyReluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value * | ||||
| mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value *> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, | ||||
|                                           ArrayRef<Type> result_types, | ||||
|                                           ArrayRef<Value> operands, | ||||
|                                           ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||
|   //                                MulFOp(alpha, %X),
 | ||||
|   //                                %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); | ||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||
|  | @ -525,17 +521,16 @@ mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXSeluOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op, | ||||
|                                       ArrayRef<Type> result_types, | ||||
|                                       ArrayRef<Value *> operands, | ||||
|                                       ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
 | ||||
|   //                           MulFOp(gamma, %X),
 | ||||
|   //                           MulFOp(gamma,
 | ||||
|   //                                  SubFOp(MulFOp(alpha, ExpOp(%X)),
 | ||||
|   //                                         alpha)))
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); | ||||
|   auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); | ||||
| 
 | ||||
|  | @ -558,13 +553,12 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op, | |||
| // Scalar unary ops for lowering ONNXReciprocalOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value * | ||||
| mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value *> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXReciprocalOp>( | ||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *operand = operands[0]; | ||||
|   Value operand = operands[0]; | ||||
| 
 | ||||
|   auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); | ||||
|   auto result = rewriter.create<DivFOp>(loc, one, operand); | ||||
|  | @ -576,15 +570,15 @@ mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXMaxOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value *> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
 | ||||
|   //                              %X,
 | ||||
|   //                              %Y)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *lhs = operands[0]; | ||||
|   Value *rhs = operands[1]; | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
|   auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); | ||||
|   auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); | ||||
|   return result; | ||||
|  | @ -594,15 +588,15 @@ Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | |||
| // Scalar unary ops for lowering ONNXMinOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                      ArrayRef<Value *> operands, | ||||
|                                      ConversionPatternRewriter &rewriter) { | ||||
| Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, | ||||
|                                     ArrayRef<Value> operands, | ||||
|                                     ConversionPatternRewriter &rewriter) { | ||||
|   // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
 | ||||
|   //                              %X,
 | ||||
|   //                              %Y)
 | ||||
|   auto loc = op->getLoc(); | ||||
|   Value *lhs = operands[0]; | ||||
|   Value *rhs = operands[1]; | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
|   auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); | ||||
|   auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); | ||||
|   return result; | ||||
|  | @ -615,7 +609,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
|   ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) | ||||
|       : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} | ||||
|   PatternMatchResult | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|                   ConversionPatternRewriter &rewriter) const final { | ||||
|     // TODO: Check that the types are valid.
 | ||||
|     // An element-wise unary operation must have all operands and the result of
 | ||||
|  | @ -632,7 +626,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
|     // dimensions with the result at this pre-optimization phase.
 | ||||
|     // TODO: verify that dimensions match.
 | ||||
|     // TODO: can the dimension of the result differ after optimizations?
 | ||||
|     Value *alloc; | ||||
|     Value alloc; | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
| 
 | ||||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|  | @ -647,7 +641,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
| 
 | ||||
|     // Define loops.
 | ||||
|     auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); | ||||
|     std::vector<Value *> originalLoops; | ||||
|     std::vector<Value> originalLoops; | ||||
|     originalLoops.reserve(rank); | ||||
|     for (auto result : loopsOp.getResults()) { | ||||
|       originalLoops.push_back(result); | ||||
|  | @ -655,7 +649,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
| 
 | ||||
|     // Define loop optimization.
 | ||||
|     auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); | ||||
|     std::vector<Value *> optimizedLoops; | ||||
|     std::vector<Value> optimizedLoops; | ||||
|     optimizedLoops.reserve(rank); | ||||
|     for (auto result : optimizedLoopsOp.getResults()) { | ||||
|       optimizedLoops.push_back(result); | ||||
|  | @ -695,7 +689,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | |||
|     rewriter.setInsertionPointToStart(&iterationBlock); | ||||
| 
 | ||||
|     // Handle the operation:
 | ||||
|     SmallVector<Value *, 4> loopIVs; | ||||
|     SmallVector<Value, 4> loopIVs; | ||||
|     for (auto arg : iterationBlock.getArguments()) | ||||
|       loopIVs.push_back(arg); | ||||
| 
 | ||||
|  | @ -718,7 +712,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|   ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) | ||||
|       : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} | ||||
|   PatternMatchResult | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|                   ConversionPatternRewriter &rewriter) const final { | ||||
|     // TODO: Check that the types are valid.
 | ||||
|     // An element-wise variadic operation must have all operands and the result
 | ||||
|  | @ -730,7 +724,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|     // Insert an allocation and deallocation for the result of this operation.
 | ||||
|     auto memRefType = convertTensorToMemRef(tensorType); | ||||
| 
 | ||||
|     Value *alloc; | ||||
|     Value alloc; | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
|     // If the output has a dynamic dimension, we compute its dimension at
 | ||||
|     // runtime by using dimensions from the operands.
 | ||||
|  | @ -749,7 +743,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
| 
 | ||||
|     // Define loops.
 | ||||
|     auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); | ||||
|     std::vector<Value *> originalLoops; | ||||
|     std::vector<Value> originalLoops; | ||||
|     originalLoops.reserve(rank); | ||||
|     for (auto result : loopsOp.getResults()) { | ||||
|       originalLoops.push_back(result); | ||||
|  | @ -757,7 +751,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
| 
 | ||||
|     // Define loop optimization.
 | ||||
|     auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); | ||||
|     std::vector<Value *> optimizedLoops; | ||||
|     std::vector<Value> optimizedLoops; | ||||
|     optimizedLoops.reserve(rank); | ||||
|     for (auto result : optimizedLoopsOp.getResults()) { | ||||
|       optimizedLoops.push_back(result); | ||||
|  | @ -781,7 +775,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
| 
 | ||||
|     // Get run-time dimension information for unknown dimensions used for
 | ||||
|     // broadcasting.
 | ||||
|     std::map<int, std::map<int, Value *>> broadcastedDimInfo = | ||||
|     std::map<int, std::map<int, Value>> broadcastedDimInfo = | ||||
|         getBroadcastedDimInfo(loc, rewriter, memRefType, operands); | ||||
| 
 | ||||
|     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); | ||||
|  | @ -801,12 +795,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|     rewriter.setInsertionPointToStart(&iterationBlock); | ||||
| 
 | ||||
|     // Handle the operation:
 | ||||
|     SmallVector<Value *, 4> loopIVs; | ||||
|     SmallVector<Value, 4> loopIVs; | ||||
|     for (auto arg : iterationBlock.getArguments()) | ||||
|       loopIVs.push_back(arg); | ||||
| 
 | ||||
|     // Fold over operands for each of their scalar values
 | ||||
|     Value *accumulated, *next; | ||||
|     Value accumulated, next; | ||||
|     auto accumulatedLoopIVs = getLoopIVsForBroadcasting( | ||||
|         loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); | ||||
|     accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs); | ||||
|  | @ -831,17 +825,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern { | |||
|       : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} | ||||
| 
 | ||||
|   PatternMatchResult | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|                   ConversionPatternRewriter &rewriter) const final { | ||||
|     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); | ||||
|     auto loc = op->getLoc(); | ||||
| 
 | ||||
|     // Insert an allocation and deallocation for the result of this operation.
 | ||||
|     auto memRefType = convertTensorToMemRef(tensorType); | ||||
|     Value *alloc; | ||||
|     Value alloc; | ||||
| 
 | ||||
|     // Compute size in bytes.
 | ||||
|     Value *tensorSize = rewriter.create<ConstantOp>( | ||||
|     Value tensorSize = rewriter.create<ConstantOp>( | ||||
|         loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), | ||||
|                                      getMemRefEltSizeInBytes(memRefType))); | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
|  | @ -849,14 +843,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern { | |||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||
|     } else { | ||||
|       auto memRefShape = memRefType.getShape(); | ||||
|       SmallVector<Value *, 4> allocOperands; | ||||
|       SmallVector<Value, 4> allocOperands; | ||||
|       for (int i = 0; i < memRefShape.size(); ++i) { | ||||
|         // The shape array can always be used to construct shape information of
 | ||||
|         // the result.
 | ||||
|         Value *index = rewriter.create<ConstantOp>( | ||||
|         Value index = rewriter.create<ConstantOp>( | ||||
|             loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); | ||||
|         Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); | ||||
|         Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>( | ||||
|         Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); | ||||
|         Value int64LoadedVal = rewriter.create<ZeroExtendIOp>( | ||||
|             loc, loadedVal, rewriter.getIntegerType(64)); | ||||
|         tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal); | ||||
|         allocOperands.push_back(rewriter.create<IndexCastOp>( | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> { | |||
|       operandItr++; | ||||
| 
 | ||||
|       // Organize operands into lower/upper bounds in affine.for ready formats.
 | ||||
|       SmallVector<Value *, 4> lbOperands, ubOperands; | ||||
|       SmallVector<Value, 4> lbOperands, ubOperands; | ||||
|       AffineMap lbMap, ubMap; | ||||
|       for (int boundType = 0; boundType < 2; boundType++) { | ||||
|         auto &operands = boundType == 0 ? lbOperands : ubOperands; | ||||
|  |  | |||
|  | @ -51,7 +51,7 @@ public: | |||
|       : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} | ||||
| 
 | ||||
|   PatternMatchResult | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|                   ConversionPatternRewriter &rewriter) const override { | ||||
|     auto *context = op->getContext(); | ||||
|     auto loc = op->getLoc(); | ||||
|  | @ -66,27 +66,27 @@ public: | |||
|     // First operand.
 | ||||
|     Type dstType = | ||||
|         operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); | ||||
|     Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||
|     Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||
|         loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); | ||||
|     Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( | ||||
|     Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( | ||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); | ||||
| 
 | ||||
|     // Second operand.
 | ||||
|     Type srcType = | ||||
|         operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); | ||||
|     Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||
|     Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||
|         loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); | ||||
|     Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( | ||||
|     Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( | ||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); | ||||
| 
 | ||||
|     // Size.
 | ||||
|     Value *int64Size = rewriter.create<LLVM::SExtOp>( | ||||
|     Value int64Size = rewriter.create<LLVM::SExtOp>( | ||||
|         loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); | ||||
| 
 | ||||
|     // Memcpy call
 | ||||
|     rewriter.create<CallOp>( | ||||
|         loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), | ||||
|         ArrayRef<Value *>( | ||||
|         ArrayRef<Value>( | ||||
|             {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); | ||||
| 
 | ||||
|     rewriter.eraseOp(op); | ||||
|  | @ -210,7 +210,7 @@ public: | |||
| 
 | ||||
|     // Retrieve dynamic mem refs from wrapped input, and convert every one of
 | ||||
|     // them to static mem refs.
 | ||||
|     SmallVector<Value *, 4> staticInputs; | ||||
|     SmallVector<Value, 4> staticInputs; | ||||
|     auto wrappedInput = entryPointEntryBlock.getArgument(0); | ||||
|     for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) { | ||||
|       // Call API function to retrieve the i-th dynamic memref.
 | ||||
|  | @ -225,13 +225,12 @@ public: | |||
|       auto memRefTy = memRefPtrTy.getPointerElementTy(); | ||||
|       auto one = rewriter.create<LLVM::ConstantOp>( | ||||
|           loc, int32Ty, rewriter.getI32IntegerAttr(1)); | ||||
|       Value *ptrToMemRef = | ||||
|           rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one, | ||||
|                                           /*alignment=*/0); | ||||
|       Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one, | ||||
|                                                           /*alignment=*/0); | ||||
| 
 | ||||
|       // Fill in the memref underlying ptrToMemRef with information extracted
 | ||||
|       // from dynMemRef.
 | ||||
|       fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc, | ||||
|       fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc, | ||||
|                                    apiRegistry, llvmDialect); | ||||
| 
 | ||||
|       // ptrToMemRef will be an input to main computation graph function.
 | ||||
|  | @ -261,8 +260,8 @@ public: | |||
|         loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); | ||||
|     auto outDynMemRef = callApi(rewriter, loc, apiRegistry, | ||||
|                                 API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); | ||||
|     fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc, | ||||
|                             apiRegistry, llvmDialect); | ||||
|     fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry, | ||||
|                             llvmDialect); | ||||
|     auto zero = rewriter.create<LLVM::ConstantOp>( | ||||
|         loc, int32Ty, rewriter.getI32IntegerAttr(0)); | ||||
|     callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, | ||||
|  | @ -270,7 +269,7 @@ public: | |||
| 
 | ||||
|     // Return wrapped output.
 | ||||
|     rewriter.create<LLVM::ReturnOp>(loc, | ||||
|                                     SmallVector<Value *, 1>({wrappedOutput})); | ||||
|                                     SmallVector<Value, 1>({wrappedOutput})); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -315,11 +314,11 @@ private: | |||
| 
 | ||||
|   // Call a registered API, return the return SSA values if only one result is
 | ||||
|   // returned, otherwise return nullptr.
 | ||||
|   Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, | ||||
|                  API apiId, ArrayRef<Value *> params) const { | ||||
|   Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, | ||||
|                 API apiId, ArrayRef<Value> params) const { | ||||
|     auto returnVals = rewriter.create<LLVM::CallOp>( | ||||
|         loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, | ||||
|         ArrayRef<Value *>(params)); | ||||
|         ArrayRef<Value>(params)); | ||||
|     if (returnVals.getNumResults() == 1) | ||||
|       return returnVals.getResult(0); | ||||
|     return nullptr; | ||||
|  | @ -348,12 +347,11 @@ private: | |||
|     auto memRefTy = memRefPtrTy.getPointerElementTy(); | ||||
|     auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); | ||||
| 
 | ||||
|     Value *memRef = | ||||
|         rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef); | ||||
|     Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef); | ||||
| 
 | ||||
|     // Set dataPtr and alignedDataPtr;
 | ||||
|     auto dataPtr = | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef}); | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef}); | ||||
|     dataPtr = rewriter.create<LLVM::BitcastOp>( | ||||
|         loc, memRefTy.getStructElementType(0), dataPtr); | ||||
|     memRef = rewriter.create<LLVM::InsertValueOp>( | ||||
|  | @ -373,9 +371,9 @@ private: | |||
|     // Get rank, sizes array ptr and strides array ptr.
 | ||||
|     auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); | ||||
|     auto sizesArrayPtr = | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef}); | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef}); | ||||
|     auto stridesArrayPtr = | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef}); | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef}); | ||||
| 
 | ||||
|     for (decltype(rank) i = 0; i < rank; i++) { | ||||
|       auto dimIdx = rewriter.create<LLVM::ConstantOp>( | ||||
|  | @ -384,7 +382,7 @@ private: | |||
|       // Insert size of the dimension.
 | ||||
|       auto dimSizePtr = rewriter.create<LLVM::GEPOp>( | ||||
|           loc, int64Ty.getPointerTo(), sizesArrayPtr, | ||||
|           ArrayRef<Value *>({dimIdx})); | ||||
|           ArrayRef<Value>({dimIdx})); | ||||
|       auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(), | ||||
|                                                    dimSizePtr); | ||||
|       memRef = rewriter.create<LLVM::InsertValueOp>( | ||||
|  | @ -395,7 +393,7 @@ private: | |||
|       // Insert stride of the dimension.
 | ||||
|       auto dimStridePtr = rewriter.create<LLVM::GEPOp>( | ||||
|           loc, int64Ty.getPointerTo(), sizesArrayPtr, | ||||
|           ArrayRef<Value *>({dimIdx})); | ||||
|           ArrayRef<Value>({dimIdx})); | ||||
|       auto dimStride = rewriter.create<LLVM::LoadOp>( | ||||
|           loc, int64Ty.getPointerTo(), dimStridePtr); | ||||
|       memRef = rewriter.create<LLVM::InsertValueOp>( | ||||
|  | @ -404,7 +402,7 @@ private: | |||
|               {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef); | ||||
|     rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef); | ||||
|   } | ||||
| 
 | ||||
|   void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, | ||||
|  | @ -415,19 +413,19 @@ private: | |||
|     auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); | ||||
| 
 | ||||
|     // Extract the data pointer, and record it in dynamic mem ref created.
 | ||||
|     Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>( | ||||
|         loc, outMemRefTy.getStructElementType(0), &outMemRef, | ||||
|     Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>( | ||||
|         loc, outMemRefTy.getStructElementType(0), outMemRef, | ||||
|         rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); | ||||
|     outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( | ||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); | ||||
|     callApi(rewriter, loc, apiRegistry, API::SET_DATA, | ||||
|             {&outDynMemRef, outMemRefDataPtr}); | ||||
|             {outDynMemRef, outMemRefDataPtr}); | ||||
| 
 | ||||
|     auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); | ||||
|     auto sizesArrayPtr = | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef}); | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef}); | ||||
|     auto stridesArrayPtr = | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef}); | ||||
|         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef}); | ||||
| 
 | ||||
|     for (decltype(rank) i = 0; i < rank; i++) { | ||||
|       auto dimIdx = rewriter.create<LLVM::ConstantOp>( | ||||
|  | @ -435,22 +433,22 @@ private: | |||
| 
 | ||||
|       // Transfer size of dimension from memref to dynamic memref.
 | ||||
|       auto dimSize = rewriter.create<LLVM::ExtractValueOp>( | ||||
|           loc, int64Ty, &outMemRef, | ||||
|           loc, int64Ty, outMemRef, | ||||
|           rewriter.getArrayAttr( | ||||
|               {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); | ||||
|       auto dimSizePtr = rewriter.create<LLVM::GEPOp>( | ||||
|           loc, int64Ty.getPointerTo(), sizesArrayPtr, | ||||
|           ArrayRef<Value *>({dimIdx})); | ||||
|           ArrayRef<Value>({dimIdx})); | ||||
|       rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr); | ||||
| 
 | ||||
|       // Transfer stride of dimension from memref to dynamic memref.
 | ||||
|       auto dimStride = rewriter.create<LLVM::ExtractValueOp>( | ||||
|           loc, int64Ty, &outMemRef, | ||||
|           loc, int64Ty, outMemRef, | ||||
|           rewriter.getArrayAttr( | ||||
|               {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); | ||||
|       auto dimStridePtr = rewriter.create<LLVM::GEPOp>( | ||||
|           loc, int64Ty.getPointerTo(), stridesArrayPtr, | ||||
|           ArrayRef<Value *>({dimIdx})); | ||||
|           ArrayRef<Value>({dimIdx})); | ||||
|       rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr); | ||||
|     } | ||||
|   } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue