Fix rebase errors. (#378)
This commit is contained in:
		
							parent
							
								
									6c7ff180f9
								
							
						
					
					
						commit
						bee32e2041
					
				|  | @ -58,17 +58,17 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : | |||
| 
 | ||||
| include "dialect/onnx/onnxop.inc" | ||||
| 
 | ||||
| def ONNXFullGemmOp: ONNX_Op<"full_gemm",	 | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {	 | ||||
|   let summary = "ONNX general matrix multiply operation";	 | ||||
|   let description = [{	 | ||||
| def ONNXFullGemmOp: ONNX_Op<"FullGemm", | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX general matrix multiply operation"; | ||||
|   let description = [{ | ||||
| 
 | ||||
|     The "onnx.gemm" generic matrix multiplication with bias.	 | ||||
|     The "onnx.gemm" generic matrix multiplication with bias. | ||||
| 
 | ||||
|   }];	 | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);	 | ||||
|   let results = (outs AnyTensor);	 | ||||
|   let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); | ||||
|   let results = (outs AnyTensor); | ||||
| } | ||||
| 
 | ||||
| #endif // ONNX_OPS | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>; | |||
| // Pattern-Match and Rewrite | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.full_gemm(%X, %Y, %Z) | ||||
| // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) | ||||
| def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), | ||||
|                                  (ONNXFullGemmOp $m1, $m2, $m3), | ||||
| 				 [(HasOneUse $res)]>; | ||||
|  |  | |||
|  | @ -82,10 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { | |||
|     // All operations which do not return a ranked tensor type have dynamic
 | ||||
|     // shaped outputs. All those operation need to implement the inferShape()
 | ||||
|     // method.
 | ||||
|     if (op->getName().getStringRef() != "onnx.add" && | ||||
|         op->getName().getStringRef() != "onnx.matmul" && | ||||
|         op->getName().getStringRef() != "onnx.gemm" && | ||||
|         op->getName().getStringRef() != "onnx.full_gemm") | ||||
|     if (op->getName().getStringRef() != "onnx.Add" && | ||||
|         op->getName().getStringRef() != "onnx.MatMul" && | ||||
|         op->getName().getStringRef() != "onnx.Gemm" && | ||||
|         op->getName().getStringRef() != "onnx.FullGemm") | ||||
|       return false; | ||||
|     return llvm::any_of(op->getResultTypes(), | ||||
|         [](Type result_type) { return !result_type.isa<RankedTensorType>(); }); | ||||
|  |  | |||
|  | @ -1,6 +1,14 @@ | |||
| 
 | ||||
| import os | ||||
| import sys | ||||
| import re | ||||
| import platform | ||||
| import subprocess | ||||
| 
 | ||||
| import lit.util | ||||
| import lit.formats | ||||
| from lit.llvm import llvm_config | ||||
| from lit.llvm.subst import FindTool | ||||
| from lit.llvm.subst import ToolSubst | ||||
| 
 | ||||
| # name: The name of this test suite. | ||||
|  |  | |||
|  | @ -2,7 +2,7 @@ | |||
| import lit.llvm | ||||
| 
 | ||||
| config.llvm_tools_dir = "@MLIR_TOOLS_DIR@" | ||||
| config.mlir_obj_root = "@MLIR_BUILD_DIR@" | ||||
| config.mlir_obj_root = "@LLVM_BUILD@" | ||||
| config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" | ||||
| config.suffixes = ['.mlir'] | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,13 +2,10 @@ | |||
| 
 | ||||
| //CHECK: module { | ||||
| module { | ||||
|  func @test_sigmoid() { | ||||
|    %0 = "frontend.input t1"() : () -> tensor<10x10xf32> | ||||
|    %1 = "frontend.input t2"() : () -> tensor<10x10xf32> | ||||
|    %2 = "frontend.input t3"() : () -> tensor<10x10xf32> | ||||
|    // CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|    %3 = "onnx.MatMul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|    %4 = "onnx.Add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|    %5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|  func @test_sigmoid(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||||
|    // CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|    %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|    %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> | ||||
|    "std.return"(%1) : (tensor<10x10xf32>) -> () | ||||
|  } | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue