diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 1402748..c1277ff 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1118,12 +1118,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern { // the result. Value index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); - Value int64LoadedVal = rewriter.create( - loc, loadedVal, rewriter.getIntegerType(64)); + // Check if the loaded index is already the correct width of 64 bits. + // Convert the value to a 64 bit integer if needed. + Value int64LoadedVal = loadedVal; + if (loadedVal.getType().cast().getWidth() < 64) + int64LoadedVal = rewriter.create( + loc, loadedVal, rewriter.getIntegerType(64)); tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); allocOperands.push_back(rewriter.create( - loc, loadedVal, rewriter.getIndexType())); + loc, loadedVal, rewriter.getIndexType())); } AllocOp allocateMemref = rewriter.create(loc, memRefType, allocOperands); diff --git a/src/transform/lower_to_llvm.cpp b/src/transform/lower_to_llvm.cpp index be6d291..9d229f5 100644 --- a/src/transform/lower_to_llvm.cpp +++ b/src/transform/lower_to_llvm.cpp @@ -83,11 +83,17 @@ public: Value int64Size = rewriter.create( loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); + // Is volatile (set to false). + Value isVolatile = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(llvmDialect), + rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); + // Memcpy call rewriter.create( loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), ArrayRef( - {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); + {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size, + isVolatile})); rewriter.eraseOp(op); return matchSuccess(); @@ -107,9 +113,10 @@ private: auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); auto llvmFnType = LLVM::LLVMType::getFunctionTy( llvmVoidTy, - ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}), + ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), false); // Insert the memcpy function into the body of the parent module. diff --git a/test/backend/test.py b/test/backend/test.py index 39f1b6e..c850462 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -154,6 +154,9 @@ test_to_enable = [ # SoftsignOp: "test_softsign_cpu", "test_softsign_example_cpu", + + # ReshapeOp: + "test_reshape_reordered_all_dims_cpu", ] # Extract name of all test cases. diff --git a/test/mlir/krnl/reshape.mlir b/test/mlir/krnl/reshape.mlir index 2b46582..84a068c 100644 --- a/test/mlir/krnl/reshape.mlir +++ b/test/mlir/krnl/reshape.mlir @@ -4,13 +4,14 @@ func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi32>) -> tensor<*x %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<4xi32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () - // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) + // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 - // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void + // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1 + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void // CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> }