Transition to value-typed Value, rename Value* -> Value, in accordance with upstream MLIR style change.
This commit is contained in:
		
							parent
							
								
									33f988e18a
								
							
						
					
					
						commit
						0582846864
					
				|  | @ -71,7 +71,7 @@ struct OnnxOnnfSymbolMapping { | ||||||
|    *  @param name onnx tensor name. |    *  @param name onnx tensor name. | ||||||
|    *  @return onnf tensor corresponding to `name`. |    *  @return onnf tensor corresponding to `name`. | ||||||
|    */ |    */ | ||||||
|   mlir::Value *GetTensorByOnnxName(std::string name) { |   mlir::Value GetTensorByOnnxName(std::string name) { | ||||||
|     assert(onnx_name2onnf_tensor.find(legalize_name(name)) != |     assert(onnx_name2onnf_tensor.find(legalize_name(name)) != | ||||||
|                             onnx_name2onnf_tensor.end() && |                             onnx_name2onnf_tensor.end() && | ||||||
|                         "Tensor not found"); |                         "Tensor not found"); | ||||||
|  | @ -81,9 +81,9 @@ struct OnnxOnnfSymbolMapping { | ||||||
|   /*!
 |   /*!
 | ||||||
|    *  Add a new mapping from onnx tensor name to MLIR symbol. |    *  Add a new mapping from onnx tensor name to MLIR symbol. | ||||||
|    *  @param name onnx tensor name. |    *  @param name onnx tensor name. | ||||||
|    *  @param tensor MLIR Value* pointer. |    *  @param tensor MLIR Value  pointer. | ||||||
|    */ |    */ | ||||||
|   void AddMapping(std::string name, mlir::Value *tensor) { |   void AddMapping(std::string name, mlir::Value tensor) { | ||||||
|     assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && |     assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && | ||||||
|                         "Tensor already exists."); |                         "Tensor already exists."); | ||||||
|     onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); |     onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); | ||||||
|  | @ -97,7 +97,7 @@ private: | ||||||
|   /*!
 |   /*!
 | ||||||
|    *  mapping from onnx tensor names to MLIR tensor. |    *  mapping from onnx tensor names to MLIR tensor. | ||||||
|    */ |    */ | ||||||
|   std::map<std::string, mlir::Value*> onnx_name2onnf_tensor; |   std::map<std::string, mlir::Value> onnx_name2onnf_tensor; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class FrontendGenImpl { | class FrontendGenImpl { | ||||||
|  | @ -192,13 +192,13 @@ private: | ||||||
| 
 | 
 | ||||||
|   /*!
 |   /*!
 | ||||||
|    * Import a input tensor symbol by recording a new entry in frontend_symbols_ |    * Import a input tensor symbol by recording a new entry in frontend_symbols_ | ||||||
|    * recording the mapping between legalized onnx tensor name and mlir::Value* |    * recording the mapping between legalized onnx tensor name and mlir::Value | ||||||
|    * for further lookup in computation node importing. |    * for further lookup in computation node importing. | ||||||
|    * @param input onnx input tensor ValueInfoProto. |    * @param input onnx input tensor ValueInfoProto. | ||||||
|    * @param symbol mlir input argument. |    * @param symbol mlir input argument. | ||||||
|    */ |    */ | ||||||
|   void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, |   void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, | ||||||
|                                mlir::Value *symbol) { |                                mlir::Value symbol) { | ||||||
|     auto input_tensor_legalized_name = legalize_name(input.name()); |     auto input_tensor_legalized_name = legalize_name(input.name()); | ||||||
|     assert( |     assert( | ||||||
|         !frontend_symbols_.ContainKey(input_tensor_legalized_name) && |         !frontend_symbols_.ContainKey(input_tensor_legalized_name) && | ||||||
|  | @ -480,7 +480,7 @@ private: | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void ImportNodeGeneric(onnx::NodeProto node) { |   void ImportNodeGeneric(onnx::NodeProto node) { | ||||||
|     std::vector<mlir::Value *> inputs; |     std::vector<mlir::Value> inputs; | ||||||
|     for (auto item : node.input()) { |     for (auto item : node.input()) { | ||||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { |       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); |         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||||
|  | @ -515,7 +515,7 @@ private: | ||||||
|       onnx::NodeProto node, int nIn, int nOut, |       onnx::NodeProto node, int nIn, int nOut, | ||||||
|       std::initializer_list<std::tuple<std::string, std::string, std::string>> |       std::initializer_list<std::tuple<std::string, std::string, std::string>> | ||||||
|           attrs) { |           attrs) { | ||||||
|     std::vector<mlir::Value *> inputs; |     std::vector<mlir::Value> inputs; | ||||||
|     for (auto item : node.input()) { |     for (auto item : node.input()) { | ||||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { |       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); |         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||||
|  | @ -562,7 +562,7 @@ private: | ||||||
|       onnx::NodeProto node, int nIn, int nOut, |       onnx::NodeProto node, int nIn, int nOut, | ||||||
|       std::initializer_list<std::tuple<std::string, std::string, std::string>> |       std::initializer_list<std::tuple<std::string, std::string, std::string>> | ||||||
|           attrs) { |           attrs) { | ||||||
|     std::vector<mlir::Value *> inputs; |     std::vector<mlir::Value> inputs; | ||||||
|     for (auto item : node.input()) { |     for (auto item : node.input()) { | ||||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { |       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); |         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||||
|  | @ -633,7 +633,7 @@ private: | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void ImportNode(onnx::NodeProto node) { |   void ImportNode(onnx::NodeProto node) { | ||||||
|     std::vector<mlir::Value *> inputs; |     std::vector<mlir::Value> inputs; | ||||||
|     for (auto item : node.input()) { |     for (auto item : node.input()) { | ||||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { |       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); |         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||||
|  | @ -662,17 +662,17 @@ private: | ||||||
|    * Import output tensor, by doing the following: |    * Import output tensor, by doing the following: | ||||||
|    * - Add the type of this output tensor to a list of tensor |    * - Add the type of this output tensor to a list of tensor | ||||||
|    *   types representing return types of this graph function. |    *   types representing return types of this graph function. | ||||||
|    * - Add this output tensor to the list of mlir::Value* |    * - Add this output tensor to the list of mlir::Value | ||||||
|    *   to be returned by the function representing computation graph. |    *   to be returned by the function representing computation graph. | ||||||
|    * @param output onnx output tensor ValueInfoProto. |    * @param output onnx output tensor ValueInfoProto. | ||||||
|    * @param ret_types a vector of tensor types representing graph's |    * @param ret_types a vector of tensor types representing graph's | ||||||
|    *   output tensor types. |    *   output tensor types. | ||||||
|    * @param ret_vals a vector of mlir Value* representing graph's |    * @param ret_vals a vector of mlir Value  representing graph's | ||||||
|    *   output tensor. |    *   output tensor. | ||||||
|    */ |    */ | ||||||
|   void ImportOutputTensor(const onnx::ValueInfoProto &output, |   void ImportOutputTensor(const onnx::ValueInfoProto &output, | ||||||
|                           llvm::SmallVectorImpl<mlir::Type> &ret_types, |                           llvm::SmallVectorImpl<mlir::Type> &ret_types, | ||||||
|                           llvm::SmallVectorImpl<mlir::Value *> &ret_vals) { |                           llvm::SmallVectorImpl<mlir::Value> &ret_vals) { | ||||||
|     auto output_tensor_legalized_name = legalize_name(output.name()); |     auto output_tensor_legalized_name = legalize_name(output.name()); | ||||||
|     assert( |     assert( | ||||||
|         frontend_symbols_.ContainKey(output_tensor_legalized_name) && |         frontend_symbols_.ContainKey(output_tensor_legalized_name) && | ||||||
|  | @ -722,7 +722,7 @@ private: | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     llvm::SmallVector<mlir::Type, 4> ret_types; |     llvm::SmallVector<mlir::Type, 4> ret_types; | ||||||
|     llvm::SmallVector<mlir::Value *, 4> ret_vals; |     llvm::SmallVector<mlir::Value, 4> ret_vals; | ||||||
|     // Import the output tensors
 |     // Import the output tensors
 | ||||||
|     for (const auto &output : graph.output()) { |     for (const auto &output : graph.output()) { | ||||||
|       ImportOutputTensor(output, ret_types, ret_vals); |       ImportOutputTensor(output, ret_types, ret_vals); | ||||||
|  |  | ||||||
|  | @ -9,8 +9,9 @@ namespace onnf { | ||||||
| 
 | 
 | ||||||
| using namespace mlir; | using namespace mlir; | ||||||
| 
 | 
 | ||||||
| ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ParseResult | ||||||
|     const Type& operandType, Value*& operand) { | KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType, | ||||||
|  |                                                Value &operand) { | ||||||
|   // If operand queue is empty, parse more operands and cache them.
 |   // If operand queue is empty, parse more operands and cache them.
 | ||||||
|   if (_operandRefQueue.empty()) { |   if (_operandRefQueue.empty()) { | ||||||
|     // Parse operand types:
 |     // Parse operand types:
 | ||||||
|  | @ -27,7 +28,7 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||||
|     auto operand_ref = _operandRefQueue.front(); |     auto operand_ref = _operandRefQueue.front(); | ||||||
|     _operandRefQueue.pop(); |     _operandRefQueue.pop(); | ||||||
| 
 | 
 | ||||||
|     llvm::SmallVector<Value*, 1> operands; |     llvm::SmallVector<Value, 1> operands; | ||||||
|     _parser.resolveOperand(operand_ref, operandType, operands); |     _parser.resolveOperand(operand_ref, operandType, operands); | ||||||
|     operand = operands.front(); |     operand = operands.front(); | ||||||
|     return success(); |     return success(); | ||||||
|  | @ -38,8 +39,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||||
|     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { |     const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) { | ||||||
|   Value* operand = nullptr; |   Value operand = nullptr; | ||||||
|   if (ParseOptionalOperand(operandType, operand)) |   if (ParseOptionalOperand(operandType, operand)) | ||||||
|     return failure(); |     return failure(); | ||||||
| 
 | 
 | ||||||
|  | @ -47,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( | ||||||
|   return success(); |   return success(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ParseResult KrnlDialectOperandParser::ParseOperand( | ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType, | ||||||
|     const Type& operandType, Value*& operand) { |                                                    Value &operand) { | ||||||
|   if (ParseOptionalOperand(operandType, operand)) |   if (ParseOptionalOperand(operandType, operand)) | ||||||
|     return _parser.emitError( |     return _parser.emitError( | ||||||
|         _parser.getCurrentLocation(), "Expecting an operand."); |         _parser.getCurrentLocation(), "Expecting an operand."); | ||||||
|  | @ -56,7 +57,7 @@ ParseResult KrnlDialectOperandParser::ParseOperand( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ParseResult KrnlDialectOperandParser::ParseOperand( | ParseResult KrnlDialectOperandParser::ParseOperand( | ||||||
|     const Type& operandType, llvm::SmallVectorImpl<Value*>& operandList) { |     const Type &operandType, llvm::SmallVectorImpl<Value> &operandList) { | ||||||
|   if (ParseOptionalOperand(operandType, operandList)) |   if (ParseOptionalOperand(operandType, operandList)) | ||||||
|     return _parser.emitError( |     return _parser.emitError( | ||||||
|         _parser.getCurrentLocation(), "Expecting an operand."); |         _parser.getCurrentLocation(), "Expecting an operand."); | ||||||
|  | @ -129,7 +130,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { | ||||||
|   boundMaps.emplace_back(AffineMapAttr::get(map)); |   boundMaps.emplace_back(AffineMapAttr::get(map)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void KrnlIterateOperandPack::pushOperandBound(mlir::Value* operand) { | void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) { | ||||||
|   if (boundMaps.size() % 2 == 0) |   if (boundMaps.size() % 2 == 0) | ||||||
|     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); |     _operands.emplace_back(inputLoops[boundMaps.size() / 2]); | ||||||
|   AffineMap map = builder.getSymbolIdentityMap(); |   AffineMap map = builder.getSymbolIdentityMap(); | ||||||
|  |  | ||||||
|  | @ -17,20 +17,22 @@ class KrnlDialectOperandParser { | ||||||
|       : _parser(parser), _builder(parser.getBuilder()){}; |       : _parser(parser), _builder(parser.getBuilder()){}; | ||||||
| 
 | 
 | ||||||
|   // Parse an optional operand.
 |   // Parse an optional operand.
 | ||||||
|   mlir::ParseResult ParseOptionalOperand( |   mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType, | ||||||
|       const mlir::Type& operandType, mlir::Value*& operand); |                                          mlir::Value &operand); | ||||||
| 
 | 
 | ||||||
|   // Parse an optional operand and push it to an operand list.
 |   // Parse an optional operand and push it to an operand list.
 | ||||||
|   mlir::ParseResult ParseOptionalOperand(const mlir::Type& operandType, |   mlir::ParseResult | ||||||
|       llvm::SmallVectorImpl<mlir::Value*>& operandList); |   ParseOptionalOperand(const mlir::Type &operandType, | ||||||
|  |                        llvm::SmallVectorImpl<mlir::Value> &operandList); | ||||||
| 
 | 
 | ||||||
|   // Parse a required operand.
 |   // Parse a required operand.
 | ||||||
|   mlir::ParseResult ParseOperand( |   mlir::ParseResult ParseOperand(const mlir::Type &operandType, | ||||||
|       const mlir::Type& operandType, mlir::Value*& operand); |                                  mlir::Value &operand); | ||||||
| 
 | 
 | ||||||
|   // Parse a required operand and push it to an operand list.
 |   // Parse a required operand and push it to an operand list.
 | ||||||
|   mlir::ParseResult ParseOperand(const mlir::Type& operandType, |   mlir::ParseResult | ||||||
|       llvm::SmallVectorImpl<mlir::Value*>& operandList); |   ParseOperand(const mlir::Type &operandType, | ||||||
|  |                llvm::SmallVectorImpl<mlir::Value> &operandList); | ||||||
| 
 | 
 | ||||||
|   // Do we have more operands to parse?
 |   // Do we have more operands to parse?
 | ||||||
|   bool hasOperandLeft() { return !_operandRefQueue.empty(); } |   bool hasOperandLeft() { return !_operandRefQueue.empty(); } | ||||||
|  | @ -63,11 +65,10 @@ void printBound(mlir::AffineMapAttr boundMap, | ||||||
| namespace mlir { | namespace mlir { | ||||||
| 
 | 
 | ||||||
| struct KrnlIterateOperandPack { | struct KrnlIterateOperandPack { | ||||||
|   KrnlIterateOperandPack(mlir::Builder& builder, |   KrnlIterateOperandPack(mlir::Builder &builder, | ||||||
|       llvm::ArrayRef<mlir::Value*> inputLoops, |                          llvm::ArrayRef<mlir::Value> inputLoops, | ||||||
|       llvm::ArrayRef<mlir::Value*> optimizedLoops) |                          llvm::ArrayRef<mlir::Value> optimizedLoops) | ||||||
|       : builder(builder), |       : builder(builder), inputLoops(inputLoops), | ||||||
|         inputLoops(inputLoops), |  | ||||||
|         optimizedLoops(optimizedLoops) { |         optimizedLoops(optimizedLoops) { | ||||||
|     _operands.insert( |     _operands.insert( | ||||||
|         _operands.end(), optimizedLoops.begin(), optimizedLoops.end()); |         _operands.end(), optimizedLoops.begin(), optimizedLoops.end()); | ||||||
|  | @ -75,9 +76,9 @@ struct KrnlIterateOperandPack { | ||||||
| 
 | 
 | ||||||
|   void pushConstantBound(int64_t bound); |   void pushConstantBound(int64_t bound); | ||||||
| 
 | 
 | ||||||
|   void pushOperandBound(mlir::Value* operand); |   void pushOperandBound(mlir::Value operand); | ||||||
| 
 | 
 | ||||||
|   llvm::SmallVector<mlir::Value*, 8> getOperands() const { return _operands; } |   llvm::SmallVector<mlir::Value, 8> getOperands() const { return _operands; } | ||||||
| 
 | 
 | ||||||
|   mlir::ArrayAttr getAttributes() const { |   mlir::ArrayAttr getAttributes() const { | ||||||
|     return builder.getArrayAttr(boundMaps); |     return builder.getArrayAttr(boundMaps); | ||||||
|  | @ -90,11 +91,11 @@ struct KrnlIterateOperandPack { | ||||||
|  private: |  private: | ||||||
|   int _boundIdx = 0; |   int _boundIdx = 0; | ||||||
| 
 | 
 | ||||||
|   llvm::SmallVector<mlir::Value*, 8> _operands; |   llvm::SmallVector<mlir::Value, 8> _operands; | ||||||
| 
 | 
 | ||||||
|   llvm::SmallVector<mlir::Attribute, 8> boundMaps; |   llvm::SmallVector<mlir::Attribute, 8> boundMaps; | ||||||
| 
 | 
 | ||||||
|   llvm::ArrayRef<mlir::Value*> inputLoops, optimizedLoops; |   llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops; | ||||||
| 
 | 
 | ||||||
|   mlir::Builder& builder; |   mlir::Builder& builder; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -44,21 +44,21 @@ static MemRefType convertTensorToMemRef(TensorType type) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Insert an allocation and deallocation for the given MemRefType.
 | /// Insert an allocation and deallocation for the given MemRefType.
 | ||||||
| static Value *insertAllocAndDealloc(MemRefType type, Location loc, | static Value insertAllocAndDealloc(MemRefType type, Location loc, | ||||||
|                                    PatternRewriter &rewriter, |                                    PatternRewriter &rewriter, | ||||||
|                                    bool insertDealloc, |                                    bool insertDealloc, | ||||||
|                                     ArrayRef<Value *> operands = {}) { |                                    ArrayRef<Value> operands = {}) { | ||||||
|   // Put together alloc operands for any dynamic dimensions of the memref.
 |   // Put together alloc operands for any dynamic dimensions of the memref.
 | ||||||
|   AllocOp alloc; |   AllocOp alloc; | ||||||
|   if (!operands.empty()) { |   if (!operands.empty()) { | ||||||
|     auto memRefShape = type.getShape(); |     auto memRefShape = type.getShape(); | ||||||
|     auto rank = memRefShape.size(); |     auto rank = memRefShape.size(); | ||||||
| 
 | 
 | ||||||
|     std::map<int, Value *> fromOperands; |     std::map<int, Value> fromOperands; | ||||||
|     for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { |     for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { | ||||||
|       int memRefDimIdx = rank - 1 - reversedIdx; |       int memRefDimIdx = rank - 1 - reversedIdx; | ||||||
|       if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
 |       if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
 | ||||||
|         Value *maxDim = nullptr; |         Value maxDim = nullptr; | ||||||
|         for (int i = 0; i < operands.size(); i++) { |         for (int i = 0; i < operands.size(); i++) { | ||||||
|           auto operandShape = |           auto operandShape = | ||||||
|               operands[i]->getType().cast<MemRefType>().getShape(); |               operands[i]->getType().cast<MemRefType>().getShape(); | ||||||
|  | @ -85,7 +85,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     SmallVector<Value *, 4> allocOperands; |     SmallVector<Value, 4> allocOperands; | ||||||
|     for (int i = 0; i < rank; ++i) |     for (int i = 0; i < rank; ++i) | ||||||
|       if (memRefShape[i] < 0) |       if (memRefShape[i] < 0) | ||||||
|         allocOperands.push_back(fromOperands[i]); |         allocOperands.push_back(fromOperands[i]); | ||||||
|  | @ -146,14 +146,14 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { | ||||||
| 
 | 
 | ||||||
| // Get run-time dimension information for unknown dimensions used for
 | // Get run-time dimension information for unknown dimensions used for
 | ||||||
| // broadcasting.
 | // broadcasting.
 | ||||||
| std::map<int, std::map<int, Value *>> | std::map<int, std::map<int, Value>> | ||||||
| getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | ||||||
|                       MemRefType memRefType, ArrayRef<Value *> operands) { |                       MemRefType memRefType, ArrayRef<Value> operands) { | ||||||
|   auto memRefShape = memRefType.getShape(); |   auto memRefShape = memRefType.getShape(); | ||||||
|   int64_t rank = memRefShape.size(); |   int64_t rank = memRefShape.size(); | ||||||
|   // For unknown dimensions, we need to get dimension values at runtime in
 |   // For unknown dimensions, we need to get dimension values at runtime in
 | ||||||
|   // order to do broadcasting.
 |   // order to do broadcasting.
 | ||||||
|   std::map<int, std::map<int, Value *>> DimInfo; |   std::map<int, std::map<int, Value>> DimInfo; | ||||||
|   // For each result dimension, compute the number of sharing operands.
 |   // For each result dimension, compute the number of sharing operands.
 | ||||||
|   // Sharing operands are operands sharing the same index (counting from the
 |   // Sharing operands are operands sharing the same index (counting from the
 | ||||||
|   // rightmost to the leftmost) for a given dimension.
 |   // rightmost to the leftmost) for a given dimension.
 | ||||||
|  | @ -173,7 +173,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | ||||||
|   // We only care about unknown dimensions whose number of sharing operands is
 |   // We only care about unknown dimensions whose number of sharing operands is
 | ||||||
|   // more than one, since they are potentially broadcasted dimensions.
 |   // more than one, since they are potentially broadcasted dimensions.
 | ||||||
|   for (int i = 0; i < operands.size(); ++i) { |   for (int i = 0; i < operands.size(); ++i) { | ||||||
|     std::map<int, Value *> broadcastedDims; |     std::map<int, Value> broadcastedDims; | ||||||
|     auto shape = operands[i]->getType().cast<MemRefType>().getShape(); |     auto shape = operands[i]->getType().cast<MemRefType>().getShape(); | ||||||
|     int size = shape.size(); |     int size = shape.size(); | ||||||
|     for (int j = 0; j < shape.size(); ++j) { |     for (int j = 0; j < shape.size(); ++j) { | ||||||
|  | @ -192,17 +192,17 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, | ||||||
| 
 | 
 | ||||||
| // Extract induction variables that are used for broadcasting values of a
 | // Extract induction variables that are used for broadcasting values of a
 | ||||||
| // given operand.
 | // given operand.
 | ||||||
| std::vector<Value *> | std::vector<Value> | ||||||
| getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, | getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, | ||||||
|                           ArrayRef<Value *> loopIVs, Value *operand, |                           ArrayRef<Value> loopIVs, Value operand, | ||||||
|                           std::map<int, Value *> broadcastedDims) { |                           std::map<int, Value> broadcastedDims) { | ||||||
|   // `operand` must has a ranked type. This should have been checked by the
 |   // `operand` must has a ranked type. This should have been checked by the
 | ||||||
|   // shape inference pass.
 |   // shape inference pass.
 | ||||||
|   auto operandShape = operand->getType().cast<MemRefType>().getShape(); |   auto operandShape = operand->getType().cast<MemRefType>().getShape(); | ||||||
|   auto rank = operandShape.size(); |   auto rank = operandShape.size(); | ||||||
|   auto loopCount = loopIVs.size(); |   auto loopCount = loopIVs.size(); | ||||||
| 
 | 
 | ||||||
|   std::vector<Value *> newLoopIVs; |   std::vector<Value> newLoopIVs; | ||||||
|   for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { |   for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { | ||||||
|     auto dimIdx = rank - 1 - reversedIdx; |     auto dimIdx = rank - 1 - reversedIdx; | ||||||
|     auto loopIdx = loopCount - 1 - reversedIdx; |     auto loopIdx = loopCount - 1 - reversedIdx; | ||||||
|  | @ -247,7 +247,7 @@ struct ScalarOp<ONNXMulOp> { | ||||||
| template <> | template <> | ||||||
| struct ScalarOp<ONNXDivOp> { | struct ScalarOp<ONNXDivOp> { | ||||||
|   using FOp = DivFOp; |   using FOp = DivFOp; | ||||||
|   using IOp = DivISOp; |   using IOp = SignedDivIOp; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
|  | @ -295,8 +295,8 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp; | ||||||
| // Scalar unary ops for lowering to Krnl dialect.
 | // Scalar unary ops for lowering to Krnl dialect.
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <typename UnaryOp> | template <typename UnaryOp> | ||||||
| Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | ||||||
|                           ArrayRef<Value *> operands, |                          ArrayRef<Value> operands, | ||||||
|                          ConversionPatternRewriter &rewriter) { |                          ConversionPatternRewriter &rewriter) { | ||||||
|   /* Lower UnaryOp to Ops in the Standard dialect.
 |   /* Lower UnaryOp to Ops in the Standard dialect.
 | ||||||
|    */ |    */ | ||||||
|  | @ -318,14 +318,13 @@ Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types, | ||||||
| // Scalar unary ops for lowering ONNXTanhOp
 | // Scalar unary ops for lowering ONNXTanhOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op, | Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                       ArrayRef<Type> result_types, |                                      ArrayRef<Value> operands, | ||||||
|                                       ArrayRef<Value *> operands, |  | ||||||
|                                      ConversionPatternRewriter &rewriter) { |                                      ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 |   // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||||
|   //                         AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
 |   //                         AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|   auto neg = rewriter.create<SubFOp>(loc, zero, operand); |   auto neg = rewriter.create<SubFOp>(loc, zero, operand); | ||||||
|  | @ -342,14 +341,13 @@ Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op, | ||||||
| // Scalar unary ops for lowering ONNXSinhOp
 | // Scalar unary ops for lowering ONNXSinhOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op, | Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                       ArrayRef<Type> result_types, |                                      ArrayRef<Value> operands, | ||||||
|                                       ArrayRef<Value *> operands, |  | ||||||
|                                      ConversionPatternRewriter &rewriter) { |                                      ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 |   // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||||
|   //                         ConstantOp 2)
 |   //                         ConstantOp 2)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|   auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); |   auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); | ||||||
|  | @ -366,14 +364,13 @@ Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op, | ||||||
| // Scalar unary ops for lowering ONNXCoshOp
 | // Scalar unary ops for lowering ONNXCoshOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op, | Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                       ArrayRef<Type> result_types, |                                      ArrayRef<Value> operands, | ||||||
|                                       ArrayRef<Value *> operands, |  | ||||||
|                                      ConversionPatternRewriter &rewriter) { |                                      ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 |   // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | ||||||
|   //                         ConstantOp 2)
 |   //                         ConstantOp 2)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|   auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); |   auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); | ||||||
|  | @ -390,14 +387,14 @@ Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op, | ||||||
| // Scalar unary ops for lowering ONNXSigmoidOp
 | // Scalar unary ops for lowering ONNXSigmoidOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | ||||||
|                                         ArrayRef<Type> result_types, |                                         ArrayRef<Type> result_types, | ||||||
|                                          ArrayRef<Value *> operands, |                                         ArrayRef<Value> operands, | ||||||
|                                         ConversionPatternRewriter &rewriter) { |                                         ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
 |   // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
 | ||||||
|   //                            AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
 |   //                            AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|   auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); |   auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); | ||||||
|  | @ -413,8 +410,8 @@ Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op, | ||||||
| // Scalar unary ops for lowering ONNXHardSigmoidOp
 | // Scalar unary ops for lowering ONNXHardSigmoidOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( | Value mapToLowerScalarOp<ONNXHardSigmoidOp>( | ||||||
|     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands, |     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||||
|     ConversionPatternRewriter &rewriter) { |     ConversionPatternRewriter &rewriter) { | ||||||
|   // %Y = AddFOp(MulFOp(alpha, %X), beta)
 |   // %Y = AddFOp(MulFOp(alpha, %X), beta)
 | ||||||
|   // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
 |   // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
 | ||||||
|  | @ -424,7 +421,7 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( | ||||||
|   //                                  %Z,
 |   //                                  %Z,
 | ||||||
|   //                                  Constant 1)
 |   //                                  Constant 1)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); |   auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); | ||||||
|   auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); |   auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); | ||||||
| 
 | 
 | ||||||
|  | @ -449,14 +446,14 @@ Value *mapToLowerScalarOp<ONNXHardSigmoidOp>( | ||||||
| // Scalar unary ops for lowering ONNXEluOp
 | // Scalar unary ops for lowering ONNXEluOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                      ArrayRef<Value *> operands, |                                     ArrayRef<Value> operands, | ||||||
|                                     ConversionPatternRewriter &rewriter) { |                                     ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 |   // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||||
|   //                          MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
 |   //                          MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
 | ||||||
|   //                          %X)
 |   //                          %X)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); |   auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|  | @ -478,15 +475,14 @@ Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
| // Scalar unary ops for lowering ONNXReluOp
 | // Scalar unary ops for lowering ONNXReluOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op, | Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                       ArrayRef<Type> result_types, |                                      ArrayRef<Value> operands, | ||||||
|                                       ArrayRef<Value *> operands, |  | ||||||
|                                      ConversionPatternRewriter &rewriter) { |                                      ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 |   // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||||
|   //                           ConstantOp 0,
 |   //                           ConstantOp 0,
 | ||||||
|   //                           %X)
 |   //                           %X)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|   auto lessThanZero = |   auto lessThanZero = | ||||||
|  | @ -500,15 +496,15 @@ Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op, | ||||||
| // Scalar unary ops for lowering ONNXLeakyReluOp
 | // Scalar unary ops for lowering ONNXLeakyReluOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value * | Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, | ||||||
| mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types, |                                           ArrayRef<Type> result_types, | ||||||
|                                     ArrayRef<Value *> operands, |                                           ArrayRef<Value> operands, | ||||||
|                                           ConversionPatternRewriter &rewriter) { |                                           ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 |   // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | ||||||
|   //                                MulFOp(alpha, %X),
 |   //                                MulFOp(alpha, %X),
 | ||||||
|   //                                %X)
 |   //                                %X)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); |   auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); | ||||||
|   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); |   auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); | ||||||
|  | @ -525,9 +521,8 @@ mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
| // Scalar unary ops for lowering ONNXSeluOp
 | // Scalar unary ops for lowering ONNXSeluOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op, | Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                       ArrayRef<Type> result_types, |                                      ArrayRef<Value> operands, | ||||||
|                                       ArrayRef<Value *> operands, |  | ||||||
|                                      ConversionPatternRewriter &rewriter) { |                                      ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
 |   // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
 | ||||||
|   //                           MulFOp(gamma, %X),
 |   //                           MulFOp(gamma, %X),
 | ||||||
|  | @ -535,7 +530,7 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op, | ||||||
|   //                                  SubFOp(MulFOp(alpha, ExpOp(%X)),
 |   //                                  SubFOp(MulFOp(alpha, ExpOp(%X)),
 | ||||||
|   //                                         alpha)))
 |   //                                         alpha)))
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
|   auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); |   auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); | ||||||
|   auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); |   auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); | ||||||
| 
 | 
 | ||||||
|  | @ -558,13 +553,12 @@ Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op, | ||||||
| // Scalar unary ops for lowering ONNXReciprocalOp
 | // Scalar unary ops for lowering ONNXReciprocalOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value * | Value mapToLowerScalarOp<ONNXReciprocalOp>( | ||||||
| mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types, |     Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||||
|                                      ArrayRef<Value *> operands, |  | ||||||
|     ConversionPatternRewriter &rewriter) { |     ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
 |   // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *operand = operands[0]; |   Value operand = operands[0]; | ||||||
| 
 | 
 | ||||||
|   auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); |   auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); | ||||||
|   auto result = rewriter.create<DivFOp>(loc, one, operand); |   auto result = rewriter.create<DivFOp>(loc, one, operand); | ||||||
|  | @ -576,15 +570,15 @@ mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
| // Scalar unary ops for lowering ONNXMaxOp
 | // Scalar unary ops for lowering ONNXMaxOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                      ArrayRef<Value *> operands, |                                     ArrayRef<Value> operands, | ||||||
|                                     ConversionPatternRewriter &rewriter) { |                                     ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
 |   // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
 | ||||||
|   //                              %X,
 |   //                              %X,
 | ||||||
|   //                              %Y)
 |   //                              %Y)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *lhs = operands[0]; |   Value lhs = operands[0]; | ||||||
|   Value *rhs = operands[1]; |   Value rhs = operands[1]; | ||||||
|   auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); |   auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); | ||||||
|   auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); |   auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); | ||||||
|   return result; |   return result; | ||||||
|  | @ -594,15 +588,15 @@ Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
| // Scalar unary ops for lowering ONNXMinOp
 | // Scalar unary ops for lowering ONNXMinOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| template <> | template <> | ||||||
| Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, | Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types, | ||||||
|                                      ArrayRef<Value *> operands, |                                     ArrayRef<Value> operands, | ||||||
|                                     ConversionPatternRewriter &rewriter) { |                                     ConversionPatternRewriter &rewriter) { | ||||||
|   // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
 |   // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
 | ||||||
|   //                              %X,
 |   //                              %X,
 | ||||||
|   //                              %Y)
 |   //                              %Y)
 | ||||||
|   auto loc = op->getLoc(); |   auto loc = op->getLoc(); | ||||||
|   Value *lhs = operands[0]; |   Value lhs = operands[0]; | ||||||
|   Value *rhs = operands[1]; |   Value rhs = operands[1]; | ||||||
|   auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); |   auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); | ||||||
|   auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); |   auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); | ||||||
|   return result; |   return result; | ||||||
|  | @ -615,7 +609,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | ||||||
|   ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) |   ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} | ||||||
|   PatternMatchResult |   PatternMatchResult | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |                   ConversionPatternRewriter &rewriter) const final { | ||||||
|     // TODO: Check that the types are valid.
 |     // TODO: Check that the types are valid.
 | ||||||
|     // An element-wise unary operation must have all operands and the result of
 |     // An element-wise unary operation must have all operands and the result of
 | ||||||
|  | @ -632,7 +626,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | ||||||
|     // dimensions with the result at this pre-optimization phase.
 |     // dimensions with the result at this pre-optimization phase.
 | ||||||
|     // TODO: verify that dimensions match.
 |     // TODO: verify that dimensions match.
 | ||||||
|     // TODO: can the dimension of the result differ after optimizations?
 |     // TODO: can the dimension of the result differ after optimizations?
 | ||||||
|     Value *alloc; |     Value alloc; | ||||||
|     bool insertDealloc = checkInsertDealloc(op); |     bool insertDealloc = checkInsertDealloc(op); | ||||||
| 
 | 
 | ||||||
|     if (hasAllConstantDimensions(memRefType)) |     if (hasAllConstantDimensions(memRefType)) | ||||||
|  | @ -647,7 +641,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|     // Define loops.
 |     // Define loops.
 | ||||||
|     auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); |     auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); | ||||||
|     std::vector<Value *> originalLoops; |     std::vector<Value> originalLoops; | ||||||
|     originalLoops.reserve(rank); |     originalLoops.reserve(rank); | ||||||
|     for (auto result : loopsOp.getResults()) { |     for (auto result : loopsOp.getResults()) { | ||||||
|       originalLoops.push_back(result); |       originalLoops.push_back(result); | ||||||
|  | @ -655,7 +649,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|     // Define loop optimization.
 |     // Define loop optimization.
 | ||||||
|     auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); |     auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); | ||||||
|     std::vector<Value *> optimizedLoops; |     std::vector<Value> optimizedLoops; | ||||||
|     optimizedLoops.reserve(rank); |     optimizedLoops.reserve(rank); | ||||||
|     for (auto result : optimizedLoopsOp.getResults()) { |     for (auto result : optimizedLoopsOp.getResults()) { | ||||||
|       optimizedLoops.push_back(result); |       optimizedLoops.push_back(result); | ||||||
|  | @ -695,7 +689,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { | ||||||
|     rewriter.setInsertionPointToStart(&iterationBlock); |     rewriter.setInsertionPointToStart(&iterationBlock); | ||||||
| 
 | 
 | ||||||
|     // Handle the operation:
 |     // Handle the operation:
 | ||||||
|     SmallVector<Value *, 4> loopIVs; |     SmallVector<Value, 4> loopIVs; | ||||||
|     for (auto arg : iterationBlock.getArguments()) |     for (auto arg : iterationBlock.getArguments()) | ||||||
|       loopIVs.push_back(arg); |       loopIVs.push_back(arg); | ||||||
| 
 | 
 | ||||||
|  | @ -718,7 +712,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
|   ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) |   ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} | ||||||
|   PatternMatchResult |   PatternMatchResult | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |                   ConversionPatternRewriter &rewriter) const final { | ||||||
|     // TODO: Check that the types are valid.
 |     // TODO: Check that the types are valid.
 | ||||||
|     // An element-wise variadic operation must have all operands and the result
 |     // An element-wise variadic operation must have all operands and the result
 | ||||||
|  | @ -730,7 +724,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
|     // Insert an allocation and deallocation for the result of this operation.
 |     // Insert an allocation and deallocation for the result of this operation.
 | ||||||
|     auto memRefType = convertTensorToMemRef(tensorType); |     auto memRefType = convertTensorToMemRef(tensorType); | ||||||
| 
 | 
 | ||||||
|     Value *alloc; |     Value alloc; | ||||||
|     bool insertDealloc = checkInsertDealloc(op); |     bool insertDealloc = checkInsertDealloc(op); | ||||||
|     // If the output has a dynamic dimension, we compute its dimension at
 |     // If the output has a dynamic dimension, we compute its dimension at
 | ||||||
|     // runtime by using dimensions from the operands.
 |     // runtime by using dimensions from the operands.
 | ||||||
|  | @ -749,7 +743,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|     // Define loops.
 |     // Define loops.
 | ||||||
|     auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); |     auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); | ||||||
|     std::vector<Value *> originalLoops; |     std::vector<Value> originalLoops; | ||||||
|     originalLoops.reserve(rank); |     originalLoops.reserve(rank); | ||||||
|     for (auto result : loopsOp.getResults()) { |     for (auto result : loopsOp.getResults()) { | ||||||
|       originalLoops.push_back(result); |       originalLoops.push_back(result); | ||||||
|  | @ -757,7 +751,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|     // Define loop optimization.
 |     // Define loop optimization.
 | ||||||
|     auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); |     auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); | ||||||
|     std::vector<Value *> optimizedLoops; |     std::vector<Value> optimizedLoops; | ||||||
|     optimizedLoops.reserve(rank); |     optimizedLoops.reserve(rank); | ||||||
|     for (auto result : optimizedLoopsOp.getResults()) { |     for (auto result : optimizedLoopsOp.getResults()) { | ||||||
|       optimizedLoops.push_back(result); |       optimizedLoops.push_back(result); | ||||||
|  | @ -781,7 +775,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|     // Get run-time dimension information for unknown dimensions used for
 |     // Get run-time dimension information for unknown dimensions used for
 | ||||||
|     // broadcasting.
 |     // broadcasting.
 | ||||||
|     std::map<int, std::map<int, Value *>> broadcastedDimInfo = |     std::map<int, std::map<int, Value>> broadcastedDimInfo = | ||||||
|         getBroadcastedDimInfo(loc, rewriter, memRefType, operands); |         getBroadcastedDimInfo(loc, rewriter, memRefType, operands); | ||||||
| 
 | 
 | ||||||
|     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); |     auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); | ||||||
|  | @ -801,12 +795,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | ||||||
|     rewriter.setInsertionPointToStart(&iterationBlock); |     rewriter.setInsertionPointToStart(&iterationBlock); | ||||||
| 
 | 
 | ||||||
|     // Handle the operation:
 |     // Handle the operation:
 | ||||||
|     SmallVector<Value *, 4> loopIVs; |     SmallVector<Value, 4> loopIVs; | ||||||
|     for (auto arg : iterationBlock.getArguments()) |     for (auto arg : iterationBlock.getArguments()) | ||||||
|       loopIVs.push_back(arg); |       loopIVs.push_back(arg); | ||||||
| 
 | 
 | ||||||
|     // Fold over operands for each of their scalar values
 |     // Fold over operands for each of their scalar values
 | ||||||
|     Value *accumulated, *next; |     Value accumulated, next; | ||||||
|     auto accumulatedLoopIVs = getLoopIVsForBroadcasting( |     auto accumulatedLoopIVs = getLoopIVsForBroadcasting( | ||||||
|         loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); |         loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); | ||||||
|     accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs); |     accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs); | ||||||
|  | @ -831,17 +825,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern { | ||||||
|       : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} | ||||||
| 
 | 
 | ||||||
|   PatternMatchResult |   PatternMatchResult | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|                   ConversionPatternRewriter &rewriter) const final { |                   ConversionPatternRewriter &rewriter) const final { | ||||||
|     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); |     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); | ||||||
|     auto loc = op->getLoc(); |     auto loc = op->getLoc(); | ||||||
| 
 | 
 | ||||||
|     // Insert an allocation and deallocation for the result of this operation.
 |     // Insert an allocation and deallocation for the result of this operation.
 | ||||||
|     auto memRefType = convertTensorToMemRef(tensorType); |     auto memRefType = convertTensorToMemRef(tensorType); | ||||||
|     Value *alloc; |     Value alloc; | ||||||
| 
 | 
 | ||||||
|     // Compute size in bytes.
 |     // Compute size in bytes.
 | ||||||
|     Value *tensorSize = rewriter.create<ConstantOp>( |     Value tensorSize = rewriter.create<ConstantOp>( | ||||||
|         loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), |         loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), | ||||||
|                                      getMemRefEltSizeInBytes(memRefType))); |                                      getMemRefEltSizeInBytes(memRefType))); | ||||||
|     bool insertDealloc = checkInsertDealloc(op); |     bool insertDealloc = checkInsertDealloc(op); | ||||||
|  | @ -849,14 +843,14 @@ struct ONNXReshapeOpLowering : public ConversionPattern { | ||||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); |       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||||
|     } else { |     } else { | ||||||
|       auto memRefShape = memRefType.getShape(); |       auto memRefShape = memRefType.getShape(); | ||||||
|       SmallVector<Value *, 4> allocOperands; |       SmallVector<Value, 4> allocOperands; | ||||||
|       for (int i = 0; i < memRefShape.size(); ++i) { |       for (int i = 0; i < memRefShape.size(); ++i) { | ||||||
|         // The shape array can always be used to construct shape information of
 |         // The shape array can always be used to construct shape information of
 | ||||||
|         // the result.
 |         // the result.
 | ||||||
|         Value *index = rewriter.create<ConstantOp>( |         Value index = rewriter.create<ConstantOp>( | ||||||
|             loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); |             loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); | ||||||
|         Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); |         Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); | ||||||
|         Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>( |         Value int64LoadedVal = rewriter.create<ZeroExtendIOp>( | ||||||
|             loc, loadedVal, rewriter.getIntegerType(64)); |             loc, loadedVal, rewriter.getIntegerType(64)); | ||||||
|         tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal); |         tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal); | ||||||
|         allocOperands.push_back(rewriter.create<IndexCastOp>( |         allocOperands.push_back(rewriter.create<IndexCastOp>( | ||||||
|  |  | ||||||
|  | @ -30,7 +30,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> { | ||||||
|       operandItr++; |       operandItr++; | ||||||
| 
 | 
 | ||||||
|       // Organize operands into lower/upper bounds in affine.for ready formats.
 |       // Organize operands into lower/upper bounds in affine.for ready formats.
 | ||||||
|       SmallVector<Value *, 4> lbOperands, ubOperands; |       SmallVector<Value, 4> lbOperands, ubOperands; | ||||||
|       AffineMap lbMap, ubMap; |       AffineMap lbMap, ubMap; | ||||||
|       for (int boundType = 0; boundType < 2; boundType++) { |       for (int boundType = 0; boundType < 2; boundType++) { | ||||||
|         auto &operands = boundType == 0 ? lbOperands : ubOperands; |         auto &operands = boundType == 0 ? lbOperands : ubOperands; | ||||||
|  |  | ||||||
|  | @ -51,7 +51,7 @@ public: | ||||||
|       : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} |       : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} | ||||||
| 
 | 
 | ||||||
|   PatternMatchResult |   PatternMatchResult | ||||||
|   matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|                   ConversionPatternRewriter &rewriter) const override { |                   ConversionPatternRewriter &rewriter) const override { | ||||||
|     auto *context = op->getContext(); |     auto *context = op->getContext(); | ||||||
|     auto loc = op->getLoc(); |     auto loc = op->getLoc(); | ||||||
|  | @ -66,27 +66,27 @@ public: | ||||||
|     // First operand.
 |     // First operand.
 | ||||||
|     Type dstType = |     Type dstType = | ||||||
|         operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); |         operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); | ||||||
|     Value *alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( |     Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||||
|         loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); |         loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); | ||||||
|     Value *alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( |     Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( | ||||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); |         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); | ||||||
| 
 | 
 | ||||||
|     // Second operand.
 |     // Second operand.
 | ||||||
|     Type srcType = |     Type srcType = | ||||||
|         operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); |         operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); | ||||||
|     Value *alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( |     Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||||
|         loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); |         loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); | ||||||
|     Value *alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( |     Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( | ||||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); |         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); | ||||||
| 
 | 
 | ||||||
|     // Size.
 |     // Size.
 | ||||||
|     Value *int64Size = rewriter.create<LLVM::SExtOp>( |     Value int64Size = rewriter.create<LLVM::SExtOp>( | ||||||
|         loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); |         loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); | ||||||
| 
 | 
 | ||||||
|     // Memcpy call
 |     // Memcpy call
 | ||||||
|     rewriter.create<CallOp>( |     rewriter.create<CallOp>( | ||||||
|         loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), |         loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), | ||||||
|         ArrayRef<Value *>( |         ArrayRef<Value>( | ||||||
|             {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); |             {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); | ||||||
| 
 | 
 | ||||||
|     rewriter.eraseOp(op); |     rewriter.eraseOp(op); | ||||||
|  | @ -210,7 +210,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     // Retrieve dynamic mem refs from wrapped input, and convert every one of
 |     // Retrieve dynamic mem refs from wrapped input, and convert every one of
 | ||||||
|     // them to static mem refs.
 |     // them to static mem refs.
 | ||||||
|     SmallVector<Value *, 4> staticInputs; |     SmallVector<Value, 4> staticInputs; | ||||||
|     auto wrappedInput = entryPointEntryBlock.getArgument(0); |     auto wrappedInput = entryPointEntryBlock.getArgument(0); | ||||||
|     for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) { |     for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) { | ||||||
|       // Call API function to retrieve the i-th dynamic memref.
 |       // Call API function to retrieve the i-th dynamic memref.
 | ||||||
|  | @ -225,13 +225,12 @@ public: | ||||||
|       auto memRefTy = memRefPtrTy.getPointerElementTy(); |       auto memRefTy = memRefPtrTy.getPointerElementTy(); | ||||||
|       auto one = rewriter.create<LLVM::ConstantOp>( |       auto one = rewriter.create<LLVM::ConstantOp>( | ||||||
|           loc, int32Ty, rewriter.getI32IntegerAttr(1)); |           loc, int32Ty, rewriter.getI32IntegerAttr(1)); | ||||||
|       Value *ptrToMemRef = |       Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one, | ||||||
|           rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one, |  | ||||||
|                                                           /*alignment=*/0); |                                                           /*alignment=*/0); | ||||||
| 
 | 
 | ||||||
|       // Fill in the memref underlying ptrToMemRef with information extracted
 |       // Fill in the memref underlying ptrToMemRef with information extracted
 | ||||||
|       // from dynMemRef.
 |       // from dynMemRef.
 | ||||||
|       fillPtrToMemRefWithDynMemRef(*dynMemRef, *ptrToMemRef, rewriter, loc, |       fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc, | ||||||
|                                    apiRegistry, llvmDialect); |                                    apiRegistry, llvmDialect); | ||||||
| 
 | 
 | ||||||
|       // ptrToMemRef will be an input to main computation graph function.
 |       // ptrToMemRef will be an input to main computation graph function.
 | ||||||
|  | @ -261,8 +260,8 @@ public: | ||||||
|         loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); |         loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); | ||||||
|     auto outDynMemRef = callApi(rewriter, loc, apiRegistry, |     auto outDynMemRef = callApi(rewriter, loc, apiRegistry, | ||||||
|                                 API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); |                                 API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); | ||||||
|     fillDynMemRefWithMemRef(*outMemRef, *outDynMemRef, rewriter, loc, |     fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry, | ||||||
|                             apiRegistry, llvmDialect); |                             llvmDialect); | ||||||
|     auto zero = rewriter.create<LLVM::ConstantOp>( |     auto zero = rewriter.create<LLVM::ConstantOp>( | ||||||
|         loc, int32Ty, rewriter.getI32IntegerAttr(0)); |         loc, int32Ty, rewriter.getI32IntegerAttr(0)); | ||||||
|     callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, |     callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, | ||||||
|  | @ -270,7 +269,7 @@ public: | ||||||
| 
 | 
 | ||||||
|     // Return wrapped output.
 |     // Return wrapped output.
 | ||||||
|     rewriter.create<LLVM::ReturnOp>(loc, |     rewriter.create<LLVM::ReturnOp>(loc, | ||||||
|                                     SmallVector<Value *, 1>({wrappedOutput})); |                                     SmallVector<Value, 1>({wrappedOutput})); | ||||||
|     return matchSuccess(); |     return matchSuccess(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -315,11 +314,11 @@ private: | ||||||
| 
 | 
 | ||||||
|   // Call a registered API, return the return SSA values if only one result is
 |   // Call a registered API, return the return SSA values if only one result is
 | ||||||
|   // returned, otherwise return nullptr.
 |   // returned, otherwise return nullptr.
 | ||||||
|   Value *callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, |   Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, | ||||||
|                  API apiId, ArrayRef<Value *> params) const { |                 API apiId, ArrayRef<Value> params) const { | ||||||
|     auto returnVals = rewriter.create<LLVM::CallOp>( |     auto returnVals = rewriter.create<LLVM::CallOp>( | ||||||
|         loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, |         loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, | ||||||
|         ArrayRef<Value *>(params)); |         ArrayRef<Value>(params)); | ||||||
|     if (returnVals.getNumResults() == 1) |     if (returnVals.getNumResults() == 1) | ||||||
|       return returnVals.getResult(0); |       return returnVals.getResult(0); | ||||||
|     return nullptr; |     return nullptr; | ||||||
|  | @ -348,12 +347,11 @@ private: | ||||||
|     auto memRefTy = memRefPtrTy.getPointerElementTy(); |     auto memRefTy = memRefPtrTy.getPointerElementTy(); | ||||||
|     auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); |     auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); | ||||||
| 
 | 
 | ||||||
|     Value *memRef = |     Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef); | ||||||
|         rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, &ptrToMemRef); |  | ||||||
| 
 | 
 | ||||||
|     // Set dataPtr and alignedDataPtr;
 |     // Set dataPtr and alignedDataPtr;
 | ||||||
|     auto dataPtr = |     auto dataPtr = | ||||||
|         callApi(rewriter, loc, apiRegistry, API::GET_DATA, {&dynMemRef}); |         callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef}); | ||||||
|     dataPtr = rewriter.create<LLVM::BitcastOp>( |     dataPtr = rewriter.create<LLVM::BitcastOp>( | ||||||
|         loc, memRefTy.getStructElementType(0), dataPtr); |         loc, memRefTy.getStructElementType(0), dataPtr); | ||||||
|     memRef = rewriter.create<LLVM::InsertValueOp>( |     memRef = rewriter.create<LLVM::InsertValueOp>( | ||||||
|  | @ -373,9 +371,9 @@ private: | ||||||
|     // Get rank, sizes array ptr and strides array ptr.
 |     // Get rank, sizes array ptr and strides array ptr.
 | ||||||
|     auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); |     auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); | ||||||
|     auto sizesArrayPtr = |     auto sizesArrayPtr = | ||||||
|         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&dynMemRef}); |         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef}); | ||||||
|     auto stridesArrayPtr = |     auto stridesArrayPtr = | ||||||
|         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&dynMemRef}); |         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef}); | ||||||
| 
 | 
 | ||||||
|     for (decltype(rank) i = 0; i < rank; i++) { |     for (decltype(rank) i = 0; i < rank; i++) { | ||||||
|       auto dimIdx = rewriter.create<LLVM::ConstantOp>( |       auto dimIdx = rewriter.create<LLVM::ConstantOp>( | ||||||
|  | @ -384,7 +382,7 @@ private: | ||||||
|       // Insert size of the dimension.
 |       // Insert size of the dimension.
 | ||||||
|       auto dimSizePtr = rewriter.create<LLVM::GEPOp>( |       auto dimSizePtr = rewriter.create<LLVM::GEPOp>( | ||||||
|           loc, int64Ty.getPointerTo(), sizesArrayPtr, |           loc, int64Ty.getPointerTo(), sizesArrayPtr, | ||||||
|           ArrayRef<Value *>({dimIdx})); |           ArrayRef<Value>({dimIdx})); | ||||||
|       auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(), |       auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(), | ||||||
|                                                    dimSizePtr); |                                                    dimSizePtr); | ||||||
|       memRef = rewriter.create<LLVM::InsertValueOp>( |       memRef = rewriter.create<LLVM::InsertValueOp>( | ||||||
|  | @ -395,7 +393,7 @@ private: | ||||||
|       // Insert stride of the dimension.
 |       // Insert stride of the dimension.
 | ||||||
|       auto dimStridePtr = rewriter.create<LLVM::GEPOp>( |       auto dimStridePtr = rewriter.create<LLVM::GEPOp>( | ||||||
|           loc, int64Ty.getPointerTo(), sizesArrayPtr, |           loc, int64Ty.getPointerTo(), sizesArrayPtr, | ||||||
|           ArrayRef<Value *>({dimIdx})); |           ArrayRef<Value>({dimIdx})); | ||||||
|       auto dimStride = rewriter.create<LLVM::LoadOp>( |       auto dimStride = rewriter.create<LLVM::LoadOp>( | ||||||
|           loc, int64Ty.getPointerTo(), dimStridePtr); |           loc, int64Ty.getPointerTo(), dimStridePtr); | ||||||
|       memRef = rewriter.create<LLVM::InsertValueOp>( |       memRef = rewriter.create<LLVM::InsertValueOp>( | ||||||
|  | @ -404,7 +402,7 @@ private: | ||||||
|               {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); |               {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     rewriter.create<LLVM::StoreOp>(loc, memRef, &ptrToMemRef); |     rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, |   void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, | ||||||
|  | @ -415,19 +413,19 @@ private: | ||||||
|     auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); |     auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); | ||||||
| 
 | 
 | ||||||
|     // Extract the data pointer, and record it in dynamic mem ref created.
 |     // Extract the data pointer, and record it in dynamic mem ref created.
 | ||||||
|     Value *outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>( |     Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>( | ||||||
|         loc, outMemRefTy.getStructElementType(0), &outMemRef, |         loc, outMemRefTy.getStructElementType(0), outMemRef, | ||||||
|         rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); |         rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); | ||||||
|     outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( |     outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( | ||||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); |         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); | ||||||
|     callApi(rewriter, loc, apiRegistry, API::SET_DATA, |     callApi(rewriter, loc, apiRegistry, API::SET_DATA, | ||||||
|             {&outDynMemRef, outMemRefDataPtr}); |             {outDynMemRef, outMemRefDataPtr}); | ||||||
| 
 | 
 | ||||||
|     auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); |     auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); | ||||||
|     auto sizesArrayPtr = |     auto sizesArrayPtr = | ||||||
|         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {&outDynMemRef}); |         callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef}); | ||||||
|     auto stridesArrayPtr = |     auto stridesArrayPtr = | ||||||
|         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {&outDynMemRef}); |         callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef}); | ||||||
| 
 | 
 | ||||||
|     for (decltype(rank) i = 0; i < rank; i++) { |     for (decltype(rank) i = 0; i < rank; i++) { | ||||||
|       auto dimIdx = rewriter.create<LLVM::ConstantOp>( |       auto dimIdx = rewriter.create<LLVM::ConstantOp>( | ||||||
|  | @ -435,22 +433,22 @@ private: | ||||||
| 
 | 
 | ||||||
|       // Transfer size of dimension from memref to dynamic memref.
 |       // Transfer size of dimension from memref to dynamic memref.
 | ||||||
|       auto dimSize = rewriter.create<LLVM::ExtractValueOp>( |       auto dimSize = rewriter.create<LLVM::ExtractValueOp>( | ||||||
|           loc, int64Ty, &outMemRef, |           loc, int64Ty, outMemRef, | ||||||
|           rewriter.getArrayAttr( |           rewriter.getArrayAttr( | ||||||
|               {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); |               {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); | ||||||
|       auto dimSizePtr = rewriter.create<LLVM::GEPOp>( |       auto dimSizePtr = rewriter.create<LLVM::GEPOp>( | ||||||
|           loc, int64Ty.getPointerTo(), sizesArrayPtr, |           loc, int64Ty.getPointerTo(), sizesArrayPtr, | ||||||
|           ArrayRef<Value *>({dimIdx})); |           ArrayRef<Value>({dimIdx})); | ||||||
|       rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr); |       rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr); | ||||||
| 
 | 
 | ||||||
|       // Transfer stride of dimension from memref to dynamic memref.
 |       // Transfer stride of dimension from memref to dynamic memref.
 | ||||||
|       auto dimStride = rewriter.create<LLVM::ExtractValueOp>( |       auto dimStride = rewriter.create<LLVM::ExtractValueOp>( | ||||||
|           loc, int64Ty, &outMemRef, |           loc, int64Ty, outMemRef, | ||||||
|           rewriter.getArrayAttr( |           rewriter.getArrayAttr( | ||||||
|               {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); |               {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); | ||||||
|       auto dimStridePtr = rewriter.create<LLVM::GEPOp>( |       auto dimStridePtr = rewriter.create<LLVM::GEPOp>( | ||||||
|           loc, int64Ty.getPointerTo(), stridesArrayPtr, |           loc, int64Ty.getPointerTo(), stridesArrayPtr, | ||||||
|           ArrayRef<Value *>({dimIdx})); |           ArrayRef<Value>({dimIdx})); | ||||||
|       rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr); |       rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue