#!/usr/bin/env python3 # Minimal RISC-V interpreter, RV32IM + Zicsr only, with trace disassembly import argparse import sys XLEN = 32 XLEN_MASK = (1 << XLEN) - 1 def extract(bits, msb, lsb): return (bits & (1 << msb + 1) - 1) >> lsb def sext(bits, sign_bit): return (bits & (1 << sign_bit + 1) - 1) - ((bits & 1 << sign_bit) << 1) def concat_extract(bits, msb_lsb_pairs, signed=True): accum = 0 accum_count = 0 for msb, lsb in msb_lsb_pairs: accum = (accum << (msb - lsb + 1)) | extract(bits, msb, lsb) accum_count += msb - lsb + 1 if signed: accum = sext(accum, accum_count - 1) return accum # Note these handy functions are not used much in the main loop, because CPython is unable # to inline them. This and similar changes results in a ~3x performance increase. :( def imm_i(instr): # return concat_extract(instr, ((31, 20),)) return (instr >> 20) - (instr >> 19 & 0x1000) def imm_s(instr): # return concat_extract(instr, ((31, 25), (11, 7))) return (instr >> 20 & 0xfe0) + (instr >> 7 & 0x1f) - (instr >> 19 & 0x1000) def imm_u(instr): # return concat_extract(instr, ((31, 12),)) << 12 return instr & 0xfffff000 - (instr << 1 & 0x100000000) def imm_b(instr): return concat_extract(instr, ((31, 31), (7, 7), (30, 25), (11, 8))) << 1 def imm_j(instr): return concat_extract(instr, ((31, 31), (19, 12), (20, 20), (30, 21))) << 1 class FlatMemory: def __init__(self, size): self.size = size self.mem = [0] * (size >> 2) # Reads are unsigned. Writes allow signed or unsigned values and convert # implicitly to unsigned. Multi-byte accesses are little-endian. def get8(self, addr): assert(addr >= 0 and addr < self.size) return self.mem[addr >> 2] >> (addr & 0x3) * 8 & 0xff def put8(self, addr, data): assert(addr >= 0 and addr < self.size) assert(data >= -1 << 7 and data < 1 << 8) self.mem[addr >> 2] &= ~(0xff << 8 * (addr & 0x3)) self.mem[addr >> 2] |= (data & 0xff) << 8 * (addr % 4) def get16(self, addr): return self.mem[addr >> 2] >> (addr & 0x2) * 8 & 0xffff def put16(self, addr, data): assert(data >= -1 << 15 and data < 1 << 16) for i in range(2): self.put8(addr + i, data >> 8 * i & 0xff) def get32(self, addr): assert(addr >= 0 and addr + 3 < self.size) return self.mem[addr >> 2] def put32(self, addr, data): assert(data >= -1 << 31 and data < 1 << 32) assert(addr >= 0 and addr + 3 < self.size) self.mem[addr >> 2] = data & 0xffff_ffff def loadbin(self, data, offs): if type(data) not in (bytes, bytearray): # must be fh assert(data.mode == "rb") data = data.read() assert(offs + len(data) < self.size) for i, b in enumerate(data): self.put8(offs + i, b) class TBExit(Exception): pass class MemWithTBIO(FlatMemory): TB_IO_BASE = 0x80000000 TB_IO_PRINT_CHAR = TB_IO_BASE + 0x0 TB_IO_PRINT_INT = TB_IO_BASE + 0x4 TB_IO_EXIT = TB_IO_BASE + 0x8 def __init__(self, size, io_log_fmt="IO: {}\n"): super().__init__(size) self.io_log_fmt = io_log_fmt def put32(self, addr, data): if addr < self.TB_IO_BASE: super().put32(addr, data) elif addr == self.TB_IO_PRINT_CHAR: sys.stdout.write(self.io_log_fmt.format(chr(data))) elif addr == self.TB_IO_PRINT_INT: sys.stdout.write(self.io_log_fmt.format(f"{data:08x}")) elif addr == self.TB_IO_EXIT: raise TBExit(data) else: print(f"Unknown IO address {addr:08x}") class RVCSR: WRITE = 0 WRITE_SET = 1 WRITE_CLEAR = 2 MSCRATCH = 0x340 MCYCLE = 0xb00 MTIME = 0xb01 MINSTRET = 0xb02 def __init__(self): self.mcycle = 0 self.mscratch = 0 def step(self): self.mcycle += 1 def read(self, addr, side_effect=True): # Close your eyes if addr in (RVCSR.MCYCLE, RVCSR.MTIME, RVCSR.MINSTRET): return self.mcycle elif addr == RVCSR.MSCRATCH: return self.mscratch else: return 0 def write(self, addr, data, op=0): if op == RVCSR.WRITE_CLEAR: data = self.read(addr, side_effect=False) & ~data elif op == RVCSR.WRITE_SET: data = self.read(addr, side_effect=False) | data if addr == RVCSR.MCYCLE: self.mcycle = data elif addr == RVCSR.MSCRATCH: self.mscratch = data class RVCore: def __init__(self, mem, reset_vector=0x40): self.regs = [0] * 32 self.mem = mem self.pc = reset_vector self.csr = RVCSR() self.stage3_result = None def step(self, instr=None, log=True, cycle_accurate=True): if instr is None: instr = self.mem.mem[self.pc >> 2] regnum_rs1 = instr >> 15 & 0x1f regnum_rs2 = instr >> 20 & 0x1f regnum_rd = instr >> 7 & 0x1f rs1 = self.regs[regnum_rs1] rs2 = self.regs[regnum_rs2] rd_wdata = None pc_wdata = None log_disasm = None instr_invalid = False stall_cycles = 0 stage3_result_next = None opc = instr >> 2 & 0x1f funct3 = instr >> 12 & 0x7 funct7 = instr >> 25 & 0x7f OPC_LOAD = 0b00_000 OPC_MISC_MEM = 0b00_011 OPC_OP_IMM = 0b00_100 OPC_AUIPC = 0b00_101 OPC_STORE = 0b01_000 OPC_OP = 0b01_100 OPC_LUI = 0b01_101 OPC_BRANCH = 0b11_000 OPC_JALR = 0b11_001 OPC_JAL = 0b11_011 OPC_SYSTEM = 0b11_100 if opc == OPC_OP: if log: log_reg_str = f" x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}" if funct7 == 0b00_00000: if funct3 == 0b000: if log: log_disasm = "add" + log_reg_str rd_wdata = rs1 + rs2 elif funct3 == 0b001: if log: log_disasm = "sll" + log_reg_str rd_wdata = rs1 << (rs2 & 0x1f) elif funct3 == 0b010: if log: log_disasm = "slt" + log_reg_str rd_wdata = rs1 < rs2 elif funct3 == 0b011: if log: log_disasm = "sltu" + log_reg_str rd_wdata = (rs1 & XLEN_MASK) < (rs2 & XLEN_MASK) elif funct3 == 0b100: if log: log_disasm = "xor" + log_reg_str rd_wdata = rs1 ^ rs2 elif funct3 == 0b101: if log: log_disasm = "srl" + log_reg_str rd_wdata = (rs1 & XLEN_MASK) >> (rs2 & 0x1f) elif funct3 == 0b110: if log: log_disasm = "or" + log_reg_str rd_wdata = rs1 | rs2 elif funct3 == 0b111: if log: log_disasm = "and" + log_reg_str rd_wdata = rs1 & rs2 else: instr_invalid = True elif funct7 == 0b01_00000: if funct3 == 0b000: if log: log_disasm = "sub" + log_reg_str rd_wdata = rs1 - rs2 elif funct3 == 0b101: if log: log_disasm = "sra" + log_reg_str rd_wdata = rs1 >> (rs2 & 0x1f) else: instr_invalid = True elif funct7 == 0b00_00001: if funct3 < 0b100: if log: mul_instr_name = {0b000: "mul", 0b001: "mulh", 0b010: "mulhsu", 0b011: "mulhu"}[funct3] log_disasm = f"{mul_instr_name} x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}" mul_op_a = rs1 & XLEN_MASK if funct3 == 0b011 else rs1 mul_op_b = rs2 & XLEN_MASK if funct3 in (0b010, 0b011) else rs2 mul_result = mul_op_a * mul_op_b if funct3 != 0b000: mul_result >>= 32 rd_wdata = sext(mul_result, XLEN - 1) stage3_result_next = regnum_rd else: if log: div_instr_name = {0b100: "div", 0b101: "divu", 0b110: "rem", 0b111: "remu"}[funct3] log_disasm = f"{div_instr_name} x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}" if funct3 == 0b100: rd_wdata = -1 if rs2 == 0 else int(rs1 / rs2) elif funct3 == 0b101: rd_wdata = -1 if rs2 == 0 else sext((rs1 & XLEN_MASK) // (rs2 & XLEN_MASK), XLEN - 1) elif funct3 == 0b110: rd_wdata = rs1 if rs2 == 0 else rs1 - int(rs1 / rs2) * rs2 elif funct3 == 0b111: rd_wdata = rs1 if rs2 == 0 else sext((rs1 & XLEN_MASK) % (rs2 & XLEN_MASK), XLEN - 1) else: instr_invalid = True stall_cycles = 18 else: instr_invalid = True if not instr_invalid: stall_cycles += regnum_rs1 == self.stage3_result or regnum_rs2 == self.stage3_result elif opc == OPC_OP_IMM: imm = (instr >> 20) - (instr >> 19 & 0x1000) # imm_i(instr) if funct3 == 0b000: if log: log_disasm = f"addi x{regnum_rd}, x{regnum_rs1}, {imm}" rd_wdata = rs1 + imm elif funct3 == 0b010: if log: log_disasm = f"slti x{regnum_rd}, x{regnum_rs1}, {imm}" rd_wdata = 1 * (rs1 < imm) elif funct3 == 0b011: if log: log_disasm = f"sltiu x{regnum_rd}, x{regnum_rs1}, {imm & XLEN_MASK}" rd_wdata = 1 * (rs1 & XLEN_MASK < imm & XLEN_MASK) elif funct3 == 0b100: if log: log_disasm = f"xori x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}" rd_wdata = rs1 ^ imm elif funct3 == 0b110: if log: log_disasm = f"ori x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}" rd_wdata = rs1 | imm elif funct3 == 0b111: if log: log_disasm = f"andi x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}" rd_wdata = rs1 & imm elif funct3 == 0b001 or funct3 == 0b101: # shamt is regnum_rs2 if funct7 == 0b00_00000 and funct3 == 0b001: if log: log_disasm = f"slli x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}" rd_wdata = rs1 << regnum_rs2 elif funct7 == 0b00_00000 and funct3 == 0b101: if log: log_disasm = f"srli x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}" rd_wdata = (rs1 & XLEN_MASK) >> regnum_rs2 elif funct7 == 0b01_00000 and funct3 == 0b101: if log: log_disasm = f"srai x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}" rd_wdata = rs1 >> regnum_rs2 else: instr_invalid = True else: instr_invalid = True if not instr_invalid: stall_cycles += regnum_rs1 == self.stage3_result elif opc == OPC_JAL: rd_wdata = self.pc + 4 # pc_wdata = self.pc + imm_j(instr) pc_wdata = self.pc + (instr >> 20 & 0x7fe) + (instr >> 9 & 0x800) + (instr & 0xff000) - (instr >> 11 & 0x100000) if log: log_disasm = f"jal x{regnum_rd}, {pc_wdata & XLEN_MASK:08x}" stall_cycles = 1 elif opc == OPC_JALR: stall_cycles += regnum_rs1 == self.stage3_result imm = imm_i(instr) if log: log_disasm = f"jalr x{regnum_rd}, x{regnum_rs1}, {imm}" rd_wdata = self.pc + 4 # JALR clears LSB always pc_wdata = (rs1 + imm) & -2 stall_cycles = 1 elif opc == OPC_BRANCH: # target = self.pc + imm_b(instr) target = self.pc + (instr >> 7 & 0x1e) + (instr >> 20 & 0x7e0) + (instr << 4 & 0x800) - (instr >> 19 & 0x1000) taken = False if log: log_branch_str = f" x{regnum_rs1}, x{regnum_rs2}, {target:08x}" if funct3 == 0b000: if log: log_disasm = "beq" + log_branch_str taken = rs1 == rs2 elif funct3 == 0b001: if log: log_disasm = "bne" + log_branch_str taken = rs1 != rs2 elif funct3 == 0b100: if log: log_disasm = "blt" + log_branch_str taken = rs1 < rs2 elif funct3 == 0b101: if log: log_disasm = "bge" + log_branch_str taken = rs1 >= rs2 elif funct3 == 0b110: if log: log_disasm = "bltu" + log_branch_str taken = (rs1 & XLEN_MASK) < (rs2 & XLEN_MASK) elif funct3 == 0b111: if log: log_disasm = "bgeu" + log_branch_str taken = (rs1 & XLEN_MASK) >= (rs2 & XLEN_MASK) else: instr_invalid = True if not instr_invalid: stall_cycles += regnum_rs1 == self.stage3_result or regnum_rs2 == self.stage3_result if taken: pc_wdata = target stall_cycles += 1 elif opc == OPC_LOAD: imm = imm_i(instr) if log: log_load_str = f" x{regnum_rd}, {imm}(x{regnum_rs1})" load_addr = imm + rs1 & XLEN_MASK if funct3 == 0b000: if log: log_disasm = "lb" + log_load_str rd_wdata = self.mem.get8(load_addr) rd_wdata -= rd_wdata << 1 & 0x100 elif funct3 == 0b001: if log: log_disasm = "lh" + log_load_str rd_wdata = self.mem.get16(load_addr) rd_wdata -= rd_wdata << 1 & 0x10000 elif funct3 == 0b010: if log: log_disasm = "lw" + log_load_str rd_wdata = self.mem.get32(load_addr) rd_wdata -= rd_wdata << 1 & 0x100000000 elif funct3 == 0b100: if log: log_disasm = "lbu" + log_load_str rd_wdata = self.mem.get8(load_addr) elif funct3 == 0b101: if log: log_disasm = "lhu" + log_load_str rd_wdata = self.mem.get16(load_addr) else: instr_invalid = True if not instr_invalid: stall_cycles += regnum_rs1 == self.stage3_result stage3_result_next = regnum_rd elif opc == OPC_STORE: imm = imm_s(instr) if log: log_store_str = f" x{regnum_rs2}, {imm}(x{regnum_rs1})" store_addr = imm + rs1 & XLEN_MASK if funct3 == 0b000: if log: log_disasm = "sb" + log_store_str self.mem.put8(store_addr, rs2 & (1 << 8) - 1) elif funct3 == 0b001: if log: log_disasm = "sh" + log_store_str self.mem.put16(store_addr, rs2 & (1 << 16) - 1) elif funct3 == 0b010: if log: log_disasm = "sw" + log_store_str self.mem.put32(store_addr, rs2) else: instr_invalid = True if not instr_invalid: stall_cycles += regnum_rs1 == self.stage3_result elif opc == OPC_LUI: imm = imm_u(instr) if log: log_disasm = f"lui x{regnum_rd}, 0x{(imm & XLEN_MASK) >> 12:05x}" rd_wdata = imm elif opc == OPC_AUIPC: imm = imm_u(instr) if log: log_disasm = f"auipc x{regnum_rd}, 0x{(imm & XLEN_MASK) >> 12:05x}" rd_wdata = self.pc + imm elif opc == OPC_SYSTEM: csr_addr = extract(instr, 31, 20) if funct3 == 0b000 and funct7 == 0b00_00000: if regnum_rs2 == 0: if log: log_disasm = "*UNHANDLED* ecall" pass elif regnum_rs2 == 1: if log: log_disasm = "*UNHANDLED* ebreak" pass else: instr_invalid = True elif funct3 in (0b001, 0b010, 0b011): if log: instr_name = {0b001: "csrrw", 0b010: "csrrs", 0b011: "csrrc"}[funct3] log_disasm = f"{instr_name} x{regnum_rd}, 0x{csr_addr:x}, x{regnum_rs2}" csr_write_op = funct3 - 0b001 if csr_write_op != RVCSR.WRITE or regnum_rd != 0: rd_wdata = self.csr.read(csr_addr) if csr_write_op == RVCSR.WRITE or rs2 != 0: self.csr.write(csr_addr, rs2, op=csr_write_op) stall_cycles += regnum_rs1 == self.stage3_result elif funct3 in (0b101, 0b110, 0b111): if log: instr_name = {0b101: "csrrwi", 0b110: "csrrsi", 0b111: "csrrci"}[funct3] log_disasm = f"{instr_name} x{regnum_rd}, 0x{csr_addr:x}, 0x{regnum_rs2:x}" csr_write_op = funct3 = 0b101 if csr_write_op != RVCSR.WRITE or regnum_rd != 0: rd_wdata = self.csr.read(csr_addr) if csr_write_op == RVCSR.WRITE or regnum_rs2 != 0: self.csr.write(csr_addr, rs2, op=csr_write_op) else: instr_invalid = True elif opc == OPC_MISC_MEM: if instr == 0b0000_0000_0000_00000_001_00000_0001111: if log: log_disasm = "fence.i" pass elif (instr & 0b1111_0000_0000_11111_111_11111_1111111) == 0b0000_0000_0000_00000_000_00000_0001111: if log: log_disasm = f"fence {extract(instr, 27, 24):04b}, {extract(instr, 23, 20):04b}" pass else: instr_invalid = True if log: log_str = f"{self.pc:08x}: ({instr:08x}) {log_disasm if log_disasm is not None else '':<25}" if rd_wdata is not None and regnum_rd != 0: log_str += f" : x{regnum_rd:<2} <- {rd_wdata & XLEN_MASK:08x}" else: log_str += " : " + 15 * " " if pc_wdata is not None: log_str += f" : pc <- {pc_wdata & XLEN_MASK:08x}" else: log_str += " :" print(log_str) if rd_wdata is not None and regnum_rd != 0: self.regs[regnum_rd] = (rd_wdata & 0xffffffff) - (rd_wdata << 1 & 0x100000000) if pc_wdata is None: self.pc = self.pc + 4 else: self.pc = pc_wdata if instr_invalid: print(f"Invalid instruction at {self.pc:08x}: {instr:08x}") self.csr.step() if cycle_accurate: self.csr.mcycle += stall_cycles self.stage3_result = stage3_result_next if stage3_result_next != 0 else None def anyint(x): return int(x, 0) def main(argv): parser = argparse.ArgumentParser() parser.add_argument("binfile") parser.add_argument("--memsize", default = 1 << 24, type = anyint) parser.add_argument("--cycles", default = int(1e4), type = anyint) parser.add_argument("--dump", nargs=2, action="append", type=anyint) parser.add_argument("--quiet", "-q", action="store_true") args = parser.parse_args(argv) if args.quiet: mem = MemWithTBIO(args.memsize, io_log_fmt="{}") else: mem = MemWithTBIO(args.memsize) mem.loadbin(open(args.binfile, "rb"), 0) rv = RVCore(mem) try: for i in range(args.cycles): rv.step(log=not args.quiet) except TBExit as e: print(f"Processor halted simulation with exit code {e}") except BrokenPipeError as e: sys.exit(0) print(f"Ran for {i + 1} cycles") for start, end in args.dump or []: print(f"Dumping memory from {start:08x} to {end:08x}:") for i, addr in enumerate(range(start, end)): sep = "\n" if i % 16 == 15 else " " sys.stdout.write(f"{mem.get8(addr):02x}{sep}") print("") if __name__ == "__main__": main(sys.argv[1:])