Fix scalar entry point parameter lowering issue. (#78)
* Fix scalar entry point parameter lowering issue. * Enable scalar bias test. * Nit. Improve comments and remove debug code. * Make helper function static, move to upfront position. * Move helper function to top of the file. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									e5677bba1f
								
							
						
					
					
						commit
						937bbec265
					
				| 
						 | 
					@ -41,6 +41,23 @@ static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
 | 
				
			||||||
  return SymbolRefAttr::get(funcName, context);
 | 
					  return SymbolRefAttr::get(funcName, context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
 | 
				
			||||||
 | 
					  // Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
 | 
				
			||||||
 | 
					  // this struct are arrays whose size is the rank of the tensor. In the event
 | 
				
			||||||
 | 
					  // that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
 | 
				
			||||||
 | 
					  // elements will have 0-length, which in turn causes the MemRef struct to
 | 
				
			||||||
 | 
					  // degenerate into a 3-element struct. For more information, refer to
 | 
				
			||||||
 | 
					  // https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
 | 
				
			||||||
 | 
					  auto numElems = memRefTy.getStructNumElements();
 | 
				
			||||||
 | 
					  assert((numElems == 3 || numElems == 5) &&
 | 
				
			||||||
 | 
					         "Expect MemRef type to contain either 3 or 5 elements.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (numElems == 3)
 | 
				
			||||||
 | 
					    return 0; // MemRef refers to a scalar.
 | 
				
			||||||
 | 
					  else
 | 
				
			||||||
 | 
					    return memRefTy.getStructElementType(3).getArrayNumElements();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// KRNL to LLVM: KrnlMemcpyOpLowering
 | 
					// KRNL to LLVM: KrnlMemcpyOpLowering
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -91,9 +108,8 @@ public:
 | 
				
			||||||
    // Memcpy call
 | 
					    // Memcpy call
 | 
				
			||||||
    rewriter.create<CallOp>(
 | 
					    rewriter.create<CallOp>(
 | 
				
			||||||
        loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
 | 
					        loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
 | 
				
			||||||
        ArrayRef<Value>(
 | 
					        ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
 | 
				
			||||||
            {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size,
 | 
					                         int64Size, isVolatile}));
 | 
				
			||||||
             isVolatile}));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    rewriter.eraseOp(op);
 | 
					    rewriter.eraseOp(op);
 | 
				
			||||||
    return matchSuccess();
 | 
					    return matchSuccess();
 | 
				
			||||||
| 
						 | 
					@ -116,7 +132,8 @@ private:
 | 
				
			||||||
    auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
 | 
					    auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
 | 
				
			||||||
    auto llvmFnType = LLVM::LLVMType::getFunctionTy(
 | 
					    auto llvmFnType = LLVM::LLVMType::getFunctionTy(
 | 
				
			||||||
        llvmVoidTy,
 | 
					        llvmVoidTy,
 | 
				
			||||||
        ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
					        ArrayRef<mlir::LLVM::LLVMType>(
 | 
				
			||||||
 | 
					            {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
				
			||||||
        false);
 | 
					        false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Insert the memcpy function into the body of the parent module.
 | 
					    // Insert the memcpy function into the body of the parent module.
 | 
				
			||||||
| 
						 | 
					@ -261,8 +278,7 @@ public:
 | 
				
			||||||
    // it in the wrapped Output.
 | 
					    // it in the wrapped Output.
 | 
				
			||||||
    auto outMemRef = outputMemRefs.getResult(0);
 | 
					    auto outMemRef = outputMemRefs.getResult(0);
 | 
				
			||||||
    auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
 | 
					    auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
 | 
				
			||||||
    auto outMemRefRank =
 | 
					    auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
 | 
				
			||||||
        outMemRefTy.getStructElementType(3).getArrayNumElements();
 | 
					 | 
				
			||||||
    auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
 | 
					    auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
 | 
				
			||||||
        loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
 | 
					        loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
 | 
				
			||||||
    auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
 | 
					    auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
 | 
				
			||||||
| 
						 | 
					@ -376,7 +392,7 @@ private:
 | 
				
			||||||
        rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
 | 
					        rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Get rank, sizes array ptr and strides array ptr.
 | 
					    // Get rank, sizes array ptr and strides array ptr.
 | 
				
			||||||
    auto rank = memRefTy.getStructElementType(3).getArrayNumElements();
 | 
					    auto rank = getRankFromMemRefType(memRefTy);
 | 
				
			||||||
    auto sizesArrayPtr =
 | 
					    auto sizesArrayPtr =
 | 
				
			||||||
        callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
 | 
					        callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
 | 
				
			||||||
    auto stridesArrayPtr =
 | 
					    auto stridesArrayPtr =
 | 
				
			||||||
| 
						 | 
					@ -428,7 +444,7 @@ private:
 | 
				
			||||||
    callApi(rewriter, loc, apiRegistry, API::SET_DATA,
 | 
					    callApi(rewriter, loc, apiRegistry, API::SET_DATA,
 | 
				
			||||||
            {outDynMemRef, outMemRefDataPtr});
 | 
					            {outDynMemRef, outMemRefDataPtr});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto rank = outMemRefTy.getStructElementType(3).getArrayNumElements();
 | 
					    auto rank = getRankFromMemRefType(outMemRefTy);
 | 
				
			||||||
    auto sizesArrayPtr =
 | 
					    auto sizesArrayPtr =
 | 
				
			||||||
        callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
 | 
					        callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
 | 
				
			||||||
    auto stridesArrayPtr =
 | 
					    auto stridesArrayPtr =
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,7 +99,7 @@ test_to_enable = [
 | 
				
			||||||
    "test_gemm_beta_cpu",
 | 
					    "test_gemm_beta_cpu",
 | 
				
			||||||
    "test_gemm_default_matrix_bias_cpu",
 | 
					    "test_gemm_default_matrix_bias_cpu",
 | 
				
			||||||
    # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
 | 
					    # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
 | 
				
			||||||
    # "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why?
 | 
					    "test_gemm_default_scalar_bias_cpu",
 | 
				
			||||||
    "test_gemm_default_single_elem_vector_bias_cpu",
 | 
					    "test_gemm_default_single_elem_vector_bias_cpu",
 | 
				
			||||||
    "test_gemm_default_vector_bias_cpu",
 | 
					    "test_gemm_default_vector_bias_cpu",
 | 
				
			||||||
    "test_gemm_default_zero_bias_cpu",
 | 
					    "test_gemm_default_zero_bias_cpu",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue