diff options
Diffstat (limited to 'test/perf/stresstest.py')
| -rw-r--r-- | test/perf/stresstest.py | 174 |
1 files changed, 174 insertions, 0 deletions
diff --git a/test/perf/stresstest.py b/test/perf/stresstest.py new file mode 100644 index 000000000..cf9404f53 --- /dev/null +++ b/test/perf/stresstest.py @@ -0,0 +1,174 @@ +import gc +import sys +import timeit +import cProfile + +from sqlalchemy import MetaData, Table, Column +from sqlalchemy.types import * +from sqlalchemy.orm import mapper, clear_mappers + +metadata = MetaData() + +def gen_table(num_fields, field_type, metadata): + return Table('test', metadata, + Column('id', Integer, primary_key=True), + *[Column("field%d" % fnum, field_type) + for fnum in range(num_fields)]) + +def insert(test_table, num_fields, num_records, genvalue, verbose=True): + if verbose: + print "building insert values...", + sys.stdout.flush() + values = [dict(("field%d" % fnum, genvalue(rnum, fnum)) + for fnum in range(num_fields)) + for rnum in range(num_records)] + if verbose: + print "inserting...", + sys.stdout.flush() + def db_insert(): + test_table.insert().execute(values) + sys.modules['__main__'].db_insert = db_insert + timing = timeit.timeit("db_insert()", + "from __main__ import db_insert", + number=1) + if verbose: + print "%s" % round(timing, 3) + +def check_result(results, num_fields, genvalue, verbose=True): + if verbose: + print "checking...", + sys.stdout.flush() + for rnum, row in enumerate(results): + expected = tuple([rnum + 1] + + [genvalue(rnum, fnum) for fnum in range(num_fields)]) + assert row == expected, "got: %s\nexpected: %s" % (row, expected) + return True + +def avgdev(values, comparison): + return sum(value - comparison for value in values) / len(values) + +def nicer_res(values, printvalues=False): + if printvalues: + print values + min_time = min(values) + return round(min_time, 3), round(avgdev(values, min_time), 2) + +def profile_func(func_name, verbose=True): + if verbose: + print "profiling...", + sys.stdout.flush() + cProfile.run('%s()' % func_name, 'prof') + +def time_func(func_name, num_tests=1, verbose=True): + if verbose: + print "timing...", + sys.stdout.flush() + timings = timeit.repeat('%s()' % func_name, + "from __main__ import %s" % func_name, + number=num_tests, repeat=5) + avg, dev = nicer_res(timings) + if verbose: + print "%s (%s)" % (avg, dev) + else: + print avg + +def profile_and_time(func_name, num_tests=1): + profile_func(func_name) + time_func(func_name, num_tests) + +def iter_results(raw_results): + return [tuple(row) for row in raw_results] + +def getattr_results(raw_results): + return [ + (r.id, + r.field0, r.field1, r.field2, r.field3, r.field4, + r.field5, r.field6, r.field7, r.field8, r.field9) + for r in raw_results] + +def fetchall(test_table): + def results(): + return test_table.select().order_by(test_table.c.id).execute() \ + .fetchall() + return results + +def hashable_set(l): + hashables = [] + for o in l: + try: + hash(o) + hashables.append(o) + except: + pass + return set(hashables) + +def prepare(field_type, genvalue, engineurl='sqlite://', + num_fields=10, num_records=1000, freshdata=True, verbose=True): + global metadata + metadata.clear() + metadata.bind = engineurl + test_table = gen_table(num_fields, field_type, metadata) + if freshdata: + metadata.drop_all() + metadata.create_all() + insert(test_table, num_fields, num_records, genvalue, verbose) + return test_table + +def time_dbfunc(test_table, test_func, genvalue, + class_=None, + getresults_func=None, + num_fields=10, num_records=1000, num_tests=1, + check_results=check_result, profile=True, + check_leaks=True, print_leaks=False, verbose=True): + if verbose: + print "testing '%s'..." % test_func.__name__, + sys.stdout.flush() + if class_ is not None: + clear_mappers() + mapper(class_, test_table) + if getresults_func is None: + getresults_func = fetchall(test_table) + def test(): + return test_func(getresults_func()) + sys.modules['__main__'].test = test + if check_leaks: + gc.collect() + objects_before = gc.get_objects() + num_objects_before = len(objects_before) + hashable_objects_before = hashable_set(objects_before) +# gc.set_debug(gc.DEBUG_LEAK) + if check_results: + check_results(test(), num_fields, genvalue, verbose) + if check_leaks: + gc.collect() + objects_after = gc.get_objects() + num_objects_after = len(objects_after) + num_leaks = num_objects_after - num_objects_before + hashable_objects_after = hashable_set(objects_after) + diff = hashable_objects_after - hashable_objects_before + ldiff = len(diff) + if print_leaks and ldiff < num_records: + print "\n*** hashable objects leaked (%d) ***" % ldiff + print '\n'.join(map(str, diff)) + print "***\n" + + if num_leaks > num_records: + print "(leaked: %d !)" % num_leaks, + if profile: + profile_func('test', verbose) + time_func('test', num_tests, verbose) + +def profile_and_time_dbfunc(test_func, field_type, genvalue, + class_=None, + getresults_func=None, + engineurl='sqlite://', freshdata=True, + num_fields=10, num_records=1000, num_tests=1, + check_results=check_result, profile=True, + check_leaks=True, print_leaks=False, verbose=True): + test_table = prepare(field_type, genvalue, engineurl, + num_fields, num_records, freshdata, verbose) + time_dbfunc(test_table, test_func, genvalue, class_, + getresults_func, + num_fields, num_records, num_tests, + check_results, profile, + check_leaks, print_leaks, verbose) |
