18273569617
7097 号测试 成员, 2023-08-24 16:53:52 加入
468
个人主页浏览
4h22m
在线时长
  • 源码教学:木森老师教你 30 行代码实现 ddt 模块

    2023-08-24 16:58

    优化了一下,可以支持传入不同的参数,和pytest.mark.`parametrize一样,通过元组自定义参数

    import pprint
    import unittest
    
    from copy import deepcopy
    
    from functools import wraps
    
    def update_func(func, name, data):
        @wraps('func')
        def wrap(self):
            if len(name) == 1:
                kwargs = {name[0]: data}
            else:
                kwargs = dict(zip(name, data))
            return func(self, **kwargs)
        return wrap
    
    def ddt(cls):
        new_cls_attrs = deepcopy(dict(cls.__dict__))
        for name, func in new_cls_attrs.items():
            if hasattr(func, 'params'):
                params = getattr(func, 'params')
                params_name, params_data = params
                for index, value in enumerate(params_data):
                    test_func = update_func(func, params_name, value)
                    setattr(cls, name + '_' + str(index), test_func)
                    pass
                else:
                    delattr(cls, name)
        return cls
    
    def data(params_name, params_data):
        def inner(func):
           setattr(func, 'params', (params_name, params_data))
           return func
        return inner
    
    
    @ddt
    class TestLogin(unittest.TestCase):
        @data(('a'), [[1, 2], [2, 4], [3, 4], [5, 6], (5, 9)])
        def test_login_1(self, a):
            print(a)
    
    pprint.pprint(list(TestLogin.__dict__))
    unittest.main()