Fix reshape op. (#53)
This commit is contained in:
		
							parent
							
								
									7c889548a7
								
							
						
					
					
						commit
						f00206cecf
					
				| 
						 | 
					@ -1118,8 +1118,13 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
 | 
				
			||||||
        // the result.
 | 
					        // the result.
 | 
				
			||||||
        Value index = rewriter.create<ConstantOp>(
 | 
					        Value index = rewriter.create<ConstantOp>(
 | 
				
			||||||
            loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
 | 
					            loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
 | 
				
			||||||
 | 
					        // Load index from array of indices.
 | 
				
			||||||
        Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
 | 
					        Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
 | 
				
			||||||
        Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
 | 
					        // Check if the loaded index is already the correct width of 64 bits.
 | 
				
			||||||
 | 
					        // Convert the value to a 64 bit integer if needed.
 | 
				
			||||||
 | 
					        Value int64LoadedVal = loadedVal;
 | 
				
			||||||
 | 
					        if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
 | 
				
			||||||
 | 
					          int64LoadedVal = rewriter.create<ZeroExtendIOp>(
 | 
				
			||||||
              loc, loadedVal, rewriter.getIntegerType(64));
 | 
					              loc, loadedVal, rewriter.getIntegerType(64));
 | 
				
			||||||
        tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
 | 
					        tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
 | 
				
			||||||
        allocOperands.push_back(rewriter.create<IndexCastOp>(
 | 
					        allocOperands.push_back(rewriter.create<IndexCastOp>(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -83,11 +83,17 @@ public:
 | 
				
			||||||
    Value int64Size = rewriter.create<LLVM::SExtOp>(
 | 
					    Value int64Size = rewriter.create<LLVM::SExtOp>(
 | 
				
			||||||
        loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
 | 
					        loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Is volatile (set to false).
 | 
				
			||||||
 | 
					    Value isVolatile = rewriter.create<LLVM::ConstantOp>(
 | 
				
			||||||
 | 
					        loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
 | 
				
			||||||
 | 
					        rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Memcpy call
 | 
					    // Memcpy call
 | 
				
			||||||
    rewriter.create<CallOp>(
 | 
					    rewriter.create<CallOp>(
 | 
				
			||||||
        loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
 | 
					        loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
 | 
				
			||||||
        ArrayRef<Value>(
 | 
					        ArrayRef<Value>(
 | 
				
			||||||
            {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
 | 
					            {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
 | 
				
			||||||
 | 
					             isVolatile}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    rewriter.eraseOp(op);
 | 
					    rewriter.eraseOp(op);
 | 
				
			||||||
    return matchSuccess();
 | 
					    return matchSuccess();
 | 
				
			||||||
| 
						 | 
					@ -107,9 +113,10 @@ private:
 | 
				
			||||||
    auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
 | 
					    auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
 | 
				
			||||||
    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
 | 
					    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
 | 
				
			||||||
    auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
 | 
					    auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
 | 
				
			||||||
 | 
					    auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
 | 
				
			||||||
    auto llvmFnType = LLVM::LLVMType::getFunctionTy(
 | 
					    auto llvmFnType = LLVM::LLVMType::getFunctionTy(
 | 
				
			||||||
        llvmVoidTy,
 | 
					        llvmVoidTy,
 | 
				
			||||||
        ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
 | 
					        ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
				
			||||||
        false);
 | 
					        false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Insert the memcpy function into the body of the parent module.
 | 
					    // Insert the memcpy function into the body of the parent module.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,6 +154,9 @@ test_to_enable = [
 | 
				
			||||||
    # SoftsignOp:
 | 
					    # SoftsignOp:
 | 
				
			||||||
    "test_softsign_cpu",
 | 
					    "test_softsign_cpu",
 | 
				
			||||||
    "test_softsign_example_cpu",
 | 
					    "test_softsign_example_cpu",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ReshapeOp:
 | 
				
			||||||
 | 
					    "test_reshape_reordered_all_dims_cpu",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Extract name of all test cases.
 | 
					# Extract name of all test cases.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,13 +4,14 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64)
 | 
					  // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 | 
					  // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 | 
				
			||||||
  // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 | 
					  // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 | 
				
			||||||
  // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
 | 
					  // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
 | 
				
			||||||
  // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
					  // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
  // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
 | 
					  // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
 | 
				
			||||||
  // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
 | 
					  // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
 | 
				
			||||||
  // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void
 | 
					  // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
 | 
				
			||||||
 | 
					  // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
 | 
				
			||||||
  // CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 | 
					  // CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue