logilab/common

view testlib.py @ 1176:a6b5fe18df99

allow to put %(cls)s substitution in deprecation warning to get the actual class name
author Sylvain Thénault <sylvain.thenault@logilab.fr>
date Mon, 02 Aug 2010 20:06:57 +0200
parents f4b1e0d9ed0c
children 69b9648a0a00
line source
1 # -*- coding: utf-8 -*-
2 # copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
3 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
4 #
5 # This file is part of logilab-common.
6 #
7 # logilab-common is free software: you can redistribute it and/or modify it under
8 # the terms of the GNU Lesser General Public License as published by the Free
9 # Software Foundation, either version 2.1 of the License, or (at your option) any
10 # later version.
11 #
12 # logilab-common is distributed in the hope that it will be useful, but WITHOUT
13 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
14 # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
15 # details.
16 #
17 # You should have received a copy of the GNU Lesser General Public License along
18 # with logilab-common. If not, see <http://www.gnu.org/licenses/>.
19 """Run tests.
21 This will find all modules whose name match a given prefix in the test
22 directory, and run them. Various command line options provide
23 additional facilities.
25 Command line options:
27 -v verbose -- run tests in verbose mode with output to stdout
28 -q quiet -- don't print anything except if a test fails
29 -t testdir -- directory where the tests will be found
30 -x exclude -- add a test to exclude
31 -p profile -- profiled execution
32 -c capture -- capture standard out/err during tests
33 -d dbc -- enable design-by-contract
34 -m match -- only run test matching the tag pattern which follow
36 If no non-option arguments are present, prefixes used are 'test',
37 'regrtest', 'smoketest' and 'unittest'.
39 """
40 __docformat__ = "restructuredtext en"
41 # modified copy of some functions from test/regrtest.py from PyXml
42 # disable camel case warning
43 # pylint: disable-msg=C0103
45 import sys
46 import os, os.path as osp
47 import re
48 import time
49 import getopt
50 import traceback
51 import inspect
52 import unittest
53 import difflib
54 import types
55 import tempfile
56 import math
57 from shutil import rmtree
58 from operator import itemgetter
59 import warnings
60 from compiler.consts import CO_GENERATOR
61 from ConfigParser import ConfigParser
62 from itertools import dropwhile
63 try:
64 from functools import wraps
65 except ImportError:
66 def wraps(wrapped):
67 def proxy(callable):
68 callable.__name__ = wrapped.__name__
69 return callable
70 return proxy
71 try:
72 from test import test_support
73 except ImportError:
74 # not always available
75 class TestSupport:
76 def unload(self, test):
77 pass
78 test_support = TestSupport()
80 # pylint: disable-msg=W0622
81 from logilab.common.compat import set, enumerate, any, sorted
82 # pylint: enable-msg=W0622
83 from logilab.common.modutils import load_module_from_name
84 from logilab.common.debugger import Debugger, colorize_source
85 from logilab.common.decorators import cached, classproperty
86 from logilab.common import textutils
89 __all__ = ['main', 'unittest_main', 'find_tests', 'run_test', 'spawn']
91 DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest',
92 'func', 'validation')
94 ENABLE_DBC = False
96 FILE_RESTART = ".pytest.restart"
98 # used by unittest to count the number of relevant levels in the traceback
99 __unittest = 1
102 def with_tempdir(callable):
103 """A decorator ensuring no temporary file left when the function return
104 Work only for temporary file create with the tempfile module"""
105 @wraps(callable)
106 def proxy(*args, **kargs):
108 old_tmpdir = tempfile.gettempdir()
109 new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
110 tempfile.tempdir = new_tmpdir
111 try:
112 return callable(*args, **kargs)
113 finally:
114 try:
115 rmtree(new_tmpdir, ignore_errors=True)
116 finally:
117 tempfile.tempdir = old_tmpdir
118 return proxy
120 def in_tempdir(callable):
121 """A decorator moving the enclosed function inside the tempfile.tempfdir
122 """
123 @wraps(callable)
124 def proxy(*args, **kargs):
126 old_cwd = os.getcwd()
127 os.chdir(tempfile.tempdir)
128 try:
129 return callable(*args, **kargs)
130 finally:
131 os.chdir(old_cwd)
132 return proxy
134 def within_tempdir(callable):
135 """A decorator run the enclosed function inside a tmpdir removed after execution
136 """
137 proxy = with_tempdir(in_tempdir(callable))
138 proxy.__name__ = callable.__name__
139 return proxy
141 def run_tests(tests, quiet, verbose, runner=None, capture=0):
142 """Execute a list of tests.
144 :rtype: tuple
145 :return: tuple (list of passed tests, list of failed tests, list of skipped tests)
146 """
147 good = []
148 bad = []
149 skipped = []
150 all_result = None
151 for test in tests:
152 if not quiet:
153 print
154 print '-'*80
155 print "Executing", test
156 result = run_test(test, verbose, runner, capture)
157 if type(result) is type(''):
158 # an unexpected error occurred
159 skipped.append( (test, result))
160 else:
161 if all_result is None:
162 all_result = result
163 else:
164 all_result.testsRun += result.testsRun
165 all_result.failures += result.failures
166 all_result.errors += result.errors
167 all_result.skipped += result.skipped
168 if result.errors or result.failures:
169 bad.append(test)
170 if verbose:
171 print "test", test, \
172 "failed -- %s errors, %s failures" % (
173 len(result.errors), len(result.failures))
174 else:
175 good.append(test)
177 return good, bad, skipped, all_result
179 def find_tests(testdir,
180 prefixes=DEFAULT_PREFIXES, suffix=".py",
181 excludes=(),
182 remove_suffix=True):
183 """
184 Return a list of all applicable test modules.
185 """
186 tests = []
187 for name in os.listdir(testdir):
188 if not suffix or name.endswith(suffix):
189 for prefix in prefixes:
190 if name.startswith(prefix):
191 if remove_suffix and name.endswith(suffix):
192 name = name[:-len(suffix)]
193 if name not in excludes:
194 tests.append(name)
195 tests.sort()
196 return tests
199 def run_test(test, verbose, runner=None, capture=0):
200 """
201 Run a single test.
203 test -- the name of the test
204 verbose -- if true, print more messages
205 """
206 test_support.unload(test)
207 try:
208 m = load_module_from_name(test, path=sys.path)
209 # m = __import__(test, globals(), locals(), sys.path)
210 try:
211 suite = m.suite
212 if callable(suite):
213 suite = suite()
214 except AttributeError:
215 loader = unittest.TestLoader()
216 suite = loader.loadTestsFromModule(m)
217 if runner is None:
218 runner = SkipAwareTextTestRunner(capture=capture) # verbosity=0)
219 return runner.run(suite)
220 except KeyboardInterrupt:
221 raise
222 except:
223 # raise
224 type, value = sys.exc_info()[:2]
225 msg = "test %s crashed -- %s : %s" % (test, type, value)
226 if verbose:
227 traceback.print_exc()
228 return msg
230 def _count(n, word):
231 """format word according to n"""
232 if n == 1:
233 return "%d %s" % (n, word)
234 else:
235 return "%d %ss" % (n, word)
240 ## PostMortem Debug facilities #####
241 def start_interactive_mode(result):
242 """starts an interactive shell so that the user can inspect errors
243 """
244 debuggers = result.debuggers
245 descrs = result.error_descrs + result.fail_descrs
246 if len(debuggers) == 1:
247 # don't ask for test name if there's only one failure
248 debuggers[0].start()
249 else:
250 while True:
251 testindex = 0
252 print "Choose a test to debug:"
253 # order debuggers in the same way than errors were printed
254 print "\n".join(['\t%s : %s' % (i, descr) for i, (_, descr)
255 in enumerate(descrs)])
256 print "Type 'exit' (or ^D) to quit"
257 print
258 try:
259 todebug = raw_input('Enter a test name: ')
260 if todebug.strip().lower() == 'exit':
261 print
262 break
263 else:
264 try:
265 testindex = int(todebug)
266 debugger = debuggers[descrs[testindex][0]]
267 except (ValueError, IndexError):
268 print "ERROR: invalid test number %r" % (todebug, )
269 else:
270 debugger.start()
271 except (EOFError, KeyboardInterrupt):
272 print
273 break
276 # test utils ##################################################################
277 from cStringIO import StringIO
279 class SkipAwareTestResult(unittest._TextTestResult):
281 def __init__(self, stream, descriptions, verbosity,
282 exitfirst=False, capture=0, printonly=None,
283 pdbmode=False, cvg=None, colorize=False):
284 super(SkipAwareTestResult, self).__init__(stream,
285 descriptions, verbosity)
286 self.skipped = []
287 self.debuggers = []
288 self.fail_descrs = []
289 self.error_descrs = []
290 self.exitfirst = exitfirst
291 self.capture = capture
292 self.printonly = printonly
293 self.pdbmode = pdbmode
294 self.cvg = cvg
295 self.colorize = colorize
296 self.pdbclass = Debugger
297 self.verbose = verbosity > 1
299 def descrs_for(self, flavour):
300 return getattr(self, '%s_descrs' % flavour.lower())
302 def _create_pdb(self, test_descr, flavour):
303 self.descrs_for(flavour).append( (len(self.debuggers), test_descr) )
304 if self.pdbmode:
305 self.debuggers.append(self.pdbclass(sys.exc_info()[2]))
308 def _iter_valid_frames(self, frames):
309 """only consider non-testlib frames when formatting traceback"""
310 lgc_testlib = osp.abspath(__file__)
311 std_testlib = osp.abspath(unittest.__file__)
312 invalid = lambda fi: osp.abspath(fi[1]) in (lgc_testlib, std_testlib)
313 for frameinfo in dropwhile(invalid, frames):
314 yield frameinfo
316 def _exc_info_to_string(self, err, test):
317 """Converts a sys.exc_info()-style tuple of values into a string.
319 This method is overridden here because we want to colorize
320 lines if --color is passed, and display local variables if
321 --verbose is passed
322 """
323 exctype, exc, tb = err
324 output = ['Traceback (most recent call last)']
325 frames = inspect.getinnerframes(tb)
326 colorize = self.colorize
327 frames = enumerate(self._iter_valid_frames(frames))
328 for index, (frame, filename, lineno, funcname, ctx, ctxindex) in frames:
329 filename = osp.abspath(filename)
330 if ctx is None: # pyc files or C extensions for instance
331 source = '<no source available>'
332 else:
333 source = ''.join(ctx)
334 if colorize:
335 filename = textutils.colorize_ansi(filename, 'magenta')
336 source = colorize_source(source)
337 output.append(' File "%s", line %s, in %s' % (filename, lineno, funcname))
338 output.append(' %s' % source.strip())
339 if self.verbose:
340 output.append('%r == %r' % (dir(frame), test.__module__))
341 output.append('')
342 output.append(' ' + ' local variables '.center(66, '-'))
343 for varname, value in sorted(frame.f_locals.items()):
344 output.append(' %s: %r' % (varname, value))
345 if varname == 'self': # special handy processing for self
346 for varname, value in sorted(vars(value).items()):
347 output.append(' self.%s: %r' % (varname, value))
348 output.append(' ' + '-' * 66)
349 output.append('')
350 output.append(''.join(traceback.format_exception_only(exctype, exc)))
351 return '\n'.join(output)
353 def addError(self, test, err):
354 """err == (exc_type, exc, tcbk)"""
355 exc_type, exc, _ = err #
356 if exc_type == TestSkipped:
357 self.addSkipped(test, exc)
358 else:
359 if self.exitfirst:
360 self.shouldStop = True
361 descr = self.getDescription(test)
362 super(SkipAwareTestResult, self).addError(test, err)
363 self._create_pdb(descr, 'error')
365 def addFailure(self, test, err):
366 if self.exitfirst:
367 self.shouldStop = True
368 descr = self.getDescription(test)
369 super(SkipAwareTestResult, self).addFailure(test, err)
370 self._create_pdb(descr, 'fail')
372 def addSkipped(self, test, reason):
373 self.skipped.append((test, self.getDescription(test), reason))
374 if self.showAll:
375 self.stream.writeln("SKIPPED")
376 elif self.dots:
377 self.stream.write('S')
379 def printErrors(self):
380 super(SkipAwareTestResult, self).printErrors()
381 self.printSkippedList()
383 def printSkippedList(self):
384 for _, descr, err in self.skipped: # test, descr, err
385 self.stream.writeln(self.separator1)
386 self.stream.writeln("%s: %s" % ('SKIPPED', descr))
387 self.stream.writeln("\t%s" % err)
389 def printErrorList(self, flavour, errors):
390 for (_, descr), (test, err) in zip(self.descrs_for(flavour), errors):
391 self.stream.writeln(self.separator1)
392 if self.colorize:
393 self.stream.writeln("%s: %s" % (
394 textutils.colorize_ansi(flavour, color='red'), descr))
395 else:
396 self.stream.writeln("%s: %s" % (flavour, descr))
398 self.stream.writeln(self.separator2)
399 self.stream.writeln(err)
400 try:
401 output, errput = test.captured_output()
402 except AttributeError:
403 pass # original unittest
404 else:
405 if output:
406 self.stream.writeln(self.separator2)
407 self.stream.writeln("captured stdout".center(
408 len(self.separator2)))
409 self.stream.writeln(self.separator2)
410 self.stream.writeln(output)
411 else:
412 self.stream.writeln('no stdout'.center(
413 len(self.separator2)))
414 if errput:
415 self.stream.writeln(self.separator2)
416 self.stream.writeln("captured stderr".center(
417 len(self.separator2)))
418 self.stream.writeln(self.separator2)
419 self.stream.writeln(errput)
420 else:
421 self.stream.writeln('no stderr'.center(
422 len(self.separator2)))
425 def run(self, result, runcondition=None, options=None):
426 for test in self._tests:
427 if result.shouldStop:
428 break
429 try:
430 test(result, runcondition, options)
431 except TypeError:
432 # this might happen if a raw unittest.TestCase is defined
433 # and used with python (and not pytest)
434 warnings.warn("%s should extend lgc.testlib.TestCase instead of unittest.TestCase"
435 % test)
436 test(result)
437 return result
438 unittest.TestSuite.run = run
440 # backward compatibility: TestSuite might be imported from lgc.testlib
441 TestSuite = unittest.TestSuite
443 # python2.3 compat
444 def __call__(self, *args, **kwds):
445 return self.run(*args, **kwds)
446 unittest.TestSuite.__call__ = __call__
449 class SkipAwareTextTestRunner(unittest.TextTestRunner):
451 def __init__(self, stream=sys.stderr, verbosity=1,
452 exitfirst=False, capture=False, printonly=None,
453 pdbmode=False, cvg=None, test_pattern=None,
454 skipped_patterns=(), colorize=False, batchmode=False,
455 options=None):
456 super(SkipAwareTextTestRunner, self).__init__(stream=stream,
457 verbosity=verbosity)
458 self.exitfirst = exitfirst
459 self.capture = capture
460 self.printonly = printonly
461 self.pdbmode = pdbmode
462 self.cvg = cvg
463 self.test_pattern = test_pattern
464 self.skipped_patterns = skipped_patterns
465 self.colorize = colorize
466 self.batchmode = batchmode
467 self.options = options
469 def _this_is_skipped(self, testedname):
470 return any([(pat in testedname) for pat in self.skipped_patterns])
472 def _runcondition(self, test, skipgenerator=True):
473 if isinstance(test, InnerTest):
474 testname = test.name
475 else:
476 if isinstance(test, TestCase):
477 meth = test._get_test_method()
478 func = meth.im_func
479 testname = '%s.%s' % (meth.im_class.__name__, func.__name__)
480 elif isinstance(test, types.FunctionType):
481 func = test
482 testname = func.__name__
483 elif isinstance(test, types.MethodType):
484 func = test.im_func
485 testname = '%s.%s' % (test.im_class.__name__, func.__name__)
486 else:
487 return True # Not sure when this happens
489 if is_generator(func) and skipgenerator:
490 return self.does_match_tags(func) # Let inner tests decide at run time
492 # print 'testname', testname, self.test_pattern
493 if self._this_is_skipped(testname):
494 return False # this was explicitly skipped
495 if self.test_pattern is not None:
496 try:
497 classpattern, testpattern = self.test_pattern.split('.')
498 klass, name = testname.split('.')
499 if classpattern not in klass or testpattern not in name:
500 return False
501 except ValueError:
502 if self.test_pattern not in testname:
503 return False
505 return self.does_match_tags(test)
507 def does_match_tags(self, test):
508 if self.options is not None:
509 tags_pattern = getattr(self.options, 'tags_pattern', None)
510 if tags_pattern is not None:
511 tags = getattr(test, 'tags', Tags())
512 if tags.inherit and isinstance(test, types.MethodType):
513 tags = tags | getattr(test.im_class, 'tags', Tags())
514 return tags.match(tags_pattern)
515 return True # no pattern
517 def _makeResult(self):
518 return SkipAwareTestResult(self.stream, self.descriptions,
519 self.verbosity, self.exitfirst, self.capture,
520 self.printonly, self.pdbmode, self.cvg,
521 self.colorize)
523 def run(self, test):
524 "Run the given test case or test suite."
525 result = self._makeResult()
526 startTime = time.time()
527 test(result, self._runcondition, self.options)
528 stopTime = time.time()
529 timeTaken = stopTime - startTime
530 result.printErrors()
531 if not self.batchmode:
532 self.stream.writeln(result.separator2)
533 run = result.testsRun
534 self.stream.writeln("Ran %d test%s in %.3fs" %
535 (run, run != 1 and "s" or "", timeTaken))
536 self.stream.writeln()
537 if not result.wasSuccessful():
538 if self.colorize:
539 self.stream.write(textutils.colorize_ansi("FAILED", color='red'))
540 else:
541 self.stream.write("FAILED")
542 else:
543 if self.colorize:
544 self.stream.write(textutils.colorize_ansi("OK", color='green'))
545 else:
546 self.stream.write("OK")
547 failed, errored, skipped = map(len, (result.failures, result.errors,
548 result.skipped))
550 det_results = []
551 for name, value in (("failures", result.failures),
552 ("errors",result.errors),
553 ("skipped", result.skipped)):
554 if value:
555 det_results.append("%s=%i" % (name, len(value)))
556 if det_results:
557 self.stream.write(" (")
558 self.stream.write(', '.join(det_results))
559 self.stream.write(")")
560 self.stream.writeln("")
561 return result
564 class keywords(dict):
565 """Keyword args (**kwargs) support for generative tests."""
567 class starargs(tuple):
568 """Variable arguments (*args) for generative tests."""
569 def __new__(cls, *args):
570 return tuple.__new__(cls, args)
574 class NonStrictTestLoader(unittest.TestLoader):
575 """
576 Overrides default testloader to be able to omit classname when
577 specifying tests to run on command line.
579 For example, if the file test_foo.py contains ::
581 class FooTC(TestCase):
582 def test_foo1(self): # ...
583 def test_foo2(self): # ...
584 def test_bar1(self): # ...
586 class BarTC(TestCase):
587 def test_bar2(self): # ...
589 'python test_foo.py' will run the 3 tests in FooTC
590 'python test_foo.py FooTC' will run the 3 tests in FooTC
591 'python test_foo.py test_foo' will run test_foo1 and test_foo2
592 'python test_foo.py test_foo1' will run test_foo1
593 'python test_foo.py test_bar' will run FooTC.test_bar1 and BarTC.test_bar2
594 """
596 def __init__(self):
597 self.skipped_patterns = []
599 def loadTestsFromNames(self, names, module=None):
600 suites = []
601 for name in names:
602 suites.extend(self.loadTestsFromName(name, module))
603 return self.suiteClass(suites)
605 def _collect_tests(self, module):
606 tests = {}
607 for obj in vars(module).values():
608 if (issubclass(type(obj), (types.ClassType, type)) and
609 issubclass(obj, unittest.TestCase)):
610 classname = obj.__name__
611 if classname[0] == '_' or self._this_is_skipped(classname):
612 continue
613 methodnames = []
614 # obj is a TestCase class
615 for attrname in dir(obj):
616 if attrname.startswith(self.testMethodPrefix):
617 attr = getattr(obj, attrname)
618 if callable(attr):
619 methodnames.append(attrname)
620 # keep track of class (obj) for convenience
621 tests[classname] = (obj, methodnames)
622 return tests
624 def loadTestsFromSuite(self, module, suitename):
625 try:
626 suite = getattr(module, suitename)()
627 except AttributeError:
628 return []
629 assert hasattr(suite, '_tests'), \
630 "%s.%s is not a valid TestSuite" % (module.__name__, suitename)
631 # python2.3 does not implement __iter__ on suites, we need to return
632 # _tests explicitly
633 return suite._tests
635 def loadTestsFromName(self, name, module=None):
636 parts = name.split('.')
637 if module is None or len(parts) > 2:
638 # let the base class do its job here
639 return [super(NonStrictTestLoader, self).loadTestsFromName(name)]
640 tests = self._collect_tests(module)
641 # import pprint
642 # pprint.pprint(tests)
643 collected = []
644 if len(parts) == 1:
645 pattern = parts[0]
646 if callable(getattr(module, pattern, None)
647 ) and pattern not in tests:
648 # consider it as a suite
649 return self.loadTestsFromSuite(module, pattern)
650 if pattern in tests:
651 # case python unittest_foo.py MyTestTC
652 klass, methodnames = tests[pattern]
653 for methodname in methodnames:
654 collected = [klass(methodname)
655 for methodname in methodnames]
656 else:
657 # case python unittest_foo.py something
658 for klass, methodnames in tests.values():
659 collected += [klass(methodname)
660 for methodname in methodnames]
661 elif len(parts) == 2:
662 # case "MyClass.test_1"
663 classname, pattern = parts
664 klass, methodnames = tests.get(classname, (None, []))
665 for methodname in methodnames:
666 collected = [klass(methodname) for methodname in methodnames]
667 return collected
669 def _this_is_skipped(self, testedname):
670 return any([(pat in testedname) for pat in self.skipped_patterns])
672 def getTestCaseNames(self, testCaseClass):
673 """Return a sorted sequence of method names found within testCaseClass
674 """
675 is_skipped = self._this_is_skipped
676 classname = testCaseClass.__name__
677 if classname[0] == '_' or is_skipped(classname):
678 return []
679 testnames = super(NonStrictTestLoader, self).getTestCaseNames(
680 testCaseClass)
681 return [testname for testname in testnames if not is_skipped(testname)]
684 class SkipAwareTestProgram(unittest.TestProgram):
685 # XXX: don't try to stay close to unittest.py, use optparse
686 USAGE = """\
687 Usage: %(progName)s [options] [test] [...]
689 Options:
690 -h, --help Show this message
691 -v, --verbose Verbose output
692 -i, --pdb Enable test failure inspection
693 -x, --exitfirst Exit on first failure
694 -c, --capture Captures and prints standard out/err only on errors
695 -p, --printonly Only prints lines matching specified pattern
696 (implies capture)
697 -s, --skip skip test matching this pattern (no regexp for now)
698 -q, --quiet Minimal output
699 --color colorize tracebacks
701 -m, --match Run only test whose tag match this pattern
703 -P, --profile FILE: Run the tests using cProfile and saving results
704 in FILE
706 Examples:
707 %(progName)s - run default set of tests
708 %(progName)s MyTestSuite - run suite 'MyTestSuite'
709 %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
710 %(progName)s MyTestCase - run all 'test*' test methods
711 in MyTestCase
712 """
713 def __init__(self, module='__main__', defaultTest=None, batchmode=False,
714 cvg=None, options=None, outstream=sys.stderr):
715 self.batchmode = batchmode
716 self.cvg = cvg
717 self.options = options
718 self.outstream = outstream
719 super(SkipAwareTestProgram, self).__init__(
720 module=module, defaultTest=defaultTest,
721 testLoader=NonStrictTestLoader())
723 def parseArgs(self, argv):
724 self.pdbmode = False
725 self.exitfirst = False
726 self.capture = 0
727 self.printonly = None
728 self.skipped_patterns = []
729 self.test_pattern = None
730 self.tags_pattern = None
731 self.colorize = False
732 self.profile_name = None
733 import getopt
734 try:
735 options, args = getopt.getopt(argv[1:], 'hHvixrqcp:s:m:P:',
736 ['help', 'verbose', 'quiet', 'pdb',
737 'exitfirst', 'restart', 'capture', 'printonly=',
738 'skip=', 'color', 'match=', 'profile='])
739 for opt, value in options:
740 if opt in ('-h', '-H', '--help'):
741 self.usageExit()
742 if opt in ('-i', '--pdb'):
743 self.pdbmode = True
744 if opt in ('-x', '--exitfirst'):
745 self.exitfirst = True
746 if opt in ('-r', '--restart'):
747 self.restart = True
748 self.exitfirst = True
749 if opt in ('-q', '--quiet'):
750 self.verbosity = 0
751 if opt in ('-v', '--verbose'):
752 self.verbosity = 2
753 if opt in ('-c', '--capture'):
754 self.capture += 1
755 if opt in ('-p', '--printonly'):
756 self.printonly = re.compile(value)
757 if opt in ('-s', '--skip'):
758 self.skipped_patterns = [pat.strip() for pat in
759 value.split(', ')]
760 if opt == '--color':
761 self.colorize = True
762 if opt in ('-m', '--match'):
763 #self.tags_pattern = value
764 self.options["tag_pattern"] = value
765 if opt in ('-P', '--profile'):
766 self.profile_name = value
767 self.testLoader.skipped_patterns = self.skipped_patterns
768 if self.printonly is not None:
769 self.capture += 1
770 if len(args) == 0 and self.defaultTest is None:
771 suitefunc = getattr(self.module, 'suite', None)
772 if isinstance(suitefunc, (types.FunctionType,
773 types.MethodType)):
774 self.test = self.module.suite()
775 else:
776 self.test = self.testLoader.loadTestsFromModule(self.module)
777 return
778 if len(args) > 0:
779 self.test_pattern = args[0]
780 self.testNames = args
781 else:
782 self.testNames = (self.defaultTest, )
783 self.createTests()
784 except getopt.error, msg:
785 self.usageExit(msg)
788 def runTests(self):
789 if self.profile_name:
790 import cProfile
791 cProfile.runctx('self._runTests()', globals(), locals(), self.profile_name )
792 else:
793 return self._runTests()
795 def _runTests(self):
796 if hasattr(self.module, 'setup_module'):
797 try:
798 self.module.setup_module(self.options)
799 except Exception, exc:
800 print 'setup_module error:', exc
801 sys.exit(1)
802 self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity,
803 stream=self.outstream,
804 exitfirst=self.exitfirst,
805 capture=self.capture,
806 printonly=self.printonly,
807 pdbmode=self.pdbmode,
808 cvg=self.cvg,
809 test_pattern=self.test_pattern,
810 skipped_patterns=self.skipped_patterns,
811 colorize=self.colorize,
812 batchmode=self.batchmode,
813 options=self.options)
815 def removeSucceededTests(obj, succTests):
816 """ Recursive function that removes succTests from
817 a TestSuite or TestCase
818 """
819 if isinstance(obj, TestSuite):
820 removeSucceededTests(obj._tests, succTests)
821 if isinstance(obj, list):
822 for el in obj[:]:
823 if isinstance(el, TestSuite):
824 removeSucceededTests(el, succTests)
825 elif isinstance(el, TestCase):
826 descr = '.'.join((el.__class__.__module__,
827 el.__class__.__name__,
828 el._testMethodName))
829 if descr in succTests:
830 obj.remove(el)
831 # take care, self.options may be None
832 if getattr(self.options, 'restart', False):
833 # retrieve succeeded tests from FILE_RESTART
834 try:
835 restartfile = open(FILE_RESTART, 'r')
836 try:
837 succeededtests = list(elem.rstrip('\n\r') for elem in
838 restartfile.readlines())
839 removeSucceededTests(self.test, succeededtests)
840 finally:
841 restartfile.close()
842 except Exception, ex:
843 raise Exception("Error while reading succeeded tests into %s: %s"
844 % (osp.join(os.getcwd(), FILE_RESTART), ex))
846 result = self.testRunner.run(self.test)
847 # help garbage collection: we want TestSuite, which hold refs to every
848 # executed TestCase, to be gc'ed
849 del self.test
850 if hasattr(self.module, 'teardown_module'):
851 try:
852 self.module.teardown_module(self.options, result)
853 except Exception, exc:
854 print 'teardown_module error:', exc
855 sys.exit(1)
856 if result.debuggers and self.pdbmode:
857 start_interactive_mode(result)
858 if not self.batchmode:
859 sys.exit(not result.wasSuccessful())
860 self.result = result
865 class FDCapture:
866 """adapted from py lib (http://codespeak.net/py)
867 Capture IO to/from a given os-level filedescriptor.
868 """
869 def __init__(self, fd, attr='stdout', printonly=None):
870 self.targetfd = fd
871 self.tmpfile = os.tmpfile() # self.maketempfile()
872 self.printonly = printonly
873 # save original file descriptor
874 self._savefd = os.dup(fd)
875 # override original file descriptor
876 os.dup2(self.tmpfile.fileno(), fd)
877 # also modify sys module directly
878 self.oldval = getattr(sys, attr)
879 setattr(sys, attr, self) # self.tmpfile)
880 self.attr = attr
882 def write(self, msg):
883 # msg might be composed of several lines
884 for line in msg.splitlines():
885 line += '\n' # keepdend=True is not enough
886 if self.printonly is None or self.printonly.search(line) is None:
887 self.tmpfile.write(line)
888 else:
889 os.write(self._savefd, line)
891 ## def maketempfile(self):
892 ## tmpf = os.tmpfile()
893 ## fd = os.dup(tmpf.fileno())
894 ## newf = os.fdopen(fd, tmpf.mode, 0) # No buffering
895 ## tmpf.close()
896 ## return newf
898 def restore(self):
899 """restore original fd and returns captured output"""
900 #XXX: hack hack hack
901 self.tmpfile.flush()
902 try:
903 ref_file = getattr(sys, '__%s__' % self.attr)
904 ref_file.flush()
905 except AttributeError:
906 pass
907 if hasattr(self.oldval, 'flush'):
908 self.oldval.flush()
909 # restore original file descriptor
910 os.dup2(self._savefd, self.targetfd)
911 # restore sys module
912 setattr(sys, self.attr, self.oldval)
913 # close backup descriptor
914 os.close(self._savefd)
915 # go to beginning of file and read it
916 self.tmpfile.seek(0)
917 return self.tmpfile.read()
920 def _capture(which='stdout', printonly=None):
921 """private method, should not be called directly
922 (cf. capture_stdout() and capture_stderr())
923 """
924 assert which in ('stdout', 'stderr'
925 ), "Can only capture stdout or stderr, not %s" % which
926 if which == 'stdout':
927 fd = 1
928 else:
929 fd = 2
930 return FDCapture(fd, which, printonly)
932 def capture_stdout(printonly=None):
933 """captures the standard output
935 returns a handle object which has a `restore()` method.
936 The restore() method returns the captured stdout and restores it
937 """
938 return _capture('stdout', printonly)
940 def capture_stderr(printonly=None):
941 """captures the standard error output
943 returns a handle object which has a `restore()` method.
944 The restore() method returns the captured stderr and restores it
945 """
946 return _capture('stderr', printonly)
949 def unittest_main(module='__main__', defaultTest=None,
950 batchmode=False, cvg=None, options=None,
951 outstream=sys.stderr):
952 """use this function if you want to have the same functionality
953 as unittest.main"""
954 return SkipAwareTestProgram(module, defaultTest, batchmode,
955 cvg, options, outstream)
957 class TestSkipped(Exception):
958 """raised when a test is skipped"""
960 class InnerTestSkipped(TestSkipped):
961 """raised when a test is skipped"""
963 def is_generator(function):
964 flags = function.func_code.co_flags
965 return flags & CO_GENERATOR
968 def parse_generative_args(params):
969 args = []
970 varargs = ()
971 kwargs = {}
972 flags = 0 # 2 <=> starargs, 4 <=> kwargs
973 for param in params:
974 if isinstance(param, starargs):
975 varargs = param
976 if flags:
977 raise TypeError('found starargs after keywords !')
978 flags |= 2
979 args += list(varargs)
980 elif isinstance(param, keywords):
981 kwargs = param
982 if flags & 4:
983 raise TypeError('got multiple keywords parameters')
984 flags |= 4
985 elif flags & 2 or flags & 4:
986 raise TypeError('found parameters after kwargs or args')
987 else:
988 args.append(param)
990 return args, kwargs
992 class InnerTest(tuple):
993 def __new__(cls, name, *data):
994 instance = tuple.__new__(cls, data)
995 instance.name = name
996 return instance
998 class Tags(set):
999 """A set of tag able validate an expression"""
1001 def __init__(self, *tags, **kwargs):
1002 self.inherit = kwargs.pop('inherit', True)
1003 if kwargs:
1004 raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
1006 if len(tags) == 1 and not isinstance(tags[0], basestring):
1007 tags = tags[0]
1008 super(Tags, self).__init__(tags, **kwargs)
1010 def __getitem__(self, key):
1011 return key in self
1013 def match(self, exp):
1014 return eval(exp, {}, self)
1016 class TestCase(unittest.TestCase):
1017 """A unittest.TestCase extension with some additional methods."""
1019 capture = False
1020 pdbclass = Debugger
1021 tags = Tags()
1023 def __init__(self, methodName='runTest'):
1024 super(TestCase, self).__init__(methodName)
1025 # internal API changed in python2.5
1026 if sys.version_info >= (2, 5):
1027 self.__exc_info = self._exc_info
1028 self.__testMethodName = self._testMethodName
1029 else:
1030 # let's give easier access to _testMethodName to every subclasses
1031 self._testMethodName = self.__testMethodName
1032 self._captured_stdout = ""
1033 self._captured_stderr = ""
1034 self._out = []
1035 self._err = []
1036 self._current_test_descr = None
1037 self._options_ = None
1039 def datadir(cls): # pylint: disable-msg=E0213
1040 """helper attribute holding the standard test's data directory
1042 NOTE: this is a logilab's standard
1043 """
1044 mod = __import__(cls.__module__)
1045 return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data')
1046 # cache it (use a class method to cache on class since TestCase is
1047 # instantiated for each test run)
1048 datadir = classproperty(cached(datadir))
1050 def datapath(cls, *fname):
1051 """joins the object's datadir and `fname`"""
1052 return osp.join(cls.datadir, *fname)
1053 datapath = classmethod(datapath)
1055 def set_description(self, descr):
1056 """sets the current test's description.
1057 This can be useful for generative tests because it allows to specify
1058 a description per yield
1059 """
1060 self._current_test_descr = descr
1062 # override default's unittest.py feature
1063 def shortDescription(self):
1064 """override default unitest shortDescription to handle correctly
1065 generative tests
1066 """
1067 if self._current_test_descr is not None:
1068 return self._current_test_descr
1069 return super(TestCase, self).shortDescription()
1072 def captured_output(self):
1073 """return a two tuple with standard output and error stripped"""
1074 return self._captured_stdout.strip(), self._captured_stderr.strip()
1076 def _start_capture(self):
1077 """start_capture if enable"""
1078 if self.capture:
1079 warnings.simplefilter('ignore', DeprecationWarning)
1080 self.start_capture()
1082 def _stop_capture(self):
1083 """stop_capture and restore previous output"""
1084 self._force_output_restore()
1086 def start_capture(self, printonly=None):
1087 """start_capture"""
1088 self._out.append(capture_stdout(printonly or self._printonly))
1089 self._err.append(capture_stderr(printonly or self._printonly))
1091 def printonly(self, pattern, flags=0):
1092 """set the pattern of line to print"""
1093 rgx = re.compile(pattern, flags)
1094 if self._out:
1095 self._out[-1].printonly = rgx
1096 self._err[-1].printonly = rgx
1097 else:
1098 self.start_capture(printonly=rgx)
1100 def stop_capture(self):
1101 """stop output and error capture"""
1102 if self._out:
1103 _out = self._out.pop()
1104 _err = self._err.pop()
1105 return _out.restore(), _err.restore()
1106 return '', ''
1108 def _force_output_restore(self):
1109 """remove all capture set"""
1110 while self._out:
1111 self._captured_stdout += self._out.pop().restore()
1112 self._captured_stderr += self._err.pop().restore()
1114 def quiet_run(self, result, func, *args, **kwargs):
1115 self._start_capture()
1116 try:
1117 func(*args, **kwargs)
1118 except (KeyboardInterrupt, SystemExit):
1119 self._stop_capture()
1120 raise
1121 except:
1122 self._stop_capture()
1123 result.addError(self, self.__exc_info())
1124 return False
1125 self._stop_capture()
1126 return True
1128 def _get_test_method(self):
1129 """return the test method"""
1130 return getattr(self, self.__testMethodName)
1133 def optval(self, option, default=None):
1134 """return the option value or default if the option is not define"""
1135 return getattr(self._options_, option, default)
1137 def __call__(self, result=None, runcondition=None, options=None):
1138 """rewrite TestCase.__call__ to support generative tests
1139 This is mostly a copy/paste from unittest.py (i.e same
1140 variable names, same logic, except for the generative tests part)
1141 """
1142 if result is None:
1143 result = self.defaultTestResult()
1144 result.pdbclass = self.pdbclass
1145 # if self.capture is True here, it means it was explicitly specified
1146 # in the user's TestCase class. If not, do what was asked on cmd line
1147 self.capture = self.capture or getattr(result, 'capture', False)
1148 self._options_ = options
1149 self._printonly = getattr(result, 'printonly', None)
1150 # if result.cvg:
1151 # result.cvg.start()
1152 testMethod = self._get_test_method()
1153 if runcondition and not runcondition(testMethod):
1154 return # test is skipped
1155 result.startTest(self)
1156 try:
1157 if not self.quiet_run(result, self.setUp):
1158 return
1159 generative = is_generator(testMethod.im_func)
1160 # generative tests
1161 if generative:
1162 self._proceed_generative(result, testMethod,
1163 runcondition)
1164 else:
1165 status = self._proceed(result, testMethod)
1166 success = (status == 0)
1167 if not self.quiet_run(result, self.tearDown):
1168 return
1169 if not generative and success:
1170 if hasattr(options, "exitfirst") and options.exitfirst:
1171 # add this test to restart file
1172 try:
1173 restartfile = open(FILE_RESTART, 'a')
1174 try:
1175 descr = '.'.join((self.__class__.__module__,
1176 self.__class__.__name__,
1177 self._testMethodName))
1178 restartfile.write(descr+os.linesep)
1179 finally:
1180 restartfile.close()
1181 except Exception, ex:
1182 print >> sys.__stderr__, "Error while saving \
1183 succeeded test into", osp.join(os.getcwd(),FILE_RESTART)
1184 raise ex
1185 result.addSuccess(self)
1186 finally:
1187 # if result.cvg:
1188 # result.cvg.stop()
1189 result.stopTest(self)
1193 def _proceed_generative(self, result, testfunc, runcondition=None):
1194 # cancel startTest()'s increment
1195 result.testsRun -= 1
1196 self._start_capture()
1197 success = True
1198 try:
1199 for params in testfunc():
1200 if runcondition and not runcondition(testfunc,
1201 skipgenerator=False):
1202 if not (isinstance(params, InnerTest)
1203 and runcondition(params)):
1204 continue
1205 if not isinstance(params, (tuple, list)):
1206 params = (params, )
1207 func = params[0]
1208 args, kwargs = parse_generative_args(params[1:])
1209 # increment test counter manually
1210 result.testsRun += 1
1211 status = self._proceed(result, func, args, kwargs)
1212 if status == 0:
1213 result.addSuccess(self)
1214 success = True
1215 else:
1216 success = False
1217 if status == 2:
1218 result.shouldStop = True
1219 if result.shouldStop: # either on error or on exitfirst + error
1220 break
1221 except:
1222 # if an error occurs between two yield
1223 result.addError(self, self.__exc_info())
1224 success = False
1225 self._stop_capture()
1226 return success
1228 def _proceed(self, result, testfunc, args=(), kwargs=None):
1229 """proceed the actual test
1230 returns 0 on success, 1 on failure, 2 on error
1232 Note: addSuccess can't be called here because we have to wait
1233 for tearDown to be successfully executed to declare the test as
1234 successful
1235 """
1236 self._start_capture()
1237 kwargs = kwargs or {}
1238 try:
1239 testfunc(*args, **kwargs)
1240 self._stop_capture()
1241 except self.failureException:
1242 self._stop_capture()
1243 result.addFailure(self, self.__exc_info())
1244 return 1
1245 except KeyboardInterrupt:
1246 self._stop_capture()
1247 raise
1248 except InnerTestSkipped, e:
1249 result.addSkipped(self, e)
1250 return 1
1251 except:
1252 self._stop_capture()
1253 result.addError(self, self.__exc_info())
1254 return 2
1255 return 0
1257 def defaultTestResult(self):
1258 """return a new instance of the defaultTestResult"""
1259 return SkipAwareTestResult()
1261 def skip(self, msg=None):
1262 """mark a test as skipped for the <msg> reason"""
1263 msg = msg or 'test was skipped'
1264 raise TestSkipped(msg)
1266 def innerSkip(self, msg=None):
1267 """mark a generative test as skipped for the <msg> reason"""
1268 msg = msg or 'test was skipped'
1269 raise InnerTestSkipped(msg)
1271 def assertIn(self, object, set, msg=None):
1272 """assert <object> is in <set>
1274 :param object: a Python Object
1275 :param set: a Python Container
1276 :param msg: custom message (String) in case of failure
1277 """
1278 self.assert_(object in set, msg or "%s not in %s" % (object, set))
1280 def assertNotIn(self, object, set, msg=None):
1281 """assert <object> is not in <set>
1283 :param object: a Python Object
1284 :param set: the Python container to contain <object>
1285 :param msg: custom message (String) in case of failure
1286 """
1287 self.assert_(object not in set, msg or "%s in %s" % (object, set))
1289 def assertDictEquals(self, dict1, dict2, msg=None):
1290 """compares two dicts
1292 If the two dict differ, the first difference is shown in the error
1293 message
1294 :param dict1: a Python Dictionary
1295 :param dict2: a Python Dictionary
1296 :param msg: custom message (String) in case of failure
1297 """
1298 dict1 = dict(dict1)
1299 msgs = []
1300 for key, value in dict2.items():
1301 try:
1302 if dict1[key] != value:
1303 msgs.append('%r != %r for key %r' % (dict1[key], value,
1304 key))
1305 del dict1[key]
1306 except KeyError:
1307 msgs.append('missing %r key' % key)
1308 if dict1:
1309 msgs.append('dict2 is lacking %r' % dict1)
1310 if msg:
1311 self.failureException(msg)
1312 elif msgs:
1313 self.fail('\n'.join(msgs))
1314 assertDictEqual = assertDictEquals
1316 def assertUnorderedIterableEquals(self, got, expected, msg=None):
1317 """compares two iterable and shows difference between both
1319 :param got: the unordered Iterable that we found
1320 :param expected: the expected unordered Iterable
1321 :param msg: custom message (String) in case of failure
1322 """
1323 got, expected = list(got), list(expected)
1324 self.assertSetEqual(set(got), set(expected), msg)
1325 if len(got) != len(expected):
1326 if msg is None:
1327 msg = ['Iterable have the same elements but not the same number',
1328 '\t<element>\t<expected>i\t<got>']
1329 got_count = {}
1330 expected_count = {}
1331 for element in got:
1332 got_count[element] = got_count.get(element,0) + 1
1333 for element in expected:
1334 expected_count[element] = expected_count.get(element,0) + 1
1335 # we know that got_count.key() == expected_count.key()
1336 # because of assertSetEquals
1337 for element, count in got_count.iteritems():
1338 other_count = expected_count[element]
1339 if other_count != count:
1340 msg.append('\t%s\t%s\t%s' % (element, other_count, count))
1342 self.fail(msg)
1344 assertUnorderedIterableEqual = assertUnorderedIterableEquals
1345 assertUnordIterEquals = assertUnordIterEqual = assertUnorderedIterableEqual
1347 def assertSetEquals(self,got,expected, msg=None):
1348 """compares two sets and shows difference between both
1350 Don't use it for iterables other than sets.
1352 :param got: the Set that we found
1353 :param expected: the second Set to be compared to the first one
1354 :param msg: custom message (String) in case of failure
1355 """
1357 if not(isinstance(got, set) and isinstance(expected, set)):
1358 warnings.warn("the assertSetEquals function if now intended for set only."\
1359 "use assertUnorderedIterableEquals instead.",
1360 DeprecationWarning, 2)
1361 return self.assertUnorderedIterableEquals(got,expected, msg)
1363 items={}
1364 items['missing'] = expected - got
1365 items['unexpected'] = got - expected
1366 if any(items.itervalues()):
1367 if msg is None:
1368 msg = '\n'.join('%s:\n\t%s' % (key,"\n\t".join(str(value) for value in values))
1369 for key, values in items.iteritems() if values)
1370 self.fail(msg)
1373 assertSetEqual = assertSetEquals
1375 def assertListEquals(self, list_1, list_2, msg=None):
1376 """compares two lists
1378 If the two list differ, the first difference is shown in the error
1379 message
1381 :param list_1: a Python List
1382 :param list_2: a second Python List
1383 :param msg: custom message (String) in case of failure
1384 """
1385 _l1 = list_1[:]
1386 for i, value in enumerate(list_2):
1387 try:
1388 if _l1[0] != value:
1389 from pprint import pprint
1390 pprint(list_1)
1391 pprint(list_2)
1392 self.fail('%r != %r for index %d' % (_l1[0], value, i))
1393 del _l1[0]
1394 except IndexError:
1395 if msg is None:
1396 msg = 'list_1 has only %d elements, not %s '\
1397 '(at least %r missing)'% (i, len(list_2), value)
1398 self.fail(msg)
1399 if _l1:
1400 if msg is None:
1401 msg = 'list_2 is lacking %r' % _l1
1402 self.fail(msg)
1403 assertListEqual = assertListEquals
1405 def assertLinesEquals(self, string1, string2, msg=None, striplines=False):
1406 """compare two strings and assert that the text lines of the strings
1407 are equal.
1409 :param string1: a String
1410 :param string2: a String
1411 :param msg: custom message (String) in case of failure
1412 :param striplines: Boolean to trigger line stripping before comparing
1413 """
1414 lines1 = string1.splitlines()
1415 lines2 = string2.splitlines()
1416 if striplines:
1417 lines1 = [l.strip() for l in lines1]
1418 lines2 = [l.strip() for l in lines2]
1419 self.assertListEquals(lines1, lines2, msg)
1420 assertLineEqual = assertLinesEquals
1422 def assertXMLWellFormed(self, stream, msg=None, context=2):
1423 """asserts the XML stream is well-formed (no DTD conformance check)
1425 :param context: number of context lines in standard message
1426 (show all data if negative).
1427 Only available with element tree
1428 """
1429 try:
1430 from xml.etree.ElementTree import parse
1431 self._assertETXMLWellFormed(stream, parse, msg)
1432 except ImportError:
1433 from xml.sax import make_parser, SAXParseException
1434 parser = make_parser()
1435 try:
1436 parser.parse(stream)
1437 except SAXParseException, ex:
1438 if msg is None:
1439 stream.seek(0)
1440 for _ in xrange(ex.getLineNumber()):
1441 line = stream.readline()
1442 pointer = ('' * (ex.getLineNumber() - 1)) + '^'
1443 msg = 'XML stream not well formed: %s\n%s%s' % (ex, line, pointer)
1444 self.fail(msg)
1446 def assertXMLStringWellFormed(self, xml_string, msg=None, context=2):
1447 """asserts the XML string is well-formed (no DTD conformance check)
1449 :param context: number of context lines in standard message
1450 (show all data if negative).
1451 Only available with element tree
1452 """
1453 try:
1454 from xml.etree.ElementTree import fromstring
1455 except ImportError:
1456 from elementtree.ElementTree import fromstring
1457 self._assertETXMLWellFormed(xml_string, fromstring, msg)
1459 def _assertETXMLWellFormed(self, data, parse, msg=None, context=2):
1460 """internal function used by /assertXML(String)?WellFormed/ functions
1462 :param data: xml_data
1463 :param parse: appropriate parser function for this data
1464 :param msg: error message
1465 :param context: number of context lines in standard message
1466 (show all data if negative).
1467 Only available with element tree
1468 """
1469 from xml.parsers.expat import ExpatError
1470 try:
1471 parse(data)
1472 except ExpatError, ex:
1473 if msg is None:
1474 if hasattr(data, 'readlines'): #file like object
1475 stream.seek(0)
1476 lines = stream.readlines()
1477 else:
1478 lines =data.splitlines(True)
1479 nb_lines = len(lines)
1480 context_lines = []
1482 if context < 0:
1483 start = 1
1484 end = nb_lines
1485 else:
1486 start = max(ex.lineno-context, 1)
1487 end = min(ex.lineno+context, nb_lines)
1488 line_number_length = len('%i' % end)
1489 line_pattern = " %%%ii: %%s" % line_number_length
1491 for line_no in xrange(start, ex.lineno):
1492 context_lines.append(line_pattern % (line_no, lines[line_no-1]))
1493 context_lines.append(line_pattern % (ex.lineno, lines[ex.lineno-1]))
1494 context_lines.append('%s^\n' % (' ' * (1 + line_number_length + 2 +ex.offset)))
1495 for line_no in xrange(ex.lineno+1, end+1):
1496 context_lines.append(line_pattern % (line_no, lines[line_no-1]))
1498 rich_context = ''.join(context_lines)
1499 msg = 'XML stream not well formed: %s\n%s' % (ex, rich_context)
1500 self.fail(msg)
1503 def assertXMLEqualsTuple(self, element, tup):
1504 """compare an ElementTree Element to a tuple formatted as follow:
1505 (tagname, [attrib[, children[, text[, tail]]]])"""
1506 # check tag
1507 self.assertTextEquals(element.tag, tup[0])
1508 # check attrib
1509 if len(element.attrib) or len(tup)>1:
1510 if len(tup)<=1:
1511 self.fail( "tuple %s has no attributes (%s expected)"%(tup,
1512 dict(element.attrib)))
1513 self.assertDictEquals(element.attrib, tup[1])
1514 # check children
1515 if len(element) or len(tup)>2:
1516 if len(tup)<=2:
1517 self.fail( "tuple %s has no children (%i expected)"%(tup,
1518 len(element)))
1519 if len(element) != len(tup[2]):
1520 self.fail( "tuple %s has %i children%s (%i expected)"%(tup,
1521 len(tup[2]),
1522 ('', 's')[len(tup[2])>1], len(element)))
1523 for index in xrange(len(tup[2])):
1524 self.assertXMLEqualsTuple(element[index], tup[2][index])
1525 #check text
1526 if element.text or len(tup)>3:
1527 if len(tup)<=3:
1528 self.fail( "tuple %s has no text value (%r expected)"%(tup,
1529 element.text))
1530 self.assertTextEquals(element.text, tup[3])
1531 #check tail
1532 if element.tail or len(tup)>4:
1533 if len(tup)<=4:
1534 self.fail( "tuple %s has no tail value (%r expected)"%(tup,
1535 element.tail))
1536 self.assertTextEquals(element.tail, tup[4])
1538 def _difftext(self, lines1, lines2, junk=None, msg_prefix='Texts differ'):
1539 junk = junk or (' ', '\t')
1540 # result is a generator
1541 result = difflib.ndiff(lines1, lines2, charjunk=lambda x: x in junk)
1542 read = []
1543 for line in result:
1544 read.append(line)
1545 # lines that don't start with a ' ' are diff ones
1546 if not line.startswith(' '):
1547 self.fail('\n'.join(['%s\n'%msg_prefix]+read + list(result)))
1549 def assertTextEquals(self, text1, text2, junk=None,
1550 msg_prefix='Text differ', striplines=False):
1551 """compare two multiline strings (using difflib and splitlines())
1553 :param text1: a Python BaseString
1554 :param text2: a second Python Basestring
1555 :param junk: List of Caracters
1556 :param msg_prefix: String (message prefix)
1557 :param striplines: Boolean to trigger line stripping before comparing
1558 """
1559 msg = []
1560 if not isinstance(text1, basestring):
1561 msg.append('text1 is not a string (%s)'%(type(text1)))
1562 if not isinstance(text2, basestring):
1563 msg.append('text2 is not a string (%s)'%(type(text2)))
1564 if msg:
1565 self.fail('\n'.join(msg))
1566 lines1 = text1.strip().splitlines(True)
1567 lines2 = text2.strip().splitlines(True)
1568 if striplines:
1569 lines1 = [line.strip() for line in lines1]
1570 lines2 = [line.strip() for line in lines2]
1571 self._difftext(lines1, lines2, junk, msg_prefix)
1572 assertTextEqual = assertTextEquals
1574 def assertStreamEquals(self, stream1, stream2, junk=None,
1575 msg_prefix='Stream differ'):
1576 """compare two streams (using difflib and readlines())"""
1577 # if stream2 is stream2, readlines() on stream1 will also read lines
1578 # in stream2, so they'll appear different, although they're not
1579 if stream1 is stream2:
1580 return
1581 # make sure we compare from the beginning of the stream
1582 stream1.seek(0)
1583 stream2.seek(0)
1584 # compare
1585 self._difftext(stream1.readlines(), stream2.readlines(), junk,
1586 msg_prefix)
1588 assertStreamEqual = assertStreamEquals
1589 def assertFileEquals(self, fname1, fname2, junk=(' ', '\t')):
1590 """compares two files using difflib"""
1591 self.assertStreamEqual(file(fname1), file(fname2), junk,
1592 msg_prefix='Files differs\n-:%s\n+:%s\n'%(fname1, fname2))
1593 assertFileEqual = assertFileEquals
1596 def assertDirEquals(self, path_a, path_b):
1597 """compares two files using difflib"""
1598 assert osp.exists(path_a), "%s doesn't exists" % path_a
1599 assert osp.exists(path_b), "%s doesn't exists" % path_b
1601 all_a = [ (ipath[len(path_a):].lstrip('/'), idirs, ifiles)
1602 for ipath, idirs, ifiles in os.walk(path_a)]
1603 all_a.sort(key=itemgetter(0))
1605 all_b = [ (ipath[len(path_b):].lstrip('/'), idirs, ifiles)
1606 for ipath, idirs, ifiles in os.walk(path_b)]
1607 all_b.sort(key=itemgetter(0))
1609 iter_a, iter_b = iter(all_a), iter(all_b)
1610 partial_iter = True
1611 ipath_a, idirs_a, ifiles_a = data_a = None, None, None
1612 while True:
1613 try:
1614 ipath_a, idirs_a, ifiles_a = datas_a = iter_a.next()
1615 partial_iter = False
1616 ipath_b, idirs_b, ifiles_b = datas_b = iter_b.next()
1617 partial_iter = True
1620 self.assert_(ipath_a == ipath_b,
1621 "unexpected %s in %s while looking %s from %s" %
1622 (ipath_a, path_a, ipath_b, path_b))
1625 errors = {}
1626 sdirs_a = set(idirs_a)
1627 sdirs_b = set(idirs_b)
1628 errors["unexpected directories"] = sdirs_a - sdirs_b
1629 errors["missing directories"] = sdirs_b - sdirs_a
1631 sfiles_a = set(ifiles_a)
1632 sfiles_b = set(ifiles_b)
1633 errors["unexpected files"] = sfiles_a - sfiles_b
1634 errors["missing files"] = sfiles_b - sfiles_a
1637 msgs = [ "%s: %s"% (name, items)
1638 for name, items in errors.iteritems() if items]
1640 if msgs:
1641 msgs.insert(0,"%s and %s differ :" % (
1642 osp.join(path_a, ipath_a),
1643 osp.join(path_b, ipath_b),
1644 ))
1645 self.fail("\n".join(msgs))
1647 for files in (ifiles_a, ifiles_b):
1648 files.sort()
1650 for index, path in enumerate(ifiles_a):
1651 self.assertFileEquals(osp.join(path_a, ipath_a, path),
1652 osp.join(path_b, ipath_b, ifiles_b[index]))
1654 except StopIteration:
1655 break
1658 assertDirEqual = assertDirEquals
1661 def assertIsInstance(self, obj, klass, msg=None, strict=False):
1662 """check if an object is an instance of a class
1664 :param obj: the Python Object to be checked
1665 :param klass: the target class
1666 :param msg: a String for a custom message
1667 :param strict: if True, check that the class of <obj> is <klass>;
1668 else check with 'isinstance'
1669 """
1670 if msg is None:
1671 if strict:
1672 msg = '%r is not of class %s but of %s'
1673 else:
1674 msg = '%r is not an instance of %s but of %s'
1675 msg = msg % (obj, klass, type(obj))
1676 if strict:
1677 self.assert_(obj.__class__ is klass, msg)
1678 else:
1679 self.assert_(isinstance(obj, klass), msg)
1681 def assertIs(self, obj, other, msg=None):
1682 """compares identity of two reference
1684 :param obj: a Python Object
1685 :param other: another Python Object
1686 :param msg: a String for a custom message
1687 """
1688 if msg is None:
1689 msg = "%r is not %r"%(obj, other)
1690 self.assert_(obj is other, msg)
1693 def assertIsNot(self, obj, other, msg=None):
1694 """compares identity of two reference"""
1695 if msg is None:
1696 msg = "%r is %r"%(obj, other)
1697 self.assert_(obj is not other, msg )
1699 def assertNone(self, obj, msg=None):
1700 """assert obj is None
1702 :param obj: Python Object to be tested
1703 """
1704 if msg is None:
1705 msg = "reference to %r when None expected"%(obj,)
1706 self.assert_( obj is None, msg )
1708 def assertNotNone(self, obj, msg=None):
1709 """assert obj is not None"""
1710 if msg is None:
1711 msg = "unexpected reference to None"
1712 self.assert_( obj is not None, msg )
1714 def assertFloatAlmostEquals(self, obj, other, prec=1e-5, msg=None):
1715 """compares if two floats have a distance smaller than expected
1716 precision.
1718 :param obj: a Float
1719 :param other: another Float to be comparted to <obj>
1720 :param prec: a Float describing the precision
1721 :param msg: a String for a custom message
1722 """
1723 if msg is None:
1724 msg = "%r != %r" % (obj, other)
1725 self.assert_(math.fabs(obj - other) < prec, msg)
1727 def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
1728 """override default failUnlessRaise method to return the raised
1729 exception instance.
1731 Fail unless an exception of class excClass is thrown
1732 by callableObj when invoked with arguments args and keyword
1733 arguments kwargs. If a different type of exception is
1734 thrown, it will not be caught, and the test case will be
1735 deemed to have suffered an error, exactly as for an
1736 unexpected exception.
1738 :param excClass: the Exception to be raised
1739 :param callableObj: a callable Object which should raise <excClass>
1740 :param args: a List of arguments for <callableObj>
1741 :param kwargs: a List of keyword arguments for <callableObj>
1742 """
1743 try:
1744 callableObj(*args, **kwargs)
1745 except excClass, exc:
1746 return exc
1747 else:
1748 if hasattr(excClass, '__name__'):
1749 excName = excClass.__name__
1750 else:
1751 excName = str(excClass)
1752 raise self.failureException("%s not raised" % excName)
1754 assertRaises = failUnlessRaises
1756 import doctest
1758 class SkippedSuite(unittest.TestSuite):
1759 def test(self):
1760 """just there to trigger test execution"""
1761 self.skipped_test('doctest module has no DocTestSuite class')
1764 # DocTestFinder was introduced in python2.4
1765 if sys.version_info >= (2, 4):
1766 class DocTestFinder(doctest.DocTestFinder):
1768 def __init__(self, *args, **kwargs):
1769 self.skipped = kwargs.pop('skipped', ())
1770 doctest.DocTestFinder.__init__(self, *args, **kwargs)
1772 def _get_test(self, obj, name, module, globs, source_lines):
1773 """override default _get_test method to be able to skip tests
1774 according to skipped attribute's value
1776 Note: Python (<=2.4) use a _name_filter which could be used for that
1777 purpose but it's no longer available in 2.5
1778 Python 2.5 seems to have a [SKIP] flag
1779 """
1780 if getattr(obj, '__name__', '') in self.skipped:
1781 return None
1782 return doctest.DocTestFinder._get_test(self, obj, name, module,
1783 globs, source_lines)
1784 else:
1785 # this is a hack to make skipped work with python <= 2.3
1786 class DocTestFinder(object):
1787 def __init__(self, skipped):
1788 self.skipped = skipped
1789 self.original_find_tests = doctest._find_tests
1790 doctest._find_tests = self._find_tests
1792 def _find_tests(self, module, prefix=None):
1793 tests = []
1794 for testinfo in self.original_find_tests(module, prefix):
1795 testname, _, _, _ = testinfo
1796 # testname looks like A.B.C.function_name
1797 testname = testname.split('.')[-1]
1798 if testname not in self.skipped:
1799 tests.append(testinfo)
1800 return tests
1803 class DocTest(TestCase):
1804 """trigger module doctest
1805 I don't know how to make unittest.main consider the DocTestSuite instance
1806 without this hack
1807 """
1808 skipped = ()
1809 def __call__(self, result=None, runcondition=None, options=None):\
1810 # pylint: disable-msg=W0613
1811 try:
1812 finder = DocTestFinder(skipped=self.skipped)
1813 if sys.version_info >= (2, 4):
1814 suite = doctest.DocTestSuite(self.module, test_finder=finder)
1815 else:
1816 suite = doctest.DocTestSuite(self.module)
1817 except AttributeError:
1818 suite = SkippedSuite()
1819 return suite.run(result)
1820 run = __call__
1822 def test(self):
1823 """just there to trigger test execution"""
1825 MAILBOX = None
1827 class MockSMTP:
1828 """fake smtplib.SMTP"""
1830 def __init__(self, host, port):
1831 self.host = host
1832 self.port = port
1833 global MAILBOX
1834 self.reveived = MAILBOX = []
1836 def set_debuglevel(self, debuglevel):
1837 """ignore debug level"""
1839 def sendmail(self, fromaddr, toaddres, body):
1840 """push sent mail in the mailbox"""
1841 self.reveived.append((fromaddr, toaddres, body))
1843 def quit(self):
1844 """ignore quit"""
1847 class MockConfigParser(ConfigParser):
1848 """fake ConfigParser.ConfigParser"""
1850 def __init__(self, options):
1851 ConfigParser.__init__(self)
1852 for section, pairs in options.iteritems():
1853 self.add_section(section)
1854 for key, value in pairs.iteritems():
1855 self.set(section,key,value)
1856 def write(self, _):
1857 raise NotImplementedError()
1860 class MockConnection:
1861 """fake DB-API 2.0 connexion AND cursor (i.e. cursor() return self)"""
1863 def __init__(self, results):
1864 self.received = []
1865 self.states = []
1866 self.results = results
1868 def cursor(self):
1869 """Mock cursor method"""
1870 return self
1871 def execute(self, query, args=None):
1872 """Mock execute method"""
1873 self.received.append( (query, args) )
1874 def fetchone(self):
1875 """Mock fetchone method"""
1876 return self.results[0]
1877 def fetchall(self):
1878 """Mock fetchall method"""
1879 return self.results
1880 def commit(self):
1881 """Mock commiy method"""
1882 self.states.append( ('commit', len(self.received)) )
1883 def rollback(self):
1884 """Mock rollback method"""
1885 self.states.append( ('rollback', len(self.received)) )
1886 def close(self):
1887 """Mock close method"""
1888 pass
1891 def mock_object(**params):
1892 """creates an object using params to set attributes
1893 >>> option = mock_object(verbose=False, index=range(5))
1894 >>> option.verbose
1895 False
1896 >>> option.index
1897 [0, 1, 2, 3, 4]
1898 """
1899 return type('Mock', (), params)()
1902 def create_files(paths, chroot):
1903 """Creates directories and files found in <path>.
1905 :param paths: list of relative paths to files or directories
1906 :param chroot: the root directory in which paths will be created
1908 >>> from os.path import isdir, isfile
1909 >>> isdir('/tmp/a')
1910 False
1911 >>> create_files(['a/b/foo.py', 'a/b/c/', 'a/b/c/d/e.py'], '/tmp')
1912 >>> isdir('/tmp/a')
1913 True
1914 >>> isdir('/tmp/a/b/c')
1915 True
1916 >>> isfile('/tmp/a/b/c/d/e.py')
1917 True
1918 >>> isfile('/tmp/a/b/foo.py')
1919 True
1920 """
1921 dirs, files = set(), set()
1922 for path in paths:
1923 path = osp.join(chroot, path)
1924 filename = osp.basename(path)
1925 # path is a directory path
1926 if filename == '':
1927 dirs.add(path)
1928 # path is a filename path
1929 else:
1930 dirs.add(osp.dirname(path))
1931 files.add(path)
1932 for dirpath in dirs:
1933 if not osp.isdir(dirpath):
1934 os.makedirs(dirpath)
1935 for filepath in files:
1936 file(filepath, 'w').close()
1938 def enable_dbc(*args):
1939 """
1940 Without arguments, return True if contracts can be enabled and should be
1941 enabled (see option -d), return False otherwise.
1943 With arguments, return False if contracts can't or shouldn't be enabled,
1944 otherwise weave ContractAspect with items passed as arguments.
1945 """
1946 if not ENABLE_DBC:
1947 return False
1948 try:
1949 from logilab.aspects.weaver import weaver
1950 from logilab.aspects.lib.contracts import ContractAspect
1951 except ImportError:
1952 sys.stderr.write(
1953 'Warning: logilab.aspects is not available. Contracts disabled.')
1954 return False
1955 for arg in args:
1956 weaver.weave_module(arg, ContractAspect)
1957 return True
1960 class AttrObject: # XXX cf mock_object
1961 def __init__(self, **kwargs):
1962 self.__dict__.update(kwargs)
1964 def tag(*args, **kwargs):
1965 """descriptor adding tag to a function"""
1966 def desc(func):
1967 assert not hasattr(func, 'tags')
1968 func.tags = Tags(*args, **kwargs)
1969 return func
1970 return desc
1972 def require_version(version):
1973 """ Compare version of python interpreter to the given one. Skip the test
1974 if older.
1975 """
1976 def check_require_version(f):
1977 version_elements = version.split('.')
1978 try:
1979 compare = tuple([int(v) for v in version_elements])
1980 except ValueError:
1981 raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version)
1982 current = sys.version_info[:3]
1983 #print 'comp', current, compare
1984 if current < compare:
1985 #print 'version too old'
1986 def new_f(self, *args, **kwargs):
1987 self.skip('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current])))
1988 new_f.__name__ = f.__name__
1989 return new_f
1990 else:
1991 #print 'version young enough'
1992 return f
1993 return check_require_version
1995 def require_module(module):
1996 """ Check if the given module is loaded. Skip the test if not.
1997 """
1998 def check_require_module(f):
1999 try:
2000 __import__(module)
2001 #print module, 'imported'
2002 return f
2003 except ImportError:
2004 #print module, 'can not be imported'
2005 def new_f(self, *args, **kwargs):
2006 self.skip('%s can not be imported.' % module)
2007 new_f.__name__ = f.__name__
2008 return new_f
2009 return check_require_module