-
源码教学:木森老师教你 30 行代码实现 ddt 模块
2023-08-24 16:58优化了一下,可以支持传入不同的参数,和pytest.mark.`parametrize一样,通过元组自定义参数
import pprint import unittest from copy import deepcopy from functools import wraps def update_func(func, name, data): @wraps('func') def wrap(self): if len(name) == 1: kwargs = {name[0]: data} else: kwargs = dict(zip(name, data)) return func(self, **kwargs) return wrap def ddt(cls): new_cls_attrs = deepcopy(dict(cls.__dict__)) for name, func in new_cls_attrs.items(): if hasattr(func, 'params'): params = getattr(func, 'params') params_name, params_data = params for index, value in enumerate(params_data): test_func = update_func(func, params_name, value) setattr(cls, name + '_' + str(index), test_func) pass else: delattr(cls, name) return cls def data(params_name, params_data): def inner(func): setattr(func, 'params', (params_name, params_data)) return func return inner @ddt class TestLogin(unittest.TestCase): @data(('a'), [[1, 2], [2, 4], [3, 4], [5, 6], (5, 9)]) def test_login_1(self, a): print(a) pprint.pprint(list(TestLogin.__dict__)) unittest.main()