Hazard3/test/sim/rvpy/rvpy

492 lines
15 KiB
Python
Executable File

#!/usr/bin/env python3
# Minimal RISC-V interpreter, supporting RV32I + Zcsr only, with trace disassembly
import argparse
import sys
XLEN = 32
XLEN_MASK = (1 << XLEN) - 1
def extract(bits, msb, lsb):
return (bits & (1 << msb + 1) - 1) >> lsb
def sext(bits, sign_bit):
return (bits & (1 << sign_bit + 1) - 1) - ((bits & 1 << sign_bit) << 1)
def concat_extract(bits, msb_lsb_pairs, signed=True):
accum = 0
accum_count = 0
for msb, lsb in msb_lsb_pairs:
accum = (accum << (msb - lsb + 1)) | extract(bits, msb, lsb)
accum_count += msb - lsb + 1
if signed:
accum = sext(accum, accum_count - 1)
return accum
# Note these handy functions are not used much in the main loop, because CPython is unable
# to inline them. This and similar changes results in a ~3x performance increase. :(
def imm_i(instr):
# return concat_extract(instr, ((31, 20),))
return (instr >> 20) - (instr >> 19 & 0x1000)
def imm_s(instr):
# return concat_extract(instr, ((31, 25), (11, 7)))
return (instr >> 20 & 0xfe0) + (instr >> 7 & 0x1f) - (instr >> 19 & 0x1000)
def imm_u(instr):
# return concat_extract(instr, ((31, 12),)) << 12
return instr & 0xfffff000 - (instr << 1 & 0x100000000)
def imm_b(instr):
return concat_extract(instr, ((31, 31), (7, 7), (30, 25), (11, 8))) << 1
def imm_j(instr):
return concat_extract(instr, ((31, 31), (19, 12), (20, 20), (30, 21))) << 1
class FlatMemory:
def __init__(self, size):
self.size = size
self.mem = [0] * (size >> 2)
# Reads are unsigned. Writes allow signed or unsigned values and convert
# implicitly to unsigned. Multi-byte accesses are little-endian.
def get8(self, addr):
assert(addr >= 0 and addr < self.size)
return self.mem[addr >> 2] >> (addr & 0x3) * 8 & 0xff
def put8(self, addr, data):
assert(addr >= 0 and addr < self.size)
assert(data >= -1 << 7 and data < 1 << 8)
self.mem[addr >> 2] &= ~(0xff << 8 * (addr & 0x3))
self.mem[addr >> 2] |= (data & 0xff) << 8 * (addr % 4)
def get16(self, addr):
return self.mem[addr >> 2] >> (addr & 0x2) * 8 & 0xffff
def put16(self, addr, data):
assert(data >= -1 << 15 and data < 1 << 16)
for i in range(2):
self.put8(addr + i, data >> 8 * i & 0xff)
def get32(self, addr):
assert(addr >= 0 and addr + 3 < self.size)
return self.mem[addr >> 2]
def put32(self, addr, data):
assert(data >= -1 << 31 and data < 1 << 32)
assert(addr >= 0 and addr + 3 < self.size)
self.mem[addr >> 2] = data & 0xffff_ffff
def loadbin(self, data, offs):
if type(data) not in (bytes, bytearray):
# must be fh
assert(data.mode == "rb")
data = data.read()
assert(offs + len(data) < self.size)
for i, b in enumerate(data):
self.put8(offs + i, b)
class TBExit(Exception):
pass
class MemWithTBIO(FlatMemory):
TB_IO_BASE = 0x80000000
TB_IO_PRINT_CHAR = TB_IO_BASE + 0x0
TB_IO_PRINT_INT = TB_IO_BASE + 0x4
TB_IO_EXIT = TB_IO_BASE + 0x8
def __init__(self, size, io_log_fmt="IO: {}\n"):
super().__init__(size)
self.io_log_fmt = io_log_fmt
def put32(self, addr, data):
if addr < self.TB_IO_BASE:
super().put32(addr, data)
elif addr == self.TB_IO_PRINT_CHAR:
sys.stdout.write(self.io_log_fmt.format(chr(data)))
elif addr == self.TB_IO_PRINT_INT:
sys.stdout.write(self.io_log_fmt.format(f"{data:08x}"))
elif addr == self.TB_IO_EXIT:
raise TBExit(data)
else:
print(f"Unknown IO address {addr:08x}")
class RVCSR:
WRITE = 0
WRITE_SET = 1
WRITE_CLEAR = 2
MSCRATCH = 0x340
MCYCLE = 0xb00
MTIME = 0xb01
MINSTRET = 0xb02
def __init__(self):
self.mcycle = 0
self.mscratch = 0
def step(self):
self.mcycle += 1
def read(self, addr, side_effect=True):
# Close your eyes
if addr in (RVCSR.MCYCLE, RVCSR.MTIME, RVCSR.MINSTRET):
return self.mcycle
elif addr == RVCSR.MSCRATCH:
return self.mscratch
else:
return None
def write(self, addr, data, op=0):
if op == RVCSR.WRITE_CLEAR:
data = self.read(addr, side_effect=False) & ~data
elif op == RVCSR.WRITE_SET:
data = self.read(addr, side_effect=False) | data
if addr == RVCSR.MCYCLE:
self.mcycle = data
elif addr == RVCSR.MSCRATCH:
self.mscratch = data
class RVCore:
def __init__(self, mem, reset_vector=0xc0):
self.regs = [0] * 32
self.mem = mem
self.pc = reset_vector
self.csr = RVCSR()
def step(self, instr=None, log=True):
if instr is None:
instr = self.mem.mem[self.pc >> 2]
regnum_rs1 = instr >> 15 & 0x1f
regnum_rs2 = instr >> 20 & 0x1f
regnum_rd = instr >> 7 & 0x1f
rs1 = self.regs[regnum_rs1]
rs2 = self.regs[regnum_rs2]
rd_wdata = None
pc_wdata = None
log_disasm = None
instr_invalid = False
opc = instr >> 2 & 0x1f
funct3 = instr >> 12 & 0x7
funct7 = instr >> 25 & 0x7f
OPC_LOAD = 0b00_000
OPC_MISC_MEM = 0b00_011
OPC_OP_IMM = 0b00_100
OPC_AUIPC = 0b00_101
OPC_STORE = 0b01_000
OPC_OP = 0b01_100
OPC_LUI = 0b01_101
OPC_BRANCH = 0b11_000
OPC_JALR = 0b11_001
OPC_JAL = 0b11_011
OPC_SYSTEM = 0b11_100
if opc == OPC_OP:
if log: log_reg_str = f" x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}"
if funct7 == 0b00_00000:
if funct3 == 0b000:
if log: log_disasm = "add" + log_reg_str
rd_wdata = rs1 + rs2
elif funct3 == 0b001:
if log: log_disasm = "sll" + log_reg_str
rd_wdata = rs1 << (rs2 & 0x1f)
elif funct3 == 0b010:
if log: log_disasm = "slt" + log_reg_str
rd_wdata = rs1 < rs2
elif funct3 == 0b011:
if log: log_disasm = "sltu" + log_reg_str
rd_wdata = (rs1 & XLEN_MASK) < (rs2 & XLEN_MASK)
elif funct3 == 0b100:
if log: log_disasm = "xor" + log_reg_str
rd_wdata = rs1 ^ rs2
elif funct3 == 0b101:
if log: log_disasm = "srl" + log_reg_str
rd_wdata = (rs1 & XLEN_MASK) >> (rs2 & 0x1f)
elif funct3 == 0b110:
if log: log_disasm = "or" + log_reg_str
rd_wdata = rs1 | rs2
elif funct3 == 0b111:
if log: log_disasm = "and" + log_reg_str
rd_wdata = rs1 & rs2
else:
instr_invalid = True
elif funct7 == 0b01_00000:
if funct3 == 0b000:
if log: log_disasm = "sub" + log_reg_str
rd_wdata = rs1 - rs2
elif funct3 == 0b101:
if log: log_disasm = "sra" + log_reg_str
rd_wdata = rs1 >> (rs2 & 0x1f)
else:
instr_invalid = True
elif funct7 == 0b00_00001:
if funct3 < 0b100:
if log:
mul_instr_name = {0b000: "mul", 0b001: "mulh", 0b010: "mulhsu", 0b011: "mulhu"}[funct3]
log_disasm = f"{mul_instr_name} x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}"
mul_op_a = rs1 & XLEN_MASK if funct3 == 0b011 else rs1
mul_op_b = rs2 & XLEN_MASK if funct3 in (0b010, 0b011) else rs2
mul_result = mul_op_a * mul_op_b
if funct3 != 0b000:
mul_result >>= 32
rd_wdata = sext(mul_result, XLEN - 1)
else:
if log:
div_instr_name = {0b100: "div", 0b101: "divu", 0b110: "rem", 0b111: "remu"}[funct3]
log_disasm = f"{div_instr_name} x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}"
if funct3 == 0b100:
rd_wdata = -1 if rs2 == 0 else int(rs1 / rs2)
elif funct3 == 0b101:
rd_wdata = -1 if rs2 == 0 else sext((rs1 & XLEN_MASK) // (rs2 & XLEN_MASK), XLEN - 1)
elif funct3 == 0b110:
rd_wdata = rs1 if rs2 == 0 else rs1 - int(rs1 / rs2) * rs2
elif funct3 == 0b111:
rd_wdata = rs1 if rs2 == 0 else sext((rs1 & XLEN_MASK) % (rs2 & XLEN_MASK), XLEN - 1)
else:
instr_invalid = True
else:
instr_invalid = True
elif opc == OPC_OP_IMM:
imm = (instr >> 20) - (instr >> 19 & 0x1000) # imm_i(instr)
if funct3 == 0b000:
if log: log_disasm = f"addi x{regnum_rd}, x{regnum_rs1}, {imm}"
rd_wdata = rs1 + imm
elif funct3 == 0b010:
if log: log_disasm = f"slti x{regnum_rd}, x{regnum_rs1}, {imm}"
rd_wdata = 1 * (rs1 < imm)
elif funct3 == 0b011:
if log: log_disasm = f"sltiu x{regnum_rd}, x{regnum_rs1}, {imm & XLEN_MASK}"
rd_wdata = 1 * (rs1 & XLEN_MASK < imm & XLEN_MASK)
elif funct3 == 0b100:
if log: log_disasm = f"xori x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}"
rd_wdata = rs1 ^ imm
elif funct3 == 0b110:
if log: log_disasm = f"ori x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}"
rd_wdata = rs1 | imm
elif funct3 == 0b111:
if log: log_disasm = f"andi x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}"
rd_wdata = rs1 & imm
elif funct3 == 0b001 or funct3 == 0b101:
# shamt is regnum_rs2
if funct7 == 0b00_00000 and funct3 == 0b001:
if log: log_disasm = f"slli x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}"
rd_wdata = rs1 << regnum_rs2
elif funct7 == 0b00_00000 and funct3 == 0b101:
if log: log_disasm = f"srli x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}"
rd_wdata = (rs1 & XLEN_MASK) >> regnum_rs2
elif funct7 == 0b01_00000 and funct3 == 0b101:
if log: log_disasm = f"srai x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}"
rd_wdata = rs1 >> regnum_rs2
else:
instr_invalid = True
else:
instr_invalid = True
elif opc == OPC_JAL:
rd_wdata = self.pc + 4
# pc_wdata = self.pc + imm_j(instr)
pc_wdata = self.pc + (instr >> 20 & 0x7fe) + (instr >> 9 & 0x800) + (instr & 0xff000) - (instr >> 11 & 0x100000)
if log: log_disasm = f"jal x{regnum_rd}, {pc_wdata & XLEN_MASK:08x}"
elif opc == OPC_JALR:
imm = imm_i(instr)
if log: log_disasm = f"jalr x{regnum_rd}, x{regnum_rs1}, {imm}"
rd_wdata = self.pc + 4
# JALR clears LSB always
pc_wdata = (rs1 + imm) & -2
elif opc == OPC_BRANCH:
# target = self.pc + imm_b(instr)
target = self.pc + (instr >> 7 & 0x1e) + (instr >> 20 & 0x7e0) + (instr << 4 & 0x800) - (instr >> 19 & 0x1000)
taken = False
if log: log_branch_str = f" x{regnum_rs1}, x{regnum_rs2}, {target:08x}"
if funct3 == 0b000:
if log: log_disasm = "beq" + log_branch_str
taken = rs1 == rs2
elif funct3 == 0b001:
if log: log_disasm = "bne" + log_branch_str
taken = rs1 != rs2
elif funct3 == 0b100:
if log: log_disasm = "blt" + log_branch_str
taken = rs1 < rs2
elif funct3 == 0b101:
if log: log_disasm = "bge" + log_branch_str
taken = rs1 >= rs2
elif funct3 == 0b110:
if log: log_disasm = "bltu" + log_branch_str
taken = (rs1 & XLEN_MASK) < (rs2 & XLEN_MASK)
elif funct3 == 0b111:
if log: log_disasm = "bgeu" + log_branch_str
taken = (rs1 & XLEN_MASK) >= (rs2 & XLEN_MASK)
else:
instr_invalid = True
if taken:
pc_wdata = target
elif opc == OPC_LOAD:
imm = imm_i(instr)
if log: log_load_str = f" x{regnum_rd}, {imm}(x{regnum_rs1})"
load_addr = imm + rs1 & XLEN_MASK
if funct3 == 0b000:
if log: log_disasm = "lb" + log_load_str
rd_wdata = self.mem.get8(load_addr)
rd_wdata -= rd_wdata << 1 & 0x100
elif funct3 == 0b001:
if log: log_disasm = "lh" + log_load_str
rd_wdata = self.mem.get16(load_addr)
rd_wdata -= rd_wdata << 1 & 0x10000
elif funct3 == 0b010:
if log: log_disasm = "lw" + log_load_str
rd_wdata = self.mem.get32(load_addr)
rd_wdata -= rd_wdata << 1 & 0x100000000
elif funct3 == 0b100:
if log: log_disasm = "lbu" + log_load_str
rd_wdata = self.mem.get8(load_addr)
elif funct3 == 0b101:
if log: log_disasm = "lhu" + log_load_str
rd_wdata = self.mem.get16(load_addr)
else:
instr_invalid = True
elif opc == OPC_STORE:
imm = imm_s(instr)
if log: log_store_str = f" x{regnum_rs2}, {imm}(x{regnum_rs1})"
store_addr = imm + rs1 & XLEN_MASK
if funct3 == 0b000:
if log: log_disasm = "sb" + log_store_str
self.mem.put8(store_addr, rs2 & (1 << 8) - 1)
elif funct3 == 0b001:
if log: log_disasm = "sh" + log_store_str
self.mem.put16(store_addr, rs2 & (1 << 16) - 1)
elif funct3 == 0b010:
if log: log_disasm = "sw" + log_store_str
self.mem.put32(store_addr, rs2)
else:
instr_invalid = True
elif opc == OPC_LUI:
imm = imm_u(instr)
if log: log_disasm = f"lui x{regnum_rd}, 0x{(imm & XLEN_MASK) >> 12:05x}"
rd_wdata = imm
elif opc == OPC_AUIPC:
imm = imm_u(instr)
if log: log_disasm = f"auipc x{regnum_rd}, 0x{(imm & XLEN_MASK) >> 12:05x}"
rd_wdata = self.pc + imm
elif opc == OPC_SYSTEM:
csr_addr = extract(instr, 31, 20)
if funct3 == 0b000 and funct7 == 0b00_00000:
if regnum_rs2 == 0:
if log: log_disasm = "*UNHANDLED* ecall"
pass
elif regnum_rs2 == 1:
if log: log_disasm = "*UNHANDLED* ebreak"
pass
else:
instr_invalid = True
elif funct3 in (0b001, 0b010, 0b011):
if log:
instr_name = {0b001: "csrrw", 0b010: "csrrs", 0b011: "csrrc"}[funct3]
log_disasm = f"{instr_name} x{regnum_rd}, 0x{csr_addr:x}, x{regnum_rs2}"
csr_write_op = funct3 - 0b001
if csr_write_op != RVCSR.WRITE or regnum_rd != 0:
rd_wdata = self.csr.read(csr_addr)
if csr_write_op == RVCSR.WRITE or rs2 != 0:
self.csr.write(csr_addr, rs2, op=csr_write_op)
elif funct3 in (0b101, 0b110, 0b111):
if log:
instr_name = {0b101: "csrrwi", 0b110: "csrrsi", 0b111: "csrrci"}[funct3]
log_disasm = f"{instr_name} x{regnum_rd}, 0x{csr_addr:x}, 0x{regnum_rs2:x}"
csr_write_op = funct3 = 0b101
if csr_write_op != RVCSR.WRITE or regnum_rd != 0:
rd_wdata = self.csr.read(csr_addr)
if csr_write_op == RVCSR.WRITE or regnum_rs2 != 0:
self.csr.write(csr_addr, rs2, op=csr_write_op)
else:
instr_invalid = True
elif opc == OPC_MISC_MEM:
if instr == 0b0000_0000_0000_00000_001_00000_0001111:
if log: log_disasm = "fence.i"
pass
elif (instr & 0b1111_0000_0000_11111_111_11111_1111111) == 0b0000_0000_0000_00000_000_00000_0001111:
if log: log_disasm = f"fence {extract(instr, 27, 24):04b}, {extract(instr, 23, 20):04b}"
pass
else:
instr_invalid = True
if log:
log_str = f"{self.pc:08x}: ({instr:08x}) {log_disasm if log_disasm is not None else '':<25}"
if rd_wdata is not None and regnum_rd != 0:
log_str += f" : x{regnum_rd:<2} <- {rd_wdata & XLEN_MASK:08x}"
else:
log_str += " : " + 15 * " "
if pc_wdata is not None:
log_str += f" : pc <- {pc_wdata & XLEN_MASK:08x}"
else:
log_str += " :"
print(log_str)
if rd_wdata is not None and regnum_rd != 0:
self.regs[regnum_rd] = (rd_wdata & 0xffffffff) - (rd_wdata << 1 & 0x100000000)
if pc_wdata is None:
self.pc = self.pc + 4
else:
self.pc = pc_wdata
if instr_invalid:
print(f"Invalid instruction at {self.pc:08x}: {instr:08x}")
self.csr.step()
def anyint(x):
return int(x, 0)
def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument("binfile")
parser.add_argument("--memsize", default = 1 << 24, type = anyint)
parser.add_argument("--cycles", default = int(1e4), type = anyint)
parser.add_argument("--dump", nargs=2, action="append", type=anyint)
parser.add_argument("--quiet", "-q", action="store_true")
args = parser.parse_args(argv)
if args.quiet:
mem = MemWithTBIO(args.memsize, io_log_fmt="{}")
else:
mem = MemWithTBIO(args.memsize)
mem.loadbin(open(args.binfile, "rb"), 0)
rv = RVCore(mem)
try:
for i in range(args.cycles):
rv.step(log=not args.quiet)
except TBExit as e:
print(f"Processor halted simulation with exit code {e}")
except BrokenPipeError as e:
sys.exit(0)
print(f"Ran for {i + 1} cycles")
for start, end in args.dump or []:
print(f"Dumping memory from {start:08x} to {end:08x}:")
for i, addr in enumerate(range(start, end)):
sep = "\n" if i % 16 == 15 else " "
sys.stdout.write(f"{mem.get8(addr):02x}{sep}")
print("")
if __name__ == "__main__":
main(sys.argv[1:])