Fix scalar entry point parameter lowering issue. (#78)
* Fix scalar entry point parameter lowering issue. * Enable scalar bias test. * Nit. Improve comments and remove debug code. * Make helper function static, move to upfront position. * Move helper function to top of the file. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									e5677bba1f
								
							
						
					
					
						commit
						937bbec265
					
				| 
						 | 
				
			
			@ -41,6 +41,23 @@ static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
 | 
			
		|||
  return SymbolRefAttr::get(funcName, context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
 | 
			
		||||
  // Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
 | 
			
		||||
  // this struct are arrays whose size is the rank of the tensor. In the event
 | 
			
		||||
  // that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
 | 
			
		||||
  // elements will have 0-length, which in turn causes the MemRef struct to
 | 
			
		||||
  // degenerate into a 3-element struct. For more information, refer to
 | 
			
		||||
  // https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
 | 
			
		||||
  auto numElems = memRefTy.getStructNumElements();
 | 
			
		||||
  assert((numElems == 3 || numElems == 5) &&
 | 
			
		||||
         "Expect MemRef type to contain either 3 or 5 elements.");
 | 
			
		||||
 | 
			
		||||
  if (numElems == 3)
 | 
			
		||||
    return 0; // MemRef refers to a scalar.
 | 
			
		||||
  else
 | 
			
		||||
    return memRefTy.getStructElementType(3).getArrayNumElements();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// KRNL to LLVM: KrnlMemcpyOpLowering
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			@ -91,9 +108,8 @@ public:
 | 
			
		|||
    // Memcpy call
 | 
			
		||||
    rewriter.create<CallOp>(
 | 
			
		||||
        loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
 | 
			
		||||
        ArrayRef<Value>(
 | 
			
		||||
            {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
 | 
			
		||||
             isVolatile}));
 | 
			
		||||
        ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
 | 
			
		||||
                         int64Size, isVolatile}));
 | 
			
		||||
 | 
			
		||||
    rewriter.eraseOp(op);
 | 
			
		||||
    return matchSuccess();
 | 
			
		||||
| 
						 | 
				
			
			@ -116,7 +132,8 @@ private:
 | 
			
		|||
    auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
 | 
			
		||||
    auto llvmFnType = LLVM::LLVMType::getFunctionTy(
 | 
			
		||||
        llvmVoidTy,
 | 
			
		||||
        ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
			
		||||
        ArrayRef<mlir::LLVM::LLVMType>(
 | 
			
		||||
            {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
			
		||||
        false);
 | 
			
		||||
 | 
			
		||||
    // Insert the memcpy function into the body of the parent module.
 | 
			
		||||
| 
						 | 
				
			
			@ -261,8 +278,7 @@ public:
 | 
			
		|||
    // it in the wrapped Output.
 | 
			
		||||
    auto outMemRef = outputMemRefs.getResult(0);
 | 
			
		||||
    auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
 | 
			
		||||
    auto outMemRefRank =
 | 
			
		||||
        outMemRefTy.getStructElementType(3).getArrayNumElements();
 | 
			
		||||
    auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
 | 
			
		||||
    auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
 | 
			
		||||
        loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
 | 
			
		||||
    auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
 | 
			
		||||
| 
						 | 
				
			
			@ -376,7 +392,7 @@ private:
 | 
			
		|||
        rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
 | 
			
		||||
 | 
			
		||||
    // Get rank, sizes array ptr and strides array ptr.
 | 
			
		||||
    auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
 | 
			
		||||
    auto rank = getRankFromMemRefType(memRefTy);
 | 
			
		||||
    auto sizesArrayPtr =
 | 
			
		||||
        callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
 | 
			
		||||
    auto stridesArrayPtr =
 | 
			
		||||
| 
						 | 
				
			
			@ -428,7 +444,7 @@ private:
 | 
			
		|||
    callApi(rewriter, loc, apiRegistry, API::SET_DATA,
 | 
			
		||||
            {outDynMemRef, outMemRefDataPtr});
 | 
			
		||||
 | 
			
		||||
    auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
 | 
			
		||||
    auto rank = getRankFromMemRefType(outMemRefTy);
 | 
			
		||||
    auto sizesArrayPtr =
 | 
			
		||||
        callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
 | 
			
		||||
    auto stridesArrayPtr =
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -99,7 +99,7 @@ test_to_enable = [
 | 
			
		|||
    "test_gemm_beta_cpu",
 | 
			
		||||
    "test_gemm_default_matrix_bias_cpu",
 | 
			
		||||
    # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
 | 
			
		||||
    # "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why?
 | 
			
		||||
    "test_gemm_default_scalar_bias_cpu",
 | 
			
		||||
    "test_gemm_default_single_elem_vector_bias_cpu",
 | 
			
		||||
    "test_gemm_default_vector_bias_cpu",
 | 
			
		||||
    "test_gemm_default_zero_bias_cpu",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue