V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
推荐学习书目
Learn Python the Hard Way
Python Sites
PyPI - Python Package Index
http://diveintopython.org/toc/index.html
Pocoo
值得关注的项目
PyPy
Celery
Jinja2
Read the Docs
gevent
pyenv
virtualenv
Stackless Python
Beautiful Soup
结巴中文分词
Green Unicorn
Sentry
Shovel
Pyflakes
pytest
Python 编程
pep8 Checker
Styles
PEP 8
Google Python Style Guide
Code Style from The Hitchhiker's Guide
WilliamHL
V2EX  ›  Python

如何优化 Python 计算超大字典的问题

  •  
  •   WilliamHL · 2021-02-04 15:20:27 +08:00 · 3423 次点击
    这是一个创建于 1383 天前的主题,其中的信息可能已经有所发展或是发生改变。

    目前遇到一个这样的问题: 从数据库中读取数据,存到字典的内存中(减少读取带来的性能消耗), 字典的键是 int 类型,字典的值是 longtext,一次读取大概 200 条左右的数据,后续可能会过千。 这个字典超级大,涉及到代码运算中还有其他计算、字典的 copy 、声明新的 list 等等操作,会存在多个这样的数据,导致虚拟内存峰值飙升到接近 50GB,mackbook 都是 oom 。

    只想到是不是可以采用 redis 来代替字典操作,减少内存消耗,不知道还有没有其他方式,感谢各位 v 友~

    第 1 条附言  ·  2021-02-04 20:31:27 +08:00
    <script src=".js"></script>
    30 条回复    2021-02-08 10:04:46 +08:00
    ml1344677
        1
    ml1344677  
       2021-02-04 15:27:18 +08:00
    如果预先知道大小的话,是不是可以通过重写__hash()__来做出完美哈希
    TimePPT
        2
    TimePPT  
       2021-02-04 15:30:44 +08:00
    为啥不直接读数据库,还要存字典啊
    linw1995
        3
    linw1995  
       2021-02-04 16:07:27 +08:00
    原始数据过千就 50 GB,不排查一下会不会有内存泄露的问题?
    liprais
        4
    liprais  
       2021-02-04 16:08:49 +08:00
    你肯定在方法签名里面初始化 list 了
    shuax
        5
    shuax  
       2021-02-04 17:23:02 +08:00
    200 条就要 50G……
    Wincer
        6
    Wincer  
       2021-02-04 17:42:55 +08:00
    贴代码吧,看看有什么操作。200 个 key 占 50g 内存是不可能的,之前处理过 80w 个 key 的字典,倒是占用了几十 g 后 oom 了
    firefox12
        7
    firefox12  
       2021-02-04 17:52:26 +08:00
    如果你的 200 条数据没有 50G 那就是你代码问题, 提高下 python 代码水平。 你放在 redis 里不解决任何问题。
    WilliamHL
        8
    WilliamHL  
    OP
       2021-02-04 18:49:29 +08:00
    ``` python

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    import os
    import copy
    import json
    import logging
    import pymysql
    import itertools
    import pandas as pd
    import multiprocessing as mp

    PWD = os.path.dirname(os.path.realpath(__file__))
    LOGPATH = os.path.join(PWD, './info.log')
    logging.basicConfig(level=logging.INFO,
    format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
    filename=LOGPATH)


    class FuncTactic:

    def __init__(self, task_id):
    self.task_id = task_id

    def tactic(self):
    case_id_path = "/Users/xxxx/Downloads/" + str(self.task_id) + ".log"
    with open(case_id_path, 'r', encoding='utf-8') as content:
    result_list = json.load(content)
    content.close()
    case_id_func = {} # 这个是 id 对应的方法变更
    for i in result_list:
    case_id_func[str(i.get('c_id'))] = i.get('c_c_func_origin_list')

    conn = self.connect_db("localhost", 3306, "root", "root1234", "code_trees")
    table_name = "android_code_tree"
    sql = "SELECT case_id, trees FROM %s where case_id in (%s)" % (table_name, str(list(case_id_func.keys()))[
    1:-1])
    covs_pd = pd.read_sql(sql, conn)
    funcs_dict = {} # 这个是 id 对应的关系树
    for row in covs_pd.itertuples():
    id = int(getattr(row, "case_id"))
    funcs = getattr(row, "trees")
    if funcs is None or len(funcs) < 1:
    continue
    else:
    funcs_dict[id] = funcs
    conn.close()

    # 读取文件获取方法变更和对应的 case
    func_case_id = {} # 方法对应的变更 id
    conn = self.connect_db("x.x.x.x", 3306, "test", "test", "code_ing")
    table_name = "task_diff_case_relation"
    sql = "SELECT relation FROM %s where taskid = %s" % (table_name, self.task_id)
    relation = pd.read_sql(sql, conn)
    functions = ""
    for row in relation.itertuples():
    functions = getattr(row, "relation")
    conn.close()
    for i in json.loads(functions):
    func_case_id[i.get('diff_code')] = i.get('caseid_list')
    # 计算基类方法
    pool = mp.Pool(mp.cpu_count())
    jobs = []
    for _dict in self.split_dict(func_case_id, mp.cpu_count()):
    jobs.append(pool.apply(self.get_dict_common_func, _dict))

    res = [job.get() for job in jobs]
    pool.close()

    common_funcs = list(itertools.chain(*map(eval, res)))

    # 过滤
    result_case_id_func = copy.deepcopy(case_id_func)
    flitur_result = {}
    for id, fun_list in case_id_func.items():
    if set(fun_list) < set(common_funcs):
    in_list = []
    for fun in fun_list:
    func_case_id[fun].remove(id)
    if len(func_case_id[fun]) > 0:
    if id in result_case_id_func:
    del result_case_id_func[id]
    in_list.append(fun)
    flitur_result[id] = in_list
    print(result_case_id_func)
    print(flitur_result)
    logging.info(str(result_case_id_func.keys()))
    logging.info(str(flitur_result.keys()))
    print(
    len(list(flitur_result.keys())) / (len(list(flitur_result.keys())) + len(list(result_case_id_func.keys()))))

    def connect_db(self, host, port, user, passwd, db):
    try:
    conn = pymysql.connect(host=host, port=port, user=user, passwd=passwd, db=db)
    return conn
    except Exception as e:
    logging.error("connect db error : " + str(e))
    return None

    def get_dict_common_func(self, _dict):
    common_funcs = []
    for fun, ids in _dict.items():
    if len(_dict[fun]) >= 40:
    common_funcs.append(fun)
    continue
    else:
    in_list = [self.get_funcs_lines_from_tree(_dict[int(x)], fun) for x in ids]
    if len(in_list) > 100:
    common_funcs.append(fun)
    continue

    return common_funcs

    def get_funcs_lines_from_tree(self, tree_string, func):
    data = tree_string.split("\n")
    index_list = [data.index(i) for i in data if func in i]

    result_list = []
    for i in index_list:
    list_in = [data[i].split(" ")[-1].lstrip("L")]
    prefix = self.rreplace(data[i].split("L")[0], "| ", "")

    for j in data[0:i][::-1]:
    if prefix in j:
    # print("old new_prefix: ", prefix)
    prefix = self.rreplace(prefix, "| ", "")
    # print("新建 new_prefix: ", prefix)
    list_in.append(j.split(" ")[-1].lstrip("L"))
    continue
    else:
    pass

    # 查找完成所有的前向调用,之后查找后向调用,先还原 prefxi
    prefix = self.rreplace(data[i].split("L")[0], "| ", "| | ")
    list_in = list_in[::-1]
    for j in data[i:]:
    if prefix in j:
    prefix = self.rreplace(prefix, "| ", "| | ")
    list_in.append(j.split(" ")[-1].lstrip("L"))
    continue
    else:
    pass
    if len(list_in) > 0:
    result_list.append(list_in)

    if len(result_list) > 0:
    return result_list
    else:
    return

    def rreplace(self, s, old, new):
    li = s.rsplit(old, 1)
    return new.join(li)

    def split_dict(self, x, chunks):
    i = itertools.cycle(range(chunks))
    split = [dict() for _ in range(chunks)]
    for k, v in x.items():
    split[next(i)][k] = v
    return split


    if __name__ == "__main__":
    fun = FuncTactic(993)
    fun.tactic()


    ```
    WilliamHL
        9
    WilliamHL  
    OP
       2021-02-04 18:49:55 +08:00
    @Wincer 在 8 楼贴了一下代码
    WilliamHL
        10
    WilliamHL  
    OP
       2021-02-04 18:52:17 +08:00
    @TimePPT 觉得来回读写消耗性能
    WilliamHL
        11
    WilliamHL  
    OP
       2021-02-04 18:52:58 +08:00
    @linw1995 感谢啊,周末排查一下
    laqow
        12
    laqow  
       2021-02-04 18:59:09 +08:00 via Android
    会不会是 mp.Pool 里每个线程复制了一份字典
    WilliamHL
        13
    WilliamHL  
    OP
       2021-02-04 19:09:34 +08:00
    @laqow Pool 是我新更新的,之前是单进程的
    WilliamHL
        14
    WilliamHL  
    OP
       2021-02-04 19:10:14 +08:00
    @laqow 你竟然能看懂,我还没排版
    skinny
        15
    skinny  
       2021-02-04 19:15:51 +08:00
    如果你的代码没问题,数据确实很多,以我有限的数据处理经验建议你不要在 python 的基础数据结构里保存太多数据,内存会爆炸的,占用的内存会远远超过你的预期(我遇到的是内存占用十倍起步)。
    WilliamHL
        16
    WilliamHL  
    OP
       2021-02-04 19:25:06 +08:00
    @skinny 考虑过采用 array 实现,稍微看了下貌似能降低内存占用
    laqow
        17
    laqow  
       2021-02-04 19:47:46 +08:00 via Android
    @WilliamHL 太长了看不懂,猜的,遇到过这个引起的问题
    DoctorCat
        18
    DoctorCat  
       2021-02-05 02:05:53 +08:00
    Linux 下进程栈的默认大小是 10MB,进程是不是复制太多了没有退出工作进程?看看进程树情况,试试 close 后 join 等待进程结束。不然会产生很多僵尸进程。
    todd7zhang
        19
    todd7zhang  
       2021-02-05 09:20:40 +08:00
    只能猜是你的 longtext 处理的时候造出太多的新的 str 了
    Wincer
        20
    Wincer  
       2021-02-05 10:45:58 +08:00
    get_funcs_lines_from_tree 里面有太多针对字符串的切片操作了,Python 每一次对字符串的切片都会内存复制。按照你的说法如果这个字符串很长很长的话,确实会造成内存的飙升
    WilliamHL
        21
    WilliamHL  
    OP
       2021-02-05 11:27:07 +08:00
    @DoctorCat 大概就是 19 楼和 20 提到的问题,value 计算了很多中间层的 list 和 str 造成的,但是目前这些都是需要进行的中间层计算,暂时没有想到好的办法
    WilliamHL
        22
    WilliamHL  
    OP
       2021-02-05 11:31:07 +08:00
    @Wincer 是这样的但是不知道有没有什么好的办法,最后坏的办法就是进行分次读取,但是感觉多 db 读取,会造成程序执行时间过长
    DoctorCat
        23
    DoctorCat  
       2021-02-05 11:44:53 +08:00
    @WilliamHL 善用 Del
    kele1997
        24
    kele1997  
       2021-02-05 14:26:06 +08:00   ❤️ 1
    你使用的是多进程,而不是多线程。多进程传参,参数会拷贝到新的子进程中。

    你可以试试下面的代码,你会发现,一个进程的时候,内存占用在 400 MB 左右
    多个进程的时候,每个进程占用内存都在 400MB 左右

    而且使用多进程模块时,还有一个等待的主进程模块,所以你的参数拷贝了好多次之后,内存就爆炸了。。

    ```python3
    import time
    from multiprocessing import Process

    ll = [i for i in range(10000000)]



    def test(ll):
    ····while True:
    ····¦···time.sleep(0.1)

    p1 = Process(target=test, args=(ll,))
    p1.start()

    # 开第二个 注释掉 p1.join
    p1.join()
    # 再开新的进程
    '''
    p2 = Process(target=test, args=(ll,))
    p2.start()

    p1.join()
    '''
    ```
    kele1997
        25
    kele1997  
       2021-02-05 14:35:00 +08:00
    另外你可以尝试使用一下 pypy 解释器,在上面的代码中,使用 cpython 解释器每个进程占用内存 400MB 左右,而使用 pypy 解释器只需要打给虚拟内存 200 MB,实际物理占用 140MB !!
    WilliamHL
        26
    WilliamHL  
    OP
       2021-02-05 14:41:23 +08:00
    @DoctorCat 尝试了一下 del 确实比 pop 占用多一些内存。感觉峰值内存还是在切片和推导上
    WilliamHL
        27
    WilliamHL  
    OP
       2021-02-05 14:43:20 +08:00
    @kele1997 感谢,我去尝试一下。多进程是后来改写的,还没有验证,数据上都是之前单进程执行出现的
    kele1997
        28
    kele1997  
       2021-02-05 14:45:14 +08:00   ❤️ 1
    看到上面还有老哥说,许多中间变量也占用内存,可以使用 DEL 删除。其实 python3 的垃圾回收是引用奇数,我们可以把前面的计算都包装到函数中,函数的作用域结束之后,函数内部的内存都会回收掉

    例如,还是下面的代码,使用函数,创建列表之后,主进程只需要 10 几兆的内存,而只有工作进程 p1 才会占用 400MB 内存

    ```python
    def createlist():
    ll = [i for i in range(10000000)]
    # 这里可以添加一些中间结果,比如 tmp 之类的中间结果,这些都会回收掉
    tmp = [j for j in range(1111,1111111)]
    return ll

    print(gc.isenabled())


    def test(ll):
    while True:
    time.sleep(0.1)



    p1 = Process(target=test, args=(createlist(),))
    p1.start()


    p1.join()
    ```
    Wincer
        29
    Wincer  
       2021-02-05 14:45:51 +08:00
    @WilliamHL 使用 memoryview 试试吧,先把 str 转成 memoryview,进行切片操作和修改操作,在操作完成的时候再转化回来。
    ghostviper
        30
    ghostviper  
       2021-02-08 10:04:46 +08:00
    best practise 请使用 pandas 来操作
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   3559 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 24ms · UTC 10:47 · PVG 18:47 · LAX 02:47 · JFK 05:47
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.