Fix reshape op. (#53)

This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-28 10:21:08 -05:00 committed by GitHub
parent 7c889548a7
commit f00206cecf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 7 deletions

View File

@ -1118,12 +1118,17 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
// the result.
Value index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
// Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
Value int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64));
// Check if the loaded index is already the correct width of 64 bits.
// Convert the value to a 64 bit integer if needed.
Value int64LoadedVal = loadedVal;
if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
allocOperands.push_back(rewriter.create<IndexCastOp>(
loc, loadedVal, rewriter.getIndexType()));
loc, loadedVal, rewriter.getIndexType()));
}
AllocOp allocateMemref =
rewriter.create<AllocOp>(loc, memRefType, allocOperands);

View File

@ -83,11 +83,17 @@ public:
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Is volatile (set to false).
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// Memcpy call
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>(
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
isVolatile}));
rewriter.eraseOp(op);
return matchSuccess();
@ -107,9 +113,10 @@ private:
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false);
// Insert the memcpy function into the body of the parent module.

View File

@ -154,6 +154,9 @@ test_to_enable = [
# SoftsignOp:
"test_softsign_cpu",
"test_softsign_example_cpu",
# ReshapeOp:
"test_reshape_reordered_all_dims_cpu",
]
# Extract name of all test cases.

View File

@ -4,13 +4,14 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64)
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
}