-rwxr-xr-x 18022 saferewrite-20250505/syminsn-sparc32 raw
#!/usr/bin/env python3
import os
import sys
import io
import resource
import subprocess
import angr
import claripy
binary = 'syminsn-sparc32-x'
symflags = True
maxsplit = 300
# max number of universes within an angr run
hostlittle = True
targetlittle = False
flags = 'cf','vf','zf','nf'
regs32 = [f'g{i}' for i in range(1,8)]
regs32 += [f'o{i}' for i in range(8)]
regs32 += [f'l{i}' for i in range(8)]
regs32 += [f'i{i}' for i in range(8)]
regs = [(r,32) for r in regs32]
expectedoutputs = len(flags)+len(regs)
invars = []
for r,rbits in regs:
varname = 'in_'+r
variable = claripy.BVS(varname,rbits,explicit_name=True)
if hostlittle:
variable = claripy.Reverse(variable)
invars += [variable]
for r in flags:
if symflags:
varname = 'in_'+r
variable = claripy.BVS(varname,1,explicit_name=True)
else:
variable = claripy.BVV(0,1)
invars += [variable.zero_extend(7)]
insn = 0x82c08003 # flags,g1 = g2+g3+carry
tweak = 0
if len(sys.argv) >= 2:
insn = int(sys.argv[1],base=16)
if len(sys.argv) >= 3:
tweak = int(sys.argv[2])
print(f'exploring insn {hex(insn)} tweak {tweak}')
sys.stdout.flush()
insn = claripy.BVV(insn,32)
if tweak > 0:
imm = claripy.BVS('in_n',tweak,explicit_name=True)
insn |= imm.zero_extend(32-tweak)
if targetlittle:
insn = claripy.Reverse(insn)
invars += [insn]
with open(f'{binary}.c','w') as f:
f.write(r'''#include <inttypes.h>
#include <string.h>
#include <stdio.h>
#include <stdarg.h>
#include <errno.h>
#include <unicorn/unicorn.h>
#include <time.h>
long long flags[] = {
20, // C
21, // V
22, // Z
23 // N
} ;
long long regs32[] = {
UC_SPARC_REG_G1,
UC_SPARC_REG_G2,
UC_SPARC_REG_G3,
UC_SPARC_REG_G4,
UC_SPARC_REG_G5,
UC_SPARC_REG_G6,
UC_SPARC_REG_G7,
UC_SPARC_REG_O0,
UC_SPARC_REG_O1,
UC_SPARC_REG_O2,
UC_SPARC_REG_O3,
UC_SPARC_REG_O4,
UC_SPARC_REG_O5,
UC_SPARC_REG_O6,
UC_SPARC_REG_O7,
UC_SPARC_REG_L0,
UC_SPARC_REG_L1,
UC_SPARC_REG_L2,
UC_SPARC_REG_L3,
UC_SPARC_REG_L4,
UC_SPARC_REG_L5,
UC_SPARC_REG_L6,
UC_SPARC_REG_L7,
UC_SPARC_REG_I0,
UC_SPARC_REG_I1,
UC_SPARC_REG_I2,
UC_SPARC_REG_I3,
UC_SPARC_REG_I4,
UC_SPARC_REG_I5,
UC_SPARC_REG_I6,
UC_SPARC_REG_I7,
} ;
int vasprintf(char **x,const char *fmt,va_list ap)
{
int len;
va_list apcopy;
va_copy(apcopy,ap);
len = vsnprintf(0,0,fmt,apcopy);
va_end(apcopy);
*x = 0;
if (len < 0) return -1;
len += 1;
*x = malloc(len);
if (!*x) return -1;
return vsnprintf(*x,len,fmt,ap);
}
int clock_gettime(clockid_t clockid, struct timespec *tp)
{
tp->tv_sec = 0;
tp->tv_nsec = 0;
return 0;
}
int gettimeofday(struct timeval *tv,void *tz)
{
tv->tv_sec = 0;
tv->tv_usec = 0;
return 0;
}
int main()
{
uc_engine *uc;
uc_err err;
uint64_t insnpos = 0x10000;
unsigned char r8;
uint32_t r32;
long long i;
err = uc_open(UC_ARCH_SPARC,UC_MODE_SPARC32|UC_MODE_BIG_ENDIAN,&uc);
if (err != UC_ERR_OK) {
fprintf(stderr,"uc_open failed: %s\n",uc_strerror(err));
exit(111);
}
err = uc_mem_map(uc,insnpos,insnpos+4096,UC_PROT_ALL);
if (err != UC_ERR_OK) {
fprintf(stderr,"uc_mem_map failed: %s\n",uc_strerror(err));
exit(111);
}
for (long long i = 0;i < sizeof(regs32)/sizeof(regs32[0]);++i) {
fread(&r32,4,1,stdin);
uc_reg_write(uc,regs32[i],&r32);
}
r32 = 0;
for (long long i = 0;i < sizeof(flags)/sizeof(flags[0]);++i) {
fread(&r8,1,1,stdin);
r8 &= 1;
r32 |= ((uint32_t) r8) << flags[i];
}
uc_reg_write(uc,UC_SPARC_REG_PSR,&r32);
fread(&r32,4,1,stdin);
uc_mem_write(uc,insnpos,&r32,4);
err = uc_emu_start(uc,insnpos,insnpos+4,0,1);
if (err != UC_ERR_OK) {
fprintf(stderr,"uc_emu_start failed: %s\n",uc_strerror(err));
exit(111);
}
for (i = 0;i < sizeof(regs32)/sizeof(regs32[0]);++i) {
uc_reg_read(uc,regs32[i],&r32);
fwrite(&r32,4,1,stdout);
}
uc_reg_read(uc,UC_SPARC_REG_PSR,&r32);
for (long long i = 0;i < sizeof(flags)/sizeof(flags[0]);++i) {
r8 = 1 & (r32 >> flags[i]);
fwrite(&r8,1,1,stdout);
}
return 0;
}
''')
subprocess.run(f'gcc -Os -o {binary} {binary}.c setjmp.s -lunicorn'.split(),check=True)
def clarikey(e):
try:
# e.g. angr 9.2.102
return e.cache_key
except:
# e.g. angr 9.2.144
return e.hash()
def values(terms,replacements):
# input: replacements mapping clarikey to integers
# output: dictionary V mapping clarikey to pairs (b,i) where i is a b-bit value
# output includes all terms
# _or_ output is None if terms use variables outside replacements
V = {}
W = set() # warnings
def evaluate(t):
if clarikey(t) in V:
return True
if t.op == 'BoolV':
V[clarikey(t)] = 1,t.args[0]
return True
if t.op == 'BVV':
V[clarikey(t)] = t.size(),t.args[0]
return True
if t.op == 'BVS':
if clarikey(t) not in replacements: return False
V[clarikey(t)] = t.size(),replacements[clarikey(t)].args[0]
return True
if t.op == 'Extract':
assert len(t.args) == 3
top = t.args[0]
bot = t.args[1]
if not evaluate(t.args[2]): return False
x0 = V[clarikey(t.args[2])]
assert x0[0] > top
assert top >= bot
assert bot >= 0
V[clarikey(t)] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot)
return True
if t.op in ('SignExt','ZeroExt'):
assert len(t.args) == 2
if not evaluate(t.args[1]): return False
x0bits,x0 = V[clarikey(t.args[1])]
extend = t.args[0]
assert extend >= 0
if t.op == 'SignExt':
if x0 >= (1<<(x0bits-1)):
x0 -= 1<<x0bits
x0 += 1<<(x0bits+extend)
V[clarikey(t)] = x0bits+extend,x0
return True
for a in t.args:
if not evaluate(a): return False
x = [V[clarikey(a)] for a in t.args]
if t.op == 'Concat':
y = 0
ybits = 0
for xbitsi,xi in x:
y <<= xbitsi
y += xi
ybits += xbitsi
V[clarikey(t)] = ybits,y
return True
if t.op == 'Reverse':
assert len(x) == 1
xbits0,x0 = x[0]
ybits = xbits0
assert ybits%8 == 0
y = 0
for i in range(ybits//8):
y = (y<<8)+((x0>>(8*i))&255)
V[clarikey(t)] = ybits,y
return True
if t.op in ('__eq__','__ne__'):
assert len(x) == 2
assert x[0][0] == x[1][0]
if t.op == '__eq__': V[clarikey(t)] = 1,(x[0][1]==x[1][1])
elif t.op == '__ne__': V[clarikey(t)] = 1,(x[0][1]!=x[1][1])
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
return True
if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'):
bits = x[0][0]
assert all(xi[0] == bits for xi in x)
if t.op == '__and__': reduction = (lambda s,t:s&t)
elif t.op == '__or__': reduction = (lambda s,t:s|t)
elif t.op == '__xor__': reduction = (lambda s,t:s^t)
elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits))
elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits))
elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits))
elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits))
elif t.op == '__rshift__':
def reduction(s,t):
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR
usigned = ssigned >> tsigned
return usigned%(2**bits)
elif t.op == '__mul__':
reduction = (lambda s,t:(s*t)%(2**bits))
W.add('mul')
elif t.op == '__floordiv__':
def reduction(s,t):
if t == 0: return 0
return (s//t)%(2**bits)
W.add('div')
elif t.op == '__mod__':
def reduction(s,t):
if t == 0: return s
return (s%t)%(2**bits)
W.add('div')
elif t.op == 'SDiv':
def reduction(s,t):
if t == 0: return 0
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# sdiv definition in Z3:
# - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0].
# - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0].
if ssigned*tsigned >= 0:
usigned = ssigned // tsigned
else:
usigned = -((-ssigned) // tsigned)
return usigned%(2**bits)
W.add('div')
elif t.op == 'SMod':
def reduction(s,t):
if t == 0: return s
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
# srem definition in Z3:
# It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division.
# The most significant bit (sign) of the result is equal to the most significant bit of \c t1.
if ssigned*tsigned >= 0:
usigned = ssigned - tsigned*(ssigned // tsigned)
else:
usigned = ssigned + tsigned*((-ssigned) // tsigned)
return usigned%(2**bits)
W.add('div')
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
V[clarikey(t)] = bits,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == '__invert__':
assert len(x) == 1
bits = x[0][0]
V[clarikey(t)] = bits,(1<<bits)-1-x[0][1]
return True
if t.op == 'Not':
assert len(x) == 1
assert all(xi[0] == 1 for xi in x)
V[clarikey(t)] = 1,1-x[0][1]
return True
if t.op in ('And','Or'):
assert all(xi[0] == 1 for xi in x)
if t.op == 'And': reduction = (lambda s,t:s*t)
elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t)
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
V[clarikey(t)] = 1,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == 'If':
assert len(x) == 3
assert x[0][0] == 1
if x[0][1]:
V[clarikey(t)] = x[1]
else:
V[clarikey(t)] = x[2]
return True
if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'):
assert len(x) == 2
bits = x[0][0]
assert bits == x[1][0]
flip = 2**(bits-1)
x0,x1 = x[0][1],x[1][1]
if t.op == '__le__': V[clarikey(t)] = (1,x0<=x1)
elif t.op == 'ULE': V[clarikey(t)] = (1,x0<=x1)
elif t.op == '__lt__': V[clarikey(t)] = (1,x0<x1)
elif t.op == 'ULT': V[clarikey(t)] = (1,x0<x1)
elif t.op == '__ge__': V[clarikey(t)] = (1,x0>=x1)
elif t.op == 'UGE': V[clarikey(t)] = (1,x0>=x1)
elif t.op == '__gt__': V[clarikey(t)] = (1,x0>x1)
elif t.op == 'UGT': V[clarikey(t)] = (1,x0>x1)
elif t.op == 'SLE': V[clarikey(t)] = (1,(x0^flip)<=(x1^flip))
elif t.op == 'SLT': V[clarikey(t)] = (1,(x0^flip)<(x1^flip))
elif t.op == 'SGE': V[clarikey(t)] = (1,(x0^flip)>=(x1^flip))
elif t.op == 'SGT': V[clarikey(t)] = (1,(x0^flip)>(x1^flip))
else:
print('values: internal error %s, falling back to Z3' % t.op)
return False
return True
# XXX: add support for more
print('values: unsupported operation %s, falling back to Z3' % t.op)
return False
# XXX: also add more validation for all of the above
for t in terms:
if not evaluate(t): return None,W
return V,W
def unroll_print(names,unrolled,f):
walked = {}
def walk(t):
if clarikey(t) in walked: return walked[clarikey(t)]
if t.op == 'BoolV':
walknext = len(walked)+1
f.write('v%d = bool(%d)\n' % (walknext,t.args[0]))
elif t.op == 'BVV':
walknext = len(walked)+1
f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0]))
elif t.op == 'BVS':
walknext = len(walked)+1
f.write('v%d = %s\n' % (walknext,t.args[0]))
elif t.op == 'Extract':
assert len(t.args) == 3
input = 'v%d' % walk(t.args[2])
walknext = len(walked)+1
f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1]))
elif t.op in ['SignExt','ZeroExt']:
assert len(t.args) == 2
input = 'v%d' % walk(t.args[1])
walknext = len(walked)+1
f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0]))
else:
inputs = ['v%d' % walk(a) for a in t.args]
walknext = len(walked)+1
f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs)))
walked[clarikey(t)] = walknext
return walknext
for x in unrolled:
walk(x)
unrolledpos = 0
for varname in names:
f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos])))
unrolledpos += 1
sys.setrecursionlimit(1000000)
add_options = {
angr.options.SYMBOLIC_WRITE_ADDRESSES,
angr.options.CONSERVATIVE_READ_STRATEGY,
angr.options.CONSERVATIVE_WRITE_STRATEGY,
angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY,
angr.options.ZERO_FILL_UNCONSTRAINED_REGISTERS,
angr.options.CONSTRAINT_TRACKING_IN_SOLVER,
angr.options.SIMPLIFY_EXPRS,
angr.options.SIMPLIFY_MEMORY_READS,
angr.options.SIMPLIFY_MEMORY_WRITES,
angr.options.SIMPLIFY_REGISTER_READS,
angr.options.SIMPLIFY_REGISTER_WRITES,
angr.options.SIMPLIFY_CONSTRAINTS,
angr.options.SIMPLIFY_RETS,
}
add_options |= angr.options.unicorn-{angr.options.UNICORN_SYM_REGS_SUPPORT}
stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*invars),has_end=True)
avoidsimprocedures = (
'clock_gettime',
'_setjmp',
'__setjmp',
'___setjmp',
'longjmp',
'_longjmp',
'__longjmp',
'___longjmp',
'sigsetjmp',
'_sigsetjmp',
'__sigsetjmp',
'___sigsetjmp',
'siglongjmp',
'_siglongjmp',
'__siglongjmp',
'___siglongjmp',
)
proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False,force_load_libs=['libunicorn.so'])
class posix_memalign(angr.SimProcedure):
def run(self,sim_ptr,sim_alignment,sim_size):
result = self.state.heap._malloc(sim_size)
self.state.memory.store(sim_ptr,result,size=8)
return claripy.BVV(0,32)
proj.hook_symbol('posix_memalign',posix_memalign())
class sysconf(angr.SimProcedure):
def run(self,num):
return claripy.BVV(os.sysconf(num.concrete_value),64)
proj.hook_symbol('sysconf',sysconf())
class getpagesize(angr.SimProcedure):
def run(self):
return claripy.BVV(resource.getpagesize(),32)
proj.hook_symbol('getpagesize',getpagesize())
class strerror(angr.SimProcedure):
def run(self,num):
e = os.strerror(num.concrete_value)+'\0'
malloc = angr.SIM_PROCEDURES["libc"]["malloc"]
where = self.inline_call(malloc,len(e)).ret_expr
self.state.memory.store(where,claripy.BVV(e),size=len(e))
return where
proj.hook_symbol('strerror',strerror())
proj.hook_symbol('mmap64',angr.procedures.posix.mmap.mmap())
state = proj.factory.full_init_state(add_options=add_options,args=[binary],stdin=stdin)
print('options',state.options._options)
simgr = proj.factory.simgr(state)
printingsteps = False
fast = True
step = 0
while True:
if printingsteps:
print(f'step {step} fast {fast} deadended {len(simgr.deadended)} active {len(simgr.active)}')
for epos,e in enumerate(simgr.active):
print(f'active {epos}:')
try:
for ins in proj.factory.block(e.addr).capstone.insns:
print(ins)
except:
print('exception')
sys.stdout.flush()
step += 1
if len(simgr.errored) > 0:
raise Exception(simgr.errored)
if len(simgr.deadended)+len(simgr.active) > maxsplit:
constraintbuf = io.StringIO()
constraintnames = [f'exited{i}' for i in range(len(simgr.deadended))]
constraintnames += [f'active{i}' for i in range(len(simgr.active))]
unroll_print(constraintnames,[claripy.And(*e.solver.constraints) for e in simgr.deadended+simgr.active],constraintbuf)
raise Exception(f'limiting split to {maxsplit}\n{constraintbuf.getvalue()}')
if len(simgr.active) == 0:
break
# XXX: should find a documented interface to do this
if any(type(e.posix.fd[0]._read_pos) != int for e in simgr.active):
state.options -= angr.options.unicorn
fast = False
simgr.step()
exits = simgr.deadended
assert len(exits) > 0
ok = True
comment = f'have {len(exits)} exits\n'
for epos,e in enumerate(exits):
receivedpackets = len(e.posix.stdout.content)
if receivedpackets != expectedoutputs:
ok = False
comment += f'exit {epos} stdout packets {receivedpackets} expecting {expectedoutputs} stderr packets {len(e.posix.stderr.content)}\n'
for packetpos,packet in enumerate(e.posix.stderr.content):
if not (packet[0].concrete and packet[1].concrete):
comment += f'stderr symbolic packet {packetpos}\n'
continue
x = packet[0].concrete_value
n = packet[1].concrete_value
todo = b''
for i in range(n):
todo = bytes(bytearray([x&255]))+todo
x >>= 8
for line in todo.splitlines():
comment += f'stderr packet {packetpos} line: {line}\n'
if not ok:
constraintbuf = io.StringIO()
constraintnames = [f'exited{i}' for i in range(len(exits))]
unroll_print(constraintnames,[claripy.And(claripy.true,*e.solver.constraints) for e in exits],constraintbuf)
raise Exception(comment+constraintbuf.getvalue())
if len(exits) > 1:
mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits])
else:
mergedexit = exits[0]
names = []
results = []
packetpos = 0
for r,rbits in regs:
names += ['out_'+r]
packet = mergedexit.posix.stdout.content[packetpos]
packetpos += 1
assert packet[1].concrete_value == rbits//8
xi = packet[0]
if hostlittle:
xi = claripy.Reverse(xi)
results += [xi]
for r in flags:
names += ['out_'+r]
packet = mergedexit.posix.stdout.content[packetpos]
packetpos += 1
assert packet[1].concrete_value == 1
xi = packet[0]
results += [claripy.Extract(0,0,xi)]
unroll_print(names,results,sys.stdout)