-rwxr-xr-x 18022 saferewrite-20250505/syminsn-sparc32 raw
#!/usr/bin/env python3 import os import sys import io import resource import subprocess import angr import claripy binary = 'syminsn-sparc32-x' symflags = True maxsplit = 300 # max number of universes within an angr run hostlittle = True targetlittle = False flags = 'cf','vf','zf','nf' regs32 = [f'g{i}' for i in range(1,8)] regs32 += [f'o{i}' for i in range(8)] regs32 += [f'l{i}' for i in range(8)] regs32 += [f'i{i}' for i in range(8)] regs = [(r,32) for r in regs32] expectedoutputs = len(flags)+len(regs) invars = [] for r,rbits in regs: varname = 'in_'+r variable = claripy.BVS(varname,rbits,explicit_name=True) if hostlittle: variable = claripy.Reverse(variable) invars += [variable] for r in flags: if symflags: varname = 'in_'+r variable = claripy.BVS(varname,1,explicit_name=True) else: variable = claripy.BVV(0,1) invars += [variable.zero_extend(7)] insn = 0x82c08003 # flags,g1 = g2+g3+carry tweak = 0 if len(sys.argv) >= 2: insn = int(sys.argv[1],base=16) if len(sys.argv) >= 3: tweak = int(sys.argv[2]) print(f'exploring insn {hex(insn)} tweak {tweak}') sys.stdout.flush() insn = claripy.BVV(insn,32) if tweak > 0: imm = claripy.BVS('in_n',tweak,explicit_name=True) insn |= imm.zero_extend(32-tweak) if targetlittle: insn = claripy.Reverse(insn) invars += [insn] with open(f'{binary}.c','w') as f: f.write(r'''#include <inttypes.h> #include <string.h> #include <stdio.h> #include <stdarg.h> #include <errno.h> #include <unicorn/unicorn.h> #include <time.h> long long flags[] = { 20, // C 21, // V 22, // Z 23 // N } ; long long regs32[] = { UC_SPARC_REG_G1, UC_SPARC_REG_G2, UC_SPARC_REG_G3, UC_SPARC_REG_G4, UC_SPARC_REG_G5, UC_SPARC_REG_G6, UC_SPARC_REG_G7, UC_SPARC_REG_O0, UC_SPARC_REG_O1, UC_SPARC_REG_O2, UC_SPARC_REG_O3, UC_SPARC_REG_O4, UC_SPARC_REG_O5, UC_SPARC_REG_O6, UC_SPARC_REG_O7, UC_SPARC_REG_L0, UC_SPARC_REG_L1, UC_SPARC_REG_L2, UC_SPARC_REG_L3, UC_SPARC_REG_L4, UC_SPARC_REG_L5, UC_SPARC_REG_L6, UC_SPARC_REG_L7, UC_SPARC_REG_I0, UC_SPARC_REG_I1, UC_SPARC_REG_I2, UC_SPARC_REG_I3, UC_SPARC_REG_I4, UC_SPARC_REG_I5, UC_SPARC_REG_I6, UC_SPARC_REG_I7, } ; int vasprintf(char **x,const char *fmt,va_list ap) { int len; va_list apcopy; va_copy(apcopy,ap); len = vsnprintf(0,0,fmt,apcopy); va_end(apcopy); *x = 0; if (len < 0) return -1; len += 1; *x = malloc(len); if (!*x) return -1; return vsnprintf(*x,len,fmt,ap); } int clock_gettime(clockid_t clockid, struct timespec *tp) { tp->tv_sec = 0; tp->tv_nsec = 0; return 0; } int gettimeofday(struct timeval *tv,void *tz) { tv->tv_sec = 0; tv->tv_usec = 0; return 0; } int main() { uc_engine *uc; uc_err err; uint64_t insnpos = 0x10000; unsigned char r8; uint32_t r32; long long i; err = uc_open(UC_ARCH_SPARC,UC_MODE_SPARC32|UC_MODE_BIG_ENDIAN,&uc); if (err != UC_ERR_OK) { fprintf(stderr,"uc_open failed: %s\n",uc_strerror(err)); exit(111); } err = uc_mem_map(uc,insnpos,insnpos+4096,UC_PROT_ALL); if (err != UC_ERR_OK) { fprintf(stderr,"uc_mem_map failed: %s\n",uc_strerror(err)); exit(111); } for (long long i = 0;i < sizeof(regs32)/sizeof(regs32[0]);++i) { fread(&r32,4,1,stdin); uc_reg_write(uc,regs32[i],&r32); } r32 = 0; for (long long i = 0;i < sizeof(flags)/sizeof(flags[0]);++i) { fread(&r8,1,1,stdin); r8 &= 1; r32 |= ((uint32_t) r8) << flags[i]; } uc_reg_write(uc,UC_SPARC_REG_PSR,&r32); fread(&r32,4,1,stdin); uc_mem_write(uc,insnpos,&r32,4); err = uc_emu_start(uc,insnpos,insnpos+4,0,1); if (err != UC_ERR_OK) { fprintf(stderr,"uc_emu_start failed: %s\n",uc_strerror(err)); exit(111); } for (i = 0;i < sizeof(regs32)/sizeof(regs32[0]);++i) { uc_reg_read(uc,regs32[i],&r32); fwrite(&r32,4,1,stdout); } uc_reg_read(uc,UC_SPARC_REG_PSR,&r32); for (long long i = 0;i < sizeof(flags)/sizeof(flags[0]);++i) { r8 = 1 & (r32 >> flags[i]); fwrite(&r8,1,1,stdout); } return 0; } ''') subprocess.run(f'gcc -Os -o {binary} {binary}.c setjmp.s -lunicorn'.split(),check=True) def clarikey(e): try: # e.g. angr 9.2.102 return e.cache_key except: # e.g. angr 9.2.144 return e.hash() def values(terms,replacements): # input: replacements mapping clarikey to integers # output: dictionary V mapping clarikey to pairs (b,i) where i is a b-bit value # output includes all terms # _or_ output is None if terms use variables outside replacements V = {} W = set() # warnings def evaluate(t): if clarikey(t) in V: return True if t.op == 'BoolV': V[clarikey(t)] = 1,t.args[0] return True if t.op == 'BVV': V[clarikey(t)] = t.size(),t.args[0] return True if t.op == 'BVS': if clarikey(t) not in replacements: return False V[clarikey(t)] = t.size(),replacements[clarikey(t)].args[0] return True if t.op == 'Extract': assert len(t.args) == 3 top = t.args[0] bot = t.args[1] if not evaluate(t.args[2]): return False x0 = V[clarikey(t.args[2])] assert x0[0] > top assert top >= bot assert bot >= 0 V[clarikey(t)] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot) return True if t.op in ('SignExt','ZeroExt'): assert len(t.args) == 2 if not evaluate(t.args[1]): return False x0bits,x0 = V[clarikey(t.args[1])] extend = t.args[0] assert extend >= 0 if t.op == 'SignExt': if x0 >= (1<<(x0bits-1)): x0 -= 1<<x0bits x0 += 1<<(x0bits+extend) V[clarikey(t)] = x0bits+extend,x0 return True for a in t.args: if not evaluate(a): return False x = [V[clarikey(a)] for a in t.args] if t.op == 'Concat': y = 0 ybits = 0 for xbitsi,xi in x: y <<= xbitsi y += xi ybits += xbitsi V[clarikey(t)] = ybits,y return True if t.op == 'Reverse': assert len(x) == 1 xbits0,x0 = x[0] ybits = xbits0 assert ybits%8 == 0 y = 0 for i in range(ybits//8): y = (y<<8)+((x0>>(8*i))&255) V[clarikey(t)] = ybits,y return True if t.op in ('__eq__','__ne__'): assert len(x) == 2 assert x[0][0] == x[1][0] if t.op == '__eq__': V[clarikey(t)] = 1,(x[0][1]==x[1][1]) elif t.op == '__ne__': V[clarikey(t)] = 1,(x[0][1]!=x[1][1]) else: print('values: internal error %s, falling back to Z3' % t.op) return False return True if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'): bits = x[0][0] assert all(xi[0] == bits for xi in x) if t.op == '__and__': reduction = (lambda s,t:s&t) elif t.op == '__or__': reduction = (lambda s,t:s|t) elif t.op == '__xor__': reduction = (lambda s,t:s^t) elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits)) elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits)) elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits)) elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits)) elif t.op == '__rshift__': def reduction(s,t): flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR usigned = ssigned >> tsigned return usigned%(2**bits) elif t.op == '__mul__': reduction = (lambda s,t:(s*t)%(2**bits)) W.add('mul') elif t.op == '__floordiv__': def reduction(s,t): if t == 0: return 0 return (s//t)%(2**bits) W.add('div') elif t.op == '__mod__': def reduction(s,t): if t == 0: return s return (s%t)%(2**bits) W.add('div') elif t.op == 'SDiv': def reduction(s,t): if t == 0: return 0 flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # sdiv definition in Z3: # - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0]. # - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0]. if ssigned*tsigned >= 0: usigned = ssigned // tsigned else: usigned = -((-ssigned) // tsigned) return usigned%(2**bits) W.add('div') elif t.op == 'SMod': def reduction(s,t): if t == 0: return s flip = 2**(bits-1) ssigned = (s ^ flip) - flip tsigned = (t ^ flip) - flip # srem definition in Z3: # It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division. # The most significant bit (sign) of the result is equal to the most significant bit of \c t1. if ssigned*tsigned >= 0: usigned = ssigned - tsigned*(ssigned // tsigned) else: usigned = ssigned + tsigned*((-ssigned) // tsigned) return usigned%(2**bits) W.add('div') else: print('values: internal error %s, falling back to Z3' % t.op) return False V[clarikey(t)] = bits,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == '__invert__': assert len(x) == 1 bits = x[0][0] V[clarikey(t)] = bits,(1<<bits)-1-x[0][1] return True if t.op == 'Not': assert len(x) == 1 assert all(xi[0] == 1 for xi in x) V[clarikey(t)] = 1,1-x[0][1] return True if t.op in ('And','Or'): assert all(xi[0] == 1 for xi in x) if t.op == 'And': reduction = (lambda s,t:s*t) elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t) else: print('values: internal error %s, falling back to Z3' % t.op) return False V[clarikey(t)] = 1,functools.reduce(reduction,(xi[1] for xi in x)) return True if t.op == 'If': assert len(x) == 3 assert x[0][0] == 1 if x[0][1]: V[clarikey(t)] = x[1] else: V[clarikey(t)] = x[2] return True if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'): assert len(x) == 2 bits = x[0][0] assert bits == x[1][0] flip = 2**(bits-1) x0,x1 = x[0][1],x[1][1] if t.op == '__le__': V[clarikey(t)] = (1,x0<=x1) elif t.op == 'ULE': V[clarikey(t)] = (1,x0<=x1) elif t.op == '__lt__': V[clarikey(t)] = (1,x0<x1) elif t.op == 'ULT': V[clarikey(t)] = (1,x0<x1) elif t.op == '__ge__': V[clarikey(t)] = (1,x0>=x1) elif t.op == 'UGE': V[clarikey(t)] = (1,x0>=x1) elif t.op == '__gt__': V[clarikey(t)] = (1,x0>x1) elif t.op == 'UGT': V[clarikey(t)] = (1,x0>x1) elif t.op == 'SLE': V[clarikey(t)] = (1,(x0^flip)<=(x1^flip)) elif t.op == 'SLT': V[clarikey(t)] = (1,(x0^flip)<(x1^flip)) elif t.op == 'SGE': V[clarikey(t)] = (1,(x0^flip)>=(x1^flip)) elif t.op == 'SGT': V[clarikey(t)] = (1,(x0^flip)>(x1^flip)) else: print('values: internal error %s, falling back to Z3' % t.op) return False return True # XXX: add support for more print('values: unsupported operation %s, falling back to Z3' % t.op) return False # XXX: also add more validation for all of the above for t in terms: if not evaluate(t): return None,W return V,W def unroll_print(names,unrolled,f): walked = {} def walk(t): if clarikey(t) in walked: return walked[clarikey(t)] if t.op == 'BoolV': walknext = len(walked)+1 f.write('v%d = bool(%d)\n' % (walknext,t.args[0])) elif t.op == 'BVV': walknext = len(walked)+1 f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0])) elif t.op == 'BVS': walknext = len(walked)+1 f.write('v%d = %s\n' % (walknext,t.args[0])) elif t.op == 'Extract': assert len(t.args) == 3 input = 'v%d' % walk(t.args[2]) walknext = len(walked)+1 f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1])) elif t.op in ['SignExt','ZeroExt']: assert len(t.args) == 2 input = 'v%d' % walk(t.args[1]) walknext = len(walked)+1 f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0])) else: inputs = ['v%d' % walk(a) for a in t.args] walknext = len(walked)+1 f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs))) walked[clarikey(t)] = walknext return walknext for x in unrolled: walk(x) unrolledpos = 0 for varname in names: f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos]))) unrolledpos += 1 sys.setrecursionlimit(1000000) add_options = { angr.options.SYMBOLIC_WRITE_ADDRESSES, angr.options.CONSERVATIVE_READ_STRATEGY, angr.options.CONSERVATIVE_WRITE_STRATEGY, angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY, angr.options.ZERO_FILL_UNCONSTRAINED_REGISTERS, angr.options.CONSTRAINT_TRACKING_IN_SOLVER, angr.options.SIMPLIFY_EXPRS, angr.options.SIMPLIFY_MEMORY_READS, angr.options.SIMPLIFY_MEMORY_WRITES, angr.options.SIMPLIFY_REGISTER_READS, angr.options.SIMPLIFY_REGISTER_WRITES, angr.options.SIMPLIFY_CONSTRAINTS, angr.options.SIMPLIFY_RETS, } add_options |= angr.options.unicorn-{angr.options.UNICORN_SYM_REGS_SUPPORT} stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*invars),has_end=True) avoidsimprocedures = ( 'clock_gettime', '_setjmp', '__setjmp', '___setjmp', 'longjmp', '_longjmp', '__longjmp', '___longjmp', 'sigsetjmp', '_sigsetjmp', '__sigsetjmp', '___sigsetjmp', 'siglongjmp', '_siglongjmp', '__siglongjmp', '___siglongjmp', ) proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False,force_load_libs=['libunicorn.so']) class posix_memalign(angr.SimProcedure): def run(self,sim_ptr,sim_alignment,sim_size): result = self.state.heap._malloc(sim_size) self.state.memory.store(sim_ptr,result,size=8) return claripy.BVV(0,32) proj.hook_symbol('posix_memalign',posix_memalign()) class sysconf(angr.SimProcedure): def run(self,num): return claripy.BVV(os.sysconf(num.concrete_value),64) proj.hook_symbol('sysconf',sysconf()) class getpagesize(angr.SimProcedure): def run(self): return claripy.BVV(resource.getpagesize(),32) proj.hook_symbol('getpagesize',getpagesize()) class strerror(angr.SimProcedure): def run(self,num): e = os.strerror(num.concrete_value)+'\0' malloc = angr.SIM_PROCEDURES["libc"]["malloc"] where = self.inline_call(malloc,len(e)).ret_expr self.state.memory.store(where,claripy.BVV(e),size=len(e)) return where proj.hook_symbol('strerror',strerror()) proj.hook_symbol('mmap64',angr.procedures.posix.mmap.mmap()) state = proj.factory.full_init_state(add_options=add_options,args=[binary],stdin=stdin) print('options',state.options._options) simgr = proj.factory.simgr(state) printingsteps = False fast = True step = 0 while True: if printingsteps: print(f'step {step} fast {fast} deadended {len(simgr.deadended)} active {len(simgr.active)}') for epos,e in enumerate(simgr.active): print(f'active {epos}:') try: for ins in proj.factory.block(e.addr).capstone.insns: print(ins) except: print('exception') sys.stdout.flush() step += 1 if len(simgr.errored) > 0: raise Exception(simgr.errored) if len(simgr.deadended)+len(simgr.active) > maxsplit: constraintbuf = io.StringIO() constraintnames = [f'exited{i}' for i in range(len(simgr.deadended))] constraintnames += [f'active{i}' for i in range(len(simgr.active))] unroll_print(constraintnames,[claripy.And(*e.solver.constraints) for e in simgr.deadended+simgr.active],constraintbuf) raise Exception(f'limiting split to {maxsplit}\n{constraintbuf.getvalue()}') if len(simgr.active) == 0: break # XXX: should find a documented interface to do this if any(type(e.posix.fd[0]._read_pos) != int for e in simgr.active): state.options -= angr.options.unicorn fast = False simgr.step() exits = simgr.deadended assert len(exits) > 0 ok = True comment = f'have {len(exits)} exits\n' for epos,e in enumerate(exits): receivedpackets = len(e.posix.stdout.content) if receivedpackets != expectedoutputs: ok = False comment += f'exit {epos} stdout packets {receivedpackets} expecting {expectedoutputs} stderr packets {len(e.posix.stderr.content)}\n' for packetpos,packet in enumerate(e.posix.stderr.content): if not (packet[0].concrete and packet[1].concrete): comment += f'stderr symbolic packet {packetpos}\n' continue x = packet[0].concrete_value n = packet[1].concrete_value todo = b'' for i in range(n): todo = bytes(bytearray([x&255]))+todo x >>= 8 for line in todo.splitlines(): comment += f'stderr packet {packetpos} line: {line}\n' if not ok: constraintbuf = io.StringIO() constraintnames = [f'exited{i}' for i in range(len(exits))] unroll_print(constraintnames,[claripy.And(claripy.true,*e.solver.constraints) for e in exits],constraintbuf) raise Exception(comment+constraintbuf.getvalue()) if len(exits) > 1: mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits]) else: mergedexit = exits[0] names = [] results = [] packetpos = 0 for r,rbits in regs: names += ['out_'+r] packet = mergedexit.posix.stdout.content[packetpos] packetpos += 1 assert packet[1].concrete_value == rbits//8 xi = packet[0] if hostlittle: xi = claripy.Reverse(xi) results += [xi] for r in flags: names += ['out_'+r] packet = mergedexit.posix.stdout.content[packetpos] packetpos += 1 assert packet[1].concrete_value == 1 xi = packet[0] results += [claripy.Extract(0,0,xi)] unroll_print(names,results,sys.stdout)