Fix scalar entry point parameter lowering issue. (#78)

* Fix scalar entry point parameter lowering issue.

* Enable scalar bias test.

* Nit. Improve comments and remove debug code.

* Make helper function static, move to upfront position.

* Move helper function to top of the file.

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tian Jin 2020-02-13 13:50:05 +08:00 committed by GitHub
parent e5677bba1f
commit 937bbec265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 9 deletions

View File

@ -41,6 +41,23 @@ static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
return SymbolRefAttr::get(funcName, context); return SymbolRefAttr::get(funcName, context);
} }
static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
// Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
// this struct are arrays whose size is the rank of the tensor. In the event
// that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
// elements will have 0-length, which in turn causes the MemRef struct to
// degenerate into a 3-element struct. For more information, refer to
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
auto numElems = memRefTy.getStructNumElements();
assert((numElems == 3 || numElems == 5) &&
"Expect MemRef type to contain either 3 or 5 elements.");
if (numElems == 3)
return 0; // MemRef refers to a scalar.
else
return memRefTy.getStructElementType(3).getArrayNumElements();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlMemcpyOpLowering // KRNL to LLVM: KrnlMemcpyOpLowering
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -91,9 +108,8 @@ public:
// Memcpy call // Memcpy call
rewriter.create<CallOp>( rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>( ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size, int64Size, isVolatile}));
isVolatile}));
rewriter.eraseOp(op); rewriter.eraseOp(op);
return matchSuccess(); return matchSuccess();
@ -116,7 +132,8 @@ private:
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy( auto llvmFnType = LLVM::LLVMType::getFunctionTy(
llvmVoidTy, llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false); false);
// Insert the memcpy function into the body of the parent module. // Insert the memcpy function into the body of the parent module.
@ -261,8 +278,7 @@ public:
// it in the wrapped Output. // it in the wrapped Output.
auto outMemRef = outputMemRefs.getResult(0); auto outMemRef = outputMemRefs.getResult(0);
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>(); auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
auto outMemRefRank = auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
outMemRefTy.getStructElementType(3).getArrayNumElements();
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>( auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry, auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
@ -376,7 +392,7 @@ private:
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)})); rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
// Get rank, sizes array ptr and strides array ptr. // Get rank, sizes array ptr and strides array ptr.
auto rank = memRefTy.getStructElementType(3).getArrayNumElements(); auto rank = getRankFromMemRefType(memRefTy);
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
auto stridesArrayPtr = auto stridesArrayPtr =
@ -428,7 +444,7 @@ private:
callApi(rewriter, loc, apiRegistry, API::SET_DATA, callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{outDynMemRef, outMemRefDataPtr}); {outDynMemRef, outMemRefDataPtr});
auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements(); auto rank = getRankFromMemRefType(outMemRefTy);
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
auto stridesArrayPtr = auto stridesArrayPtr =

View File

@ -99,7 +99,7 @@ test_to_enable = [
"test_gemm_beta_cpu", "test_gemm_beta_cpu",
"test_gemm_default_matrix_bias_cpu", "test_gemm_default_matrix_bias_cpu",
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
# "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why? "test_gemm_default_scalar_bias_cpu",
"test_gemm_default_single_elem_vector_bias_cpu", "test_gemm_default_single_elem_vector_bias_cpu",
"test_gemm_default_vector_bias_cpu", "test_gemm_default_vector_bias_cpu",
"test_gemm_default_zero_bias_cpu", "test_gemm_default_zero_bias_cpu",