Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to test a decorator in a python package

I am writing my first python package and I want to write unit tests for the following decorator:

class MaxTriesExceededError(Exception):
    pass

def tries(max_tries=3, error_message=os.strerror(errno.ETIME)):
    def decorator(func):
        try_count = 0
        def wrapper(*args, **kwargs):
            try_count+=1
            try:
                if try_count <= max_tries:
                    result = func(*args,**kwargs)
                    return result
                else:
                    raise MaxTriesExceededError(error_message)
            except:
                if try_count <= max_tries:
                    wrapper(*args,**kwargs)
                else:
                    raise Exception

        return wraps(func)(wrapper)

    return decorator

The purpose of the decorator is to throw an error if the function fails more than max_tries, but to eat the error and try again if the max try count has not been exceeded. To be honest, I'm not sure that the code doesn't have bugs. My question is therefore twofold, is the code correct, and how do I write unit tests for it using unittest?

like image 530
sakurashinken Avatar asked Jan 07 '16 06:01

sakurashinken


1 Answers

Here is a corrected version, with unittests:

class MaxTriesExceededError(Exception):
    pass

def tries(max_tries=3, error_message="failure"):
    def decorator(func):
        def wrapper(*args, **kwargs):
            for try_count in range(max_tries):
              try:
                return func(*args,**kwargs)
              except:
                pass
            raise MaxTriesExceededError(error_message)
        return wrapper
    return decorator


import unittest

class TestDecorator(unittest.TestCase):

  def setUp(self):
      self.count = 0

  def test_success_single_try(self):
      @tries(1)
      def a():
          self.count += 1
          return "expected_result"
      self.assertEqual(a(), "expected_result")
      self.assertEqual(self.count, 1)

  def test_success_two_tries(self):
      @tries(2)
      def a():
          self.count += 1
          return "expected_result"
      self.assertEqual(a(), "expected_result")
      self.assertEqual(self.count, 1)

  def test_failure_two_tries(self):
      @tries(2)
      def a():
           self.count += 1
           raise Exception()
      try:
        a()
        self.fail()
      except MaxTriesExceededError:
        self.assertEqual(self.count,2)

  def test_success_after_third_try(self):
      @tries(5)
      def a():
           self.count += 1
           if self.count==3:
             return "expected_result"
           else:
             raise Exception()
      self.assertEqual(a(), "expected_result")
      self.assertEqual(self.count, 3)

if __name__ == '__main__':
    unittest.main()
like image 60
Adi Levin Avatar answered Oct 05 '22 22:10

Adi Levin