Error fix3 (#145)
* removed warning missing return, dangling else * fixed errors, made sure to return false in all shape inference failures * shape inference use LogicalResults as return value * format fixed * format error Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									c20aa6980e
								
							
						
					
					
						commit
						4f8fd9d1bf
					
				|  | @ -490,6 +490,7 @@ Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter, | ||||||
|     return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
 |     return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
 | ||||||
|   } else { |   } else { | ||||||
|     emitError(loc, "unsupported element type"); |     emitError(loc, "unsupported element type"); | ||||||
|  |     return nullptr; | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -258,12 +258,12 @@ struct ONNXMatMulOpLowering : public ConversionPattern { | ||||||
|         for (auto arg : loopBatchIVs) |         for (auto arg : loopBatchIVs) | ||||||
|           loopBatchKNIVs.emplace_back(arg); |           loopBatchKNIVs.emplace_back(arg); | ||||||
|       loopBatchKNIVs.emplace_back(loopKIVs[0]); |       loopBatchKNIVs.emplace_back(loopKIVs[0]); | ||||||
|       if (BShape.size() >= 2) |       if (BShape.size() >= 2) { | ||||||
|         if (AShape.size() >= 2) |         if (AShape.size() >= 2) | ||||||
|           loopBatchKNIVs.emplace_back(loopMNIVs[1]); |           loopBatchKNIVs.emplace_back(loopMNIVs[1]); | ||||||
|         else |         else | ||||||
|           loopBatchKNIVs.emplace_back(loopMNIVs[0]); |           loopBatchKNIVs.emplace_back(loopMNIVs[0]); | ||||||
| 
 |       } | ||||||
|       // Matmul computation
 |       // Matmul computation
 | ||||||
|       auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs); |       auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs); | ||||||
|       auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs); |       auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs); | ||||||
|  |  | ||||||
|  | @ -32,6 +32,7 @@ public: | ||||||
|       return LoopType::get(parser.getBuilder().getContext()); |       return LoopType::get(parser.getBuilder().getContext()); | ||||||
| 
 | 
 | ||||||
|     parser.emitError(parser.getCurrentLocation(), "Unknown type"); |     parser.emitError(parser.getCurrentLocation(), "Unknown type"); | ||||||
|  |     return nullptr; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Print a type registered to this dialect.
 |   /// Print a type registered to this dialect.
 | ||||||
|  |  | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { | ||||||
| 
 | 
 | ||||||
|   let methods = [ |   let methods = [ | ||||||
|     InterfaceMethod<"Infer and set the output shape for the current operation.", |     InterfaceMethod<"Infer and set the output shape for the current operation.", | ||||||
|                     "bool", "inferShapes"> |                     "LogicalResult", "inferShapes"> | ||||||
|   ]; |   ]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -36,7 +36,7 @@ public: | ||||||
|     f.walk([&](mlir::Operation *op) { |     f.walk([&](mlir::Operation *op) { | ||||||
|       if (returnsDynamicShape(op)) { |       if (returnsDynamicShape(op)) { | ||||||
|         if (auto shape_op = dyn_cast<ShapeInference>(op)) { |         if (auto shape_op = dyn_cast<ShapeInference>(op)) { | ||||||
|           if (!shape_op.inferShapes()) { |           if (failed(shape_op.inferShapes())) { | ||||||
|             op->emitError("unable to infer shape of operation without shape " |             op->emitError("unable to infer shape of operation without shape " | ||||||
|                           "inference method"); |                           "inference method"); | ||||||
|             return signalPassFailure(); |             return signalPassFailure(); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue