-rwxr-xr-x 31359 saferewrite-20210904/analyze
#!/usr/bin/env python3
compilerlist = (
'clang -O1 -fwrapv -march=native',
'gcc -O3 -march=native -mtune=native',
)
numrandomtests = 16
avoidsimprocedures = (
'memcmp', # we want to test the real libc memcmp
)
typebits = {
'int8': 8,
'int16': 16,
'int32': 32,
'int64': 64,
}
import sys
import os
import shutil
import subprocess
import angr
import claripy
import multiprocessing
import random
import traceback
import functools
try:
os_cores = len(os.sched_getaffinity(0))
except AttributeError:
os_cores = multiprocessing.cpu_count()
os_cores = os.getenv('CORES',default=os_cores)
os_cores = int(os_cores)
if os_cores < 1: os_cores = 1
import resource
def cputime():
return resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime
def notetime(builddir,what,time):
print('%s seconds %s %.6f' % (builddir,what,time))
sys.stdout.flush()
with open('%s/analysis/seconds' % builddir,'a') as f:
f.write('%s %.6f\n' % (what,time))
def note(builddir,conclusion,contents=None):
print('%s %s' % (builddir,conclusion))
sys.stdout.flush()
with open('%s/analysis/%s' % (builddir,conclusion),'w') as f:
if contents != None:
f.write(str(contents))
sys.setrecursionlimit(1000000)
startdir = os.getcwd()
assert all(x in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./' for x in startdir)
shutil.rmtree('build',ignore_errors=True)
os.makedirs('build')
primitives = []
for o in 'src',:
o = o.strip()
if o == '': continue
if not os.path.isdir(o): continue
if os.stat('%s' % o).st_mode & 0o1000 == 0o1000:
print('%s sticky, skipping' % o)
sys.stdout.flush()
continue
for p in sorted(os.listdir(o)):
if not os.path.isdir('%s/%s' % (o,p)): continue
if os.stat('%s/%s' % (o,p)).st_mode & 0o1000 == 0o1000:
print('%s/%s sticky, skipping' % (o,p))
sys.stdout.flush()
continue
if not os.path.exists('%s/%s/api' % (o,p)):
print('%s/%s/api nonexistent, skipping' % (o,p))
sys.stdout.flush()
continue
primitives += [(o,p)]
op_api = {}
for o,p in primitives:
inputs = []
outputs = []
funargs = []
funargtypes = []
funname = None
funret = None
funrettype = 'void'
with open('%s/%s/api' % (o,p)) as f:
for line in f:
line = line.split()
if len(line) == 0: continue
if line[0] == 'call':
funname = line[1]
if line[0] == 'return':
bitsperentry = typebits[line[1]]
csymbol = line[2]
assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol)
entries = 1
outputs += [(csymbol,bitsperentry,entries)]
funret = 'alloc_%s'%csymbol
funrettype = 'uint%d_t'%bitsperentry
if line[0] in ('in','out','inout'):
bitsperentry = typebits[line[1]]
csymbol = line[2]
assert all(c in 'abcdefghijklmnopqrstuvwxyz' for c in csymbol)
if len(line) == 3:
pointer = False
entries = 1
else:
pointer = True
entries = int(line[3])
if line[0] in ('in','inout'):
inputs += [(csymbol,bitsperentry,entries)]
if line[0] in ('out','inout'):
outputs += [(csymbol,bitsperentry,entries)]
if pointer:
funargs += ['alloc_%s'%csymbol]
funargtypes += ['uint%d_t *' % bitsperentry]
else:
funargs += ['*alloc_%s'%csymbol]
funargtypes += ['uint%d_t' % bitsperentry]
# XXX: support constant inputs
op_api[o,p] = inputs,outputs,funargs,funargtypes,funname,funret,funrettype
def input_example_str(inputs,x):
xstr = ''
xpos = 0
for csymbol,bitsperentry,entries in inputs:
for e in range(entries):
varname = 'in_%s_%d'%(csymbol,e)
xstr += '%s = %d\n' % (varname,x[xpos])
xpos += 1
assert xpos == len(x)
return xstr
def output_example_str(outputs,y):
ystr = ''
ypos = 0
for csymbol,bitsperentry,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
ystr += '%s = %d\n' % (varname,y[ypos])
ypos += 1
assert ypos == len(y)
return ystr
reservedfilenames = (
'library.so.1',
'analysis',
'analysis-execute',
'analysis-execute.c',
'analysis-valgrind',
'analysis-valgrind.c',
'analysis-angr',
'analysis-angr.c',
)
opimplementations = {}
for o,p in primitives:
opimplementations[o,p] = []
for i in sorted(os.listdir('%s/%s' % (o,p))):
implementationdir = '%s/%s/%s' % (o,p,i)
if not os.path.isdir(implementationdir): continue
if os.stat(implementationdir).st_mode & 0o1000 == 0o1000:
print('%s/%s/%s sticky, skipping' % (o,p,i))
continue
files = sorted(os.listdir(implementationdir))
for f in files:
ok = True
if f in reservedfilenames:
print('%s/%s/%s/%s reserved filename' % (o,p,i,f))
ok = False
if any(fi not in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.' for fi in f):
print('%s/%s/%s/%s prohibited character' % (o,p,i,f))
ok = False
if not ok: continue
opimplementations[o,p] += [i]
for compiler in compilerlist:
compilerword = compiler.replace(' ','_').replace('=','_')
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
os.makedirs('build/%s/%s' % (p,i),exist_ok=True)
def compile(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p]
files = sorted(os.listdir(implementationdir))
cfiles = [x for x in files if x.endswith('.c')]
sfiles = [x for x in files if x.endswith('.s') or x.endswith('.S')]
files = cfiles + sfiles
shutil.copytree(implementationdir,builddir)
os.makedirs('%s/analysis' % builddir)
for bits in 8,16,32,64:
with open('%s/crypto_int%d.h' % (builddir,bits),'w') as f:
f.write('#include <inttypes.h>\n')
f.write('#define crypto_int%d int%d_t' % (bits,bits))
with open('%s/crypto_uint%d.h' % (builddir,bits),'w') as f:
f.write('#include <inttypes.h>\n')
f.write('#define crypto_uint%d uint%d_t' % (bits,bits))
for analysis in 'execute','valgrind','angr':
with open('%s/analysis-%s.c' % (builddir,analysis),'w') as f:
f.write('#include <stdio.h>\n')
f.write('#include <stdlib.h>\n')
f.write('#include <string.h>\n')
f.write('#include <inttypes.h>\n')
f.write('\n')
# function declaration
f.write('extern ')
if funrettype != None:
f.write('%s ' % funrettype)
f.write('%s(%s);\n' % (funname,','.join(funargtypes)))
f.write('\n')
for csymbol,bitsperentry,entries in inputs+outputs:
f.write('uint%d_t static_%s[%d];\n' % (bitsperentry,csymbol,entries))
f.write('\n')
f.write('int main(int argc,char **argv)\n')
f.write('{\n')
for csymbol,bitsperentry,entries in inputs:
f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (bitsperentry,csymbol,entries*bitsperentry/8))
for csymbol,bitsperentry,entries in outputs:
if (csymbol,bitsperentry,entries) not in inputs:
f.write(' uint%d_t *alloc_%s = malloc(%d);\n' % (bitsperentry,csymbol,entries*bitsperentry/8))
f.write('\n')
# XXX: resource limits
if analysis == 'execute':
for csymbol,bitsperentry,entries in inputs:
f.write(' for (long long i = 0;i < %d;++i) {\n' % entries)
f.write(' unsigned long long x;\n')
f.write(' if (scanf("%llu",&x) != 1) abort();\n')
f.write(' static_%s[i] = x;\n' % csymbol)
f.write(' }\n')
f.write('\n')
if analysis in ('execute','angr'):
for csymbol,bitsperentry,entries in inputs:
f.write(' for (long long i = 0;i < %d;++i)\n' % entries)
f.write(' alloc_%s[i] = static_%s[i];\n' % (csymbol,csymbol))
f.write('\n')
f.write(' ')
if funret != None:
f.write('%s[0] = ' % funret)
f.write('%s(%s);\n' % (funname,','.join(funargs)))
f.write('\n')
if analysis in ('execute','angr'):
for csymbol,bitsperentry,entries in outputs:
f.write(' for (long long i = 0;i < %d;++i)\n' % entries)
f.write(' static_%s[i] = alloc_%s[i];\n' % (csymbol,csymbol))
f.write('\n')
if analysis == 'execute':
for csymbol,bitsperentry,entries in outputs:
f.write(' for (long long i = 0;i < %d;++i) {\n' % entries)
f.write(' unsigned long long x = static_%s[i];\n' % csymbol)
f.write(' printf("%llu\\n",x);\n')
f.write(' }\n')
f.write(' fflush(stdout);\n')
f.write('\n')
f.write(' return 0;\n')
f.write('}\n')
# ----- compile
compiletime = -cputime()
objfiles = []
for f in files+['analysis-execute.c','analysis-valgrind.c','analysis-angr.c']:
command = '%s -Wall -fPIC -DCRYPTO_NAMESPACE(x)=x -c %s' % (compiler,f)
try:
proc = subprocess.Popen(command.split(),cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-compilefailed',traceback.format_exc())
return o,p,i,compiler,False
assert not err
if out != '':
note(builddir,'warning-compileoutput',out)
if proc.returncode:
note(builddir,'warning-compilefailed','exit code %s' % proc.returncode)
return o,p,i,compiler,False
if f in files:
objfiles += ['.'.join(f.split('.')[:-1]+['o'])]
compiletime += cputime()
notetime(builddir,'compile',compiletime)
# ----- link into executable
linktime = -cputime()
for analysis in 'execute','valgrind','angr':
static = True
if static:
command = 'gcc -no-pie -o analysis-%s analysis-%s.o' % (analysis,analysis)
command = command.split()
command += objfiles
else:
command = 'gcc -shared -Wl,-soname,library.so.1 -o library.so.1'
command = command.split()
command += objfiles
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-linkfailed',traceback.format_exc())
return o,p,i,compiler,False
if out != '':
note(builddir,'warning-linkoutput',out)
assert not err
if proc.returncode:
note(builddir,'warning-linkfailed','exit code %s' % proc.returncode)
return o,p,i,compiler,False
shutil.copy('%s/library.so.1' % builddir,'%s/library.so' % builddir)
command = 'gcc -no-pie -o analysis-%s analysis-%s.o -Wl,-rpath=%s/%s -L. -lrary' % (analysis,analysis,startdir,builddir)
command = command.split()
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
note(builddir,'warning-linkfailed',traceback.format_exc())
return o,p,i,compiler,False
if out != '':
note(builddir,'warning-linkoutput',out)
assert not err
if proc.returncode:
note(builddir,'warning-linkfailed','exit code %s' % proc.returncode)
return o,p,i,compiler,False
linktime += cputime()
notetime(builddir,'link',linktime)
return o,p,i,compiler,True
def wanttocompile():
for o,p in primitives:
for i in opimplementations[o,p]:
for compiler in compilerlist:
yield o,p,i,compiler
op_compiled = {}
for o,p in primitives:
op_compiled[o,p] = []
with multiprocessing.Pool(os_cores) as pool:
for o,p,i,compiler,ok in pool.starmap(compile,wanttocompile()):
if not ok: continue
op_compiled[o,p] += [(i,compiler)]
print('----- execute')
op_x = {}
for o,p in primitives:
inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p]
op_x[o,p] = []
for execution in range(numrandomtests):
x = []
for csymbol,bitsperentry,entries in inputs:
for e in range(entries):
if execution == 0:
r = 0
elif execution == 1:
r = 2**bitsperentry-1
else:
r = random.randrange(2**bitsperentry)
x += [r]
op_x[o,p] += [x]
def execute(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p]
executetime = -cputime()
results = []
command = ['./analysis-execute']
for x in op_x[o,p]:
xstr = ''
for r in x: xstr += '%d\n'%r
try:
proc = subprocess.Popen(command,cwd=builddir,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
ystr,err = proc.communicate(input=xstr)
except OSError:
note(builddir,'warning-executeerror',xstr)
return o,p,i,compiler,False
if proc.returncode != 0:
note(builddir,'warning-executefailed',xstr+'exit code %s' % proc.returncode)
return o,p,i,compiler,False
try:
y = [int(s) for s in ystr.splitlines()]
ypos = 0
for csymbol,bitsperentry,entries in outputs:
for e in range(entries):
assert y[ypos] >= 0
assert y[ypos] < 2**bitsperentry
ypos += 1
assert ypos == len(y)
except ValueError:
note(builddir,'warning-executebadformat',input_example_str(inputs,x)+output_example_str(outputs,y))
return o,p,i,compiler,False
results += [y]
executetime += cputime()
notetime(builddir,'execute',executetime)
return o,p,i,compiler,results
def wanttoexecute():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
yield o,p,i,compiler
opic_y = {}
with multiprocessing.Pool(os_cores) as pool:
for o,p,i,compiler,results in pool.starmap(execute,wanttoexecute()):
if results == False: continue
opic_y[o,p,i,compiler] = results
print('----- valgrind (can take some time)')
def valgrind(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
valgrindtime = -cputime()
command = ['valgrind','-q','--error-exitcode=99','./analysis-valgrind']
valgrindstatus = None
try:
proc = subprocess.Popen(command,cwd=builddir,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
out,err = proc.communicate()
except OSError:
valgrindstatus = 'warning-valgrinderror'
if valgrindstatus == None:
assert not err
if proc.returncode == 99:
valgrindstatus = 'unsafe-valgrindfailure'
elif proc.returncode != 0:
valgrindstatus = 'warning-valgrinderror'
elif out.find('client request') >= 0:
valgrindstatus = 'unsafe-valgrindfailure'
if valgrindstatus != None:
note(builddir,valgrindstatus)
valgrindtime += cputime()
notetime(builddir,'valgrind',valgrindtime)
def wanttovalgrind():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
yield o,p,i,compiler
with multiprocessing.Pool(os_cores) as pool:
list(pool.starmap(valgrind,wanttovalgrind()))
print('----- unroll (can take tons of time)')
# XXX: could do this in parallel with valgrind
# XXX: unrolled can be huge; pass through disk instead of RAM
def values(terms,replacements):
# input: replacements mapping cache_key to integers
# output: dictionary V mapping cache_key to pairs (b,i) where i is a b-bit value
# output includes all terms
# _or_ output is None if terms use variables outside replacements
V = {}
def evaluate(t):
if t.cache_key in V:
return True
if t.op == 'BoolV':
V[t.cache_key] = 1,t.args[0]
return True
if t.op == 'BVV':
V[t.cache_key] = t.size(),t.args[0]
return True
if t.op == 'BVS':
if t.cache_key not in replacements: return False
V[t.cache_key] = t.size(),replacements[t.cache_key].args[0]
return True
if t.op == 'Extract':
assert len(t.args) == 3
top = t.args[0]
bot = t.args[1]
if not evaluate(t.args[2]): return False
x0 = V[t.args[2].cache_key]
assert x0[0] > top
assert top >= bot
assert bot >= 0
V[t.cache_key] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot)
return True
if t.op in ('SignExt','ZeroExt'):
assert len(t.args) == 2
if not evaluate(t.args[1]): return False
x0bits,x0 = V[t.args[1].cache_key]
extend = t.args[0]
assert extend >= 0
if t.op == 'SignExt':
if x0 >= (1<<(x0bits-1)):
x0 -= 1<<x0bits
x0 += 1<<(x0bits+extend)
V[t.cache_key] = x0bits+extend,x0
return True
for a in t.args:
if not evaluate(a): return False
x = [V[a.cache_key] for a in t.args]
if t.op == 'Concat':
y = 0
ybits = 0
for xbitsi,xi in x:
y <<= xbitsi
y += xi
ybits += xbitsi
V[t.cache_key] = ybits,y
return True
if t.op in ('__eq__','__ne__'):
assert len(x) == 2
assert x[0][0] == x[1][0]
if t.op == '__eq__': V[t.cache_key] = 1,(x[0][1]==x[1][1])
elif t.op == '__ne__': V[t.cache_key] = 1,(x[0][1]==x[1][1])
else: return False
return True
if t.op in ('__add__','__mul__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'):
bits = x[0][0]
assert all(xi[0] == bits for xi in x)
if t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits))
elif t.op == '__mul__': reduction = (lambda s,t:(s*t)%(2**bits))
elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits))
elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits))
elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits))
elif t.op == '__rshift__':
def reduction(s,t):
flip = 2**(bits-1)
ssigned = (s ^ flip) - flip
tsigned = (t ^ flip) - flip
assert 0 <= tsigned
assert tsigned < bits
usigned = ssigned >> tsigned
return (usigned + flip) ^ flip
elif t.op == '__and__': reduction = (lambda s,t:s&t)
elif t.op == '__or__': reduction = (lambda s,t:s|t)
elif t.op == '__xor__': reduction = (lambda s,t:s^t)
else: return False
V[t.cache_key] = bits,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == '__invert__':
assert len(x) == 1
bits = x[0][0]
V[t.cache_key] = bits,(1<<bits)-1-x[0][1]
return True
if t.op == 'Not':
assert len(x) == 1
assert all(xi[0] == 1 for xi in x)
V[t.cache_key] = 1,1-x[0][1]
return True
if t.op in ('And','Or'):
assert all(xi[0] == 1 for xi in x)
if t.op == 'And': reduction = (lambda s,t:s*t)
elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t)
else: return False
V[t.cache_key] = 1,functools.reduce(reduction,(xi[1] for xi in x))
return True
if t.op == 'If':
assert len(x) == 3
assert x[0][0] == 1
if x[0][1]:
V[t.cache_key] = x[1]
else:
V[t.cache_key] = x[2]
return True
if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'):
assert len(x) == 2
bits = x[0][0]
assert bits == x[1][0]
flip = 2**(bits-1)
x0,x1 = x[0][1],x[1][1]
if t.op == '__le__': V[t.cache_key] = (1,x0<=x1)
elif t.op == 'ULE': V[t.cache_key] = (1,x0<=x1)
elif t.op == '__lt__': V[t.cache_key] = (1,x0<x1)
elif t.op == 'ULT': V[t.cache_key] = (1,x0<x1)
elif t.op == '__ge__': V[t.cache_key] = (1,x0>=x1)
elif t.op == 'UGE': V[t.cache_key] = (1,x0>=x1)
elif t.op == '__gt__': V[t.cache_key] = (1,x0>x1)
elif t.op == 'UGT': V[t.cache_key] = (1,x0>x1)
elif t.op == 'SLE': V[t.cache_key] = (1,(x0^flip)<=(x1^flip))
elif t.op == 'SLT': V[t.cache_key] = (1,(x0^flip)<(x1^flip))
elif t.op == 'SGE': V[t.cache_key] = (1,(x0^flip)>=(x1^flip))
elif t.op == 'SGT': V[t.cache_key] = (1,(x0^flip)>(x1^flip))
else: return False
return True
# XXX: add support for more
print('values: unsupported operation %s, falling back to Z3' % t.op)
return False
# XXX: also add more validation for all of the above
for t in terms:
if not evaluate(t): return None
return V
def unroll_print(outputs,unrolled,f):
walked = {}
def walk(t):
if t in walked: return walked[t]
if t.op == 'BoolV':
walknext = len(walked)+1
f.write('v%d = bool(%d)\n' % (walknext,t.args[0]))
elif t.op == 'BVV':
walknext = len(walked)+1
f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0]))
elif t.op == 'BVS':
walknext = len(walked)+1
f.write('v%d = %s\n' % (walknext,t.args[0]))
elif t.op == 'Extract':
assert len(t.args) == 3
input = 'v%d' % walk(t.args[2])
walknext = len(walked)+1
f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1]))
elif t.op in ['SignExt','ZeroExt']:
assert len(t.args) == 2
input = 'v%d' % walk(t.args[1])
walknext = len(walked)+1
f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0]))
else:
inputs = ['v%d' % walk(a) for a in t.args]
walknext = len(walked)+1
f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs)))
walked[t] = walknext
return walknext
for x in unrolled:
walk(x)
unrolledpos = 0
for csymbol,bitsperentry,entries in outputs:
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos])))
unrolledpos += 1
def unroll_inputvars(inputs):
result = []
for csymbol,bitsperentry,entries in inputs:
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = claripy.BVS(varname,bitsperentry,explicit_name=True)
result += [(varname,variable)]
return result
# XXX: probably better to merge into unroll()
def unroll_worker(binary,inputs,outputs):
results = []
proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures)
state = proj.factory.full_init_state()
state.options |= {angr.options.LAZY_SOLVES}
state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY}
state.options |= {angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS}
state.options -= {angr.options.SIMPLIFY_EXPRS}
state.options -= {angr.options.SIMPLIFY_REGISTER_WRITES}
state.options -= {angr.options.SIMPLIFY_MEMORY_WRITES}
state.options -= {angr.options.SIMPLIFY_REGISTER_READS}
state.options -= {angr.options.SIMPLIFY_MEMORY_READS}
for csymbol,bitsperentry,entries in inputs:
xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr
for i in range(entries):
varname = 'in_%s_%d'%(csymbol,i)
variable = claripy.BVS(varname,bitsperentry,explicit_name=True)
if bitsperentry == 8:
state.mem[xaddr+i].char = variable
elif bitsperentry == 16:
state.mem[xaddr+2*i].short = variable
elif bitsperentry == 32:
state.mem[xaddr+4*i].int = variable
elif bitsperentry == 64:
state.mem[xaddr+8*i].long = variable
simgr = proj.factory.simgr(state)
simgr.run()
if len(simgr.errored) > 0:
return -1,False,simgr.errored
exits = simgr.deadended
assert len(exits) > 0
# cannot be safe if there are multiple exits
# for equivalence tests we'll merge exits below
mergedconstraints = []
for epos,e in enumerate(exits):
mergedconstraint = e.solver.true
for c in e.solver.constraints:
mergedconstraint = e.solver.And(mergedconstraint,c)
mergedconstraints += [mergedconstraint]
resultpos = 0
for csymbol,bitsperentry,entries in outputs:
xaddr = proj.loader.find_symbol('static_%s'%csymbol).rebased_addr
for i in range(entries):
if bitsperentry == 8:
xi = e.mem[xaddr+i].char.resolved
elif bitsperentry == 16:
xi = e.mem[xaddr+2*i].short.resolved
elif bitsperentry == 32:
xi = e.mem[xaddr+4*i].int.resolved
elif bitsperentry == 64:
xi = e.mem[xaddr+8*i].long.resolved
if epos == 0:
assert len(results) == resultpos
results += [xi]
else:
results[resultpos] = e.solver.If(mergedconstraint,xi,results[resultpos])
resultpos += 1
assert resultpos == len(results)
assert len(mergedconstraints) == len(exits)
ispartition = True
# are mergedconstraints a partition of all universes?
# i.e.: in each universe, exactly one of the constraints is satisfied?
s = claripy.Solver()
for c in mergedconstraints:
s.add(claripy.Not(c))
if s.satisfiable():
ispartition = False
for i in range(len(exits)):
for j in range(i):
s = claripy.Solver()
s.add(mergedconstraints[i])
s.add(mergedconstraints[j])
if s.satisfiable():
ispartition = False
return len(exits),ispartition,results
def unroll(o,p,i,compiler):
compilerword = compiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p]
unrolltime = -cputime()
numexits,ispartition,unrolled = unroll_worker('%s/analysis-angr'%builddir,inputs,outputs)
if numexits < 1:
note(builddir,'warning-unrollerror',unrolled)
return o,p,i,compiler,False
if not ispartition:
note(builddir,'warning-unrollnotpartition')
return o,p,i,compiler,False
if numexits > 1:
note(builddir,'unsafe-unrollsplit-%d'%numexits)
with open('%s/analysis/unrolled' % builddir,'w') as f:
unroll_print(outputs,unrolled,f)
okvars = set(vname for vname,v in unroll_inputvars(inputs))
usedvars = set(v for x in unrolled for v in x.variables)
if not usedvars.issubset(okvars):
note(builddir,'warning-unrollmem')
if not okvars.issubset(usedvars):
note(builddir,'warning-unusedinputs')
for x,y in zip(op_x[o,p],opic_y[o,p,i,compiler]):
# cpu gave us outputs y given inputs x
# does this match unrolled?
replacements = {}
xpos = 0
for csymbol,bitsperentry,entries in inputs:
for e in range(entries):
varname = 'in_%s_%d'%(csymbol,e)
variable = claripy.BVS(varname,bitsperentry,explicit_name=True)
replacements[variable.cache_key] = claripy.BVV(x[xpos],bitsperentry)
xpos += 1
assert xpos == len(x)
V = None
try:
V = values(unrolled,replacements)
except AssertionError:
note(builddir,'warning-valuesfailed',traceback.format_exc())
# proceed with z3 fallback below
if V != None:
mismatch = all(yi == V[unrolledi.cache_key] for (yi,unrolledi) in zip(y,unrolled))
else:
# fall back on Z3 for figuring this out
s = claripy.Solver()
mismatch = claripy.false
for yi,unrolledi in zip(y,unrolled):
mismatch = claripy.Or(mismatch,unrolledi.replace_dict(replacements) != yi)
s.add(mismatch)
mismatch = s.satisfiable()
if mismatch:
notestr = ''
for vname,v in unroll_inputvars(inputs):
notestr += '%s = %s\n' % (vname,s.eval(v,1)[0])
pos = 0
for csymbol,bitsperentry,entries in outputs:
for e in range(entries):
varname = 'out_%s_%d'%(csymbol,e)
notestr += 'executed_%s = %s\n' % (varname,y[pos])
notestr += 'unrolled_%s = %s\n' % (varname,s.eval(unrolled[pos],1)[0])
pos += 1
note(builddir,'warning-unrollmismatch',notestr)
return o,p,i,compiler,False
unrolltime += cputime()
notetime(builddir,'unroll',unrolltime)
return o,p,i,compiler,unrolled
def wanttounroll():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
if (o,p,i,compiler) in opic_y:
yield o,p,i,compiler
opic_unrolled = {}
with multiprocessing.Pool(os_cores) as pool:
for o,p,i,compiler,unrolled in pool.starmap(unroll,wanttounroll()):
if unrolled == False: continue
opic_unrolled[o,p,i,compiler] = unrolled
print('----- compareunrolled (can take tons of time)')
def compareunrolled(o,p,i,compiler,source,sourcecompiler):
compilerword = compiler.replace(' ','_').replace('=','_')
sourcecompilerword = sourcecompiler.replace(' ','_').replace('=','_')
implementationdir = '%s/%s/%s' % (o,p,i)
builddir = 'build/%s/%s/%s' % (p,i,compilerword)
inputs,outputs,funargs,funargtypes,funname,funret,funrettype = op_api[o,p]
for pos,(x,y,z) in enumerate(zip(op_x[o,p],opic_y[o,p,i,compiler],opic_y[o,p,source,sourcecompiler])):
if y != z:
xstr = input_example_str(inputs,x)
note(builddir,'unsafe-randomtest-%d-differentfrom-%s-%s' % (pos,source,sourcecompilerword),xstr)
# could return at this point to save time
# but to help validate symbolic testing we also want to see symbolic testing fail
equivtime = -cputime()
u1 = opic_unrolled[o,p,source,sourcecompiler]
u2 = opic_unrolled[o,p,i,compiler]
assert len(u1) == len(u2)
# XXX: allow other equivalence-testing techniques
s = claripy.Solver()
different = claripy.false
for u1j,u2j in zip(u1,u2):
different = claripy.Or(different,u1j != u2j)
s.add(different)
try:
mismatch = s.satisfiable()
except claripy.errors.ClaripyZ3Error:
# avoid crashing on the sort of bug fixed in https://github.com/angr/angr/pull/2887
note(builddir,'warning-z3failed',traceback.format_exc())
return
if mismatch:
# angr documentation says:
# "If you don't add any constraints between two queries, the results will be consistent with each other."
example = ''
for vname,v in unroll_inputvars(inputs):
example += '%s = %s\n' % (vname,s.eval(v,1)[0])
unrolledpos = 0
for csymbol,bitsperentry,entries in outputs:
for i in range(entries):
varname = 'out_%s_%d'%(csymbol,i)
example += 'source_%s = %s\n' % (varname,s.eval(u1[unrolledpos],1)[0])
example += 'target_%s = %s\n' % (varname,s.eval(u2[unrolledpos],1)[0])
unrolledpos += 1
note(builddir,'unsafe-differentfrom-%s-%s' % (source,sourcecompilerword),example)
else:
note(builddir,'equals-%s-%s' % (source,sourcecompilerword))
equivtime += cputime()
notetime(builddir,'equiv',equivtime)
def wanttocompareunrolled():
for o,p in primitives:
for i,compiler in op_compiled[o,p]:
source = 'ref' # XXX: allow each implementation to choose source
if i == 'ref':
sourcecompiler = compilerlist[0] # XXX: maybe also allow choice
else:
sourcecompiler = compiler
if (o,p,i,compiler) not in opic_unrolled: continue
if (o,p,source,sourcecompiler) not in opic_unrolled: continue
# XXX: could also do self-tests
if (o,p,source,sourcecompiler) == (o,p,i,compiler): continue
yield o,p,i,compiler,source,sourcecompiler
with multiprocessing.Pool(os_cores) as pool:
list(pool.starmap(compareunrolled,wanttocompareunrolled()))