Fix reshape op. (#53)
This commit is contained in:
parent
7c889548a7
commit
f00206cecf
|
@ -1118,8 +1118,13 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
// the result.
|
// the result.
|
||||||
Value index = rewriter.create<ConstantOp>(
|
Value index = rewriter.create<ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||||
|
// Load index from array of indices.
|
||||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
// Check if the loaded index is already the correct width of 64 bits.
|
||||||
|
// Convert the value to a 64 bit integer if needed.
|
||||||
|
Value int64LoadedVal = loadedVal;
|
||||||
|
if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
|
||||||
|
int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||||
loc, loadedVal, rewriter.getIntegerType(64));
|
loc, loadedVal, rewriter.getIntegerType(64));
|
||||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||||
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||||
|
|
|
@ -83,11 +83,17 @@ public:
|
||||||
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||||
|
|
||||||
|
// Is volatile (set to false).
|
||||||
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
|
|
||||||
// Memcpy call
|
// Memcpy call
|
||||||
rewriter.create<CallOp>(
|
rewriter.create<CallOp>(
|
||||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
ArrayRef<Value>(
|
ArrayRef<Value>(
|
||||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
|
||||||
|
isVolatile}));
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
|
@ -107,9 +113,10 @@ private:
|
||||||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
||||||
llvmVoidTy,
|
llvmVoidTy,
|
||||||
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
|
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||||
false);
|
false);
|
||||||
|
|
||||||
// Insert the memcpy function into the body of the parent module.
|
// Insert the memcpy function into the body of the parent module.
|
||||||
|
|
|
@ -154,6 +154,9 @@ test_to_enable = [
|
||||||
# SoftsignOp:
|
# SoftsignOp:
|
||||||
"test_softsign_cpu",
|
"test_softsign_cpu",
|
||||||
"test_softsign_example_cpu",
|
"test_softsign_example_cpu",
|
||||||
|
|
||||||
|
# ReshapeOp:
|
||||||
|
"test_reshape_reordered_all_dims_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
|
|
|
@ -4,13 +4,14 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
||||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64)
|
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void
|
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
|
||||||
|
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||||
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue