diff options
Diffstat (limited to 'poky/meta/lib/oeqa/core/threaded.py')
-rw-r--r-- | poky/meta/lib/oeqa/core/threaded.py | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/poky/meta/lib/oeqa/core/threaded.py b/poky/meta/lib/oeqa/core/threaded.py new file mode 100644 index 0000000000..2cafe03a21 --- /dev/null +++ b/poky/meta/lib/oeqa/core/threaded.py @@ -0,0 +1,275 @@ +# Copyright (C) 2017 Intel Corporation +# Released under the MIT license (see COPYING.MIT) + +import threading +import multiprocessing +import queue +import time + +from unittest.suite import TestSuite + +from oeqa.core.loader import OETestLoader +from oeqa.core.runner import OEStreamLogger, OETestResult, OETestRunner +from oeqa.core.context import OETestContext + +class OETestLoaderThreaded(OETestLoader): + def __init__(self, tc, module_paths, modules, tests, modules_required, + filters, process_num=0, *args, **kwargs): + super(OETestLoaderThreaded, self).__init__(tc, module_paths, modules, + tests, modules_required, filters, *args, **kwargs) + + self.process_num = process_num + + def discover(self): + suite = super(OETestLoaderThreaded, self).discover() + + if self.process_num <= 0: + self.process_num = min(multiprocessing.cpu_count(), + len(suite._tests)) + + suites = [] + for _ in range(self.process_num): + suites.append(self.suiteClass()) + + def _search_for_module_idx(suites, case): + """ + Cases in the same module needs to be run + in the same thread because PyUnit keeps track + of setUp{Module, Class,} and tearDown{Module, Class,}. + """ + + for idx in range(self.process_num): + suite = suites[idx] + for c in suite._tests: + if case.__module__ == c.__module__: + return idx + + return -1 + + def _search_for_depend_idx(suites, depends): + """ + Dependency cases needs to be run in the same + thread, because OEQA framework look at the state + of dependant test to figure out if skip or not. + """ + + for idx in range(self.process_num): + suite = suites[idx] + + for case in suite._tests: + if case.id() in depends: + return idx + return -1 + + def _get_best_idx(suites): + sizes = [len(suite._tests) for suite in suites] + return sizes.index(min(sizes)) + + def _fill_suites(suite): + idx = -1 + for case in suite: + if isinstance(case, TestSuite): + _fill_suites(case) + else: + idx = _search_for_module_idx(suites, case) + + depends = {} + if 'depends' in self.tc._registry: + depends = self.tc._registry['depends'] + + if idx == -1 and case.id() in depends: + case_depends = depends[case.id()] + idx = _search_for_depend_idx(suites, case_depends) + + if idx == -1: + idx = _get_best_idx(suites) + + suites[idx].addTest(case) + _fill_suites(suite) + + suites_tmp = suites + suites = [] + for suite in suites_tmp: + if len(suite._tests) > 0: + suites.append(suite) + + return suites + +class OEStreamLoggerThreaded(OEStreamLogger): + _lock = threading.Lock() + buffers = {} + + def write(self, msg): + tid = threading.get_ident() + + if not tid in self.buffers: + self.buffers[tid] = "" + + if msg: + self.buffers[tid] += msg + + def finish(self): + tid = threading.get_ident() + + self._lock.acquire() + self.logger.info('THREAD: %d' % tid) + self.logger.info('-' * 70) + for line in self.buffers[tid].split('\n'): + self.logger.info(line) + self._lock.release() + +class OETestResultThreadedInternal(OETestResult): + def _tc_map_results(self): + tid = threading.get_ident() + + # PyUnit generates a result for every test module run, test + # if the thread already has an entry to avoid lose the previous + # test module results. + if not tid in self.tc._results: + self.tc._results[tid] = {} + self.tc._results[tid]['failures'] = self.failures + self.tc._results[tid]['errors'] = self.errors + self.tc._results[tid]['skipped'] = self.skipped + self.tc._results[tid]['expectedFailures'] = self.expectedFailures + +class OETestResultThreaded(object): + _results = {} + _lock = threading.Lock() + + def __init__(self, tc): + self.tc = tc + + def _fill_tc_results(self): + tids = list(self.tc._results.keys()) + fields = ['failures', 'errors', 'skipped', 'expectedFailures'] + + for tid in tids: + result = self.tc._results[tid] + for field in fields: + if not field in self.tc._results: + self.tc._results[field] = [] + self.tc._results[field].extend(result[field]) + + def addResult(self, result, run_start_time, run_end_time): + tid = threading.get_ident() + + self._lock.acquire() + self._results[tid] = {} + self._results[tid]['result'] = result + self._results[tid]['run_start_time'] = run_start_time + self._results[tid]['run_end_time'] = run_end_time + self._results[tid]['result'] = result + self._lock.release() + + def wasSuccessful(self): + wasSuccessful = True + for tid in self._results.keys(): + wasSuccessful = wasSuccessful and \ + self._results[tid]['result'].wasSuccessful() + return wasSuccessful + + def stop(self): + for tid in self._results.keys(): + self._results[tid]['result'].stop() + + def logSummary(self, component, context_msg=''): + elapsed_time = (self.tc._run_end_time - self.tc._run_start_time) + + self.tc.logger.info("SUMMARY:") + self.tc.logger.info("%s (%s) - Ran %d tests in %.3fs" % (component, + context_msg, len(self.tc._registry['cases']), elapsed_time)) + if self.wasSuccessful(): + msg = "%s - OK - All required tests passed" % component + else: + msg = "%s - FAIL - Required tests failed" % component + self.tc.logger.info(msg) + + def logDetails(self): + if list(self._results): + tid = list(self._results)[0] + result = self._results[tid]['result'] + result.logDetails() + +class _Worker(threading.Thread): + """Thread executing tasks from a given tasks queue""" + def __init__(self, tasks, result, stream): + threading.Thread.__init__(self) + self.tasks = tasks + + self.result = result + self.stream = stream + + def run(self): + while True: + try: + func, args, kargs = self.tasks.get(block=False) + except queue.Empty: + break + + try: + run_start_time = time.time() + rc = func(*args, **kargs) + run_end_time = time.time() + self.result.addResult(rc, run_start_time, run_end_time) + self.stream.finish() + except Exception as e: + print(e) + finally: + self.tasks.task_done() + +class _ThreadedPool: + """Pool of threads consuming tasks from a queue""" + def __init__(self, num_workers, num_tasks, stream=None, result=None): + self.tasks = queue.Queue(num_tasks) + self.workers = [] + + for _ in range(num_workers): + worker = _Worker(self.tasks, result, stream) + self.workers.append(worker) + + def start(self): + for worker in self.workers: + worker.start() + + def add_task(self, func, *args, **kargs): + """Add a task to the queue""" + self.tasks.put((func, args, kargs)) + + def wait_completion(self): + """Wait for completion of all the tasks in the queue""" + self.tasks.join() + for worker in self.workers: + worker.join() + +class OETestRunnerThreaded(OETestRunner): + streamLoggerClass = OEStreamLoggerThreaded + + def __init__(self, tc, *args, **kwargs): + super(OETestRunnerThreaded, self).__init__(tc, *args, **kwargs) + self.resultclass = OETestResultThreadedInternal # XXX: XML reporting overrides at __init__ + + def run(self, suites): + result = OETestResultThreaded(self.tc) + + pool = _ThreadedPool(len(suites), len(suites), stream=self.stream, + result=result) + for s in suites: + pool.add_task(super(OETestRunnerThreaded, self).run, s) + pool.start() + pool.wait_completion() + result._fill_tc_results() + + return result + +class OETestContextThreaded(OETestContext): + loaderClass = OETestLoaderThreaded + runnerClass = OETestRunnerThreaded + + def loadTests(self, module_paths, modules=[], tests=[], + modules_manifest="", modules_required=[], filters={}, process_num=0): + if modules_manifest: + modules = self._read_modules_from_manifest(modules_manifest) + + self.loader = self.loaderClass(self, module_paths, modules, tests, + modules_required, filters, process_num) + self.suites = self.loader.discover() |