如何优化 Python 计算超大字典的问题

2021-02-04 15:20:27 +08:00
 WilliamHL

目前遇到一个这样的问题: 从数据库中读取数据,存到字典的内存中(减少读取带来的性能消耗), 字典的键是 int 类型,字典的值是 longtext,一次读取大概 200 条左右的数据,后续可能会过千。 这个字典超级大,涉及到代码运算中还有其他计算、字典的 copy 、声明新的 list 等等操作,会存在多个这样的数据,导致虚拟内存峰值飙升到接近 50GB,mackbook 都是 oom 。

只想到是不是可以采用 redis 来代替字典操作,减少内存消耗,不知道还有没有其他方式,感谢各位 v 友~

3471 次点击
所在节点    Python
30 条回复
ml1344677
2021-02-04 15:27:18 +08:00
如果预先知道大小的话,是不是可以通过重写__hash()__来做出完美哈希
TimePPT
2021-02-04 15:30:44 +08:00
为啥不直接读数据库,还要存字典啊
linw1995
2021-02-04 16:07:27 +08:00
原始数据过千就 50 GB,不排查一下会不会有内存泄露的问题?
liprais
2021-02-04 16:08:49 +08:00
你肯定在方法签名里面初始化 list 了
shuax
2021-02-04 17:23:02 +08:00
200 条就要 50G……
Wincer
2021-02-04 17:42:55 +08:00
贴代码吧,看看有什么操作。200 个 key 占 50g 内存是不可能的,之前处理过 80w 个 key 的字典,倒是占用了几十 g 后 oom 了
firefox12
2021-02-04 17:52:26 +08:00
如果你的 200 条数据没有 50G 那就是你代码问题, 提高下 python 代码水平。 你放在 redis 里不解决任何问题。
WilliamHL
2021-02-04 18:49:29 +08:00
``` python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import copy
import json
import logging
import pymysql
import itertools
import pandas as pd
import multiprocessing as mp

PWD = os.path.dirname(os.path.realpath(__file__))
LOGPATH = os.path.join(PWD, './info.log')
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
filename=LOGPATH)


class FuncTactic:

def __init__(self, task_id):
self.task_id = task_id

def tactic(self):
case_id_path = "/Users/xxxx/Downloads/" + str(self.task_id) + ".log"
with open(case_id_path, 'r', encoding='utf-8') as content:
result_list = json.load(content)
content.close()
case_id_func = {} # 这个是 id 对应的方法变更
for i in result_list:
case_id_func[str(i.get('c_id'))] = i.get('c_c_func_origin_list')

conn = self.connect_db("localhost", 3306, "root", "root1234", "code_trees")
table_name = "android_code_tree"
sql = "SELECT case_id, trees FROM %s where case_id in (%s)" % (table_name, str(list(case_id_func.keys()))[
1:-1])
covs_pd = pd.read_sql(sql, conn)
funcs_dict = {} # 这个是 id 对应的关系树
for row in covs_pd.itertuples():
id = int(getattr(row, "case_id"))
funcs = getattr(row, "trees")
if funcs is None or len(funcs) < 1:
continue
else:
funcs_dict[id] = funcs
conn.close()

# 读取文件获取方法变更和对应的 case
func_case_id = {} # 方法对应的变更 id
conn = self.connect_db("x.x.x.x", 3306, "test", "test", "code_ing")
table_name = "task_diff_case_relation"
sql = "SELECT relation FROM %s where taskid = %s" % (table_name, self.task_id)
relation = pd.read_sql(sql, conn)
functions = ""
for row in relation.itertuples():
functions = getattr(row, "relation")
conn.close()
for i in json.loads(functions):
func_case_id[i.get('diff_code')] = i.get('caseid_list')
# 计算基类方法
pool = mp.Pool(mp.cpu_count())
jobs = []
for _dict in self.split_dict(func_case_id, mp.cpu_count()):
jobs.append(pool.apply(self.get_dict_common_func, _dict))

res = [job.get() for job in jobs]
pool.close()

common_funcs = list(itertools.chain(*map(eval, res)))

# 过滤
result_case_id_func = copy.deepcopy(case_id_func)
flitur_result = {}
for id, fun_list in case_id_func.items():
if set(fun_list) < set(common_funcs):
in_list = []
for fun in fun_list:
func_case_id[fun].remove(id)
if len(func_case_id[fun]) > 0:
if id in result_case_id_func:
del result_case_id_func[id]
in_list.append(fun)
flitur_result[id] = in_list
print(result_case_id_func)
print(flitur_result)
logging.info(str(result_case_id_func.keys()))
logging.info(str(flitur_result.keys()))
print(
len(list(flitur_result.keys())) / (len(list(flitur_result.keys())) + len(list(result_case_id_func.keys()))))

def connect_db(self, host, port, user, passwd, db):
try:
conn = pymysql.connect(host=host, port=port, user=user, passwd=passwd, db=db)
return conn
except Exception as e:
logging.error("connect db error : " + str(e))
return None

def get_dict_common_func(self, _dict):
common_funcs = []
for fun, ids in _dict.items():
if len(_dict[fun]) >= 40:
common_funcs.append(fun)
continue
else:
in_list = [self.get_funcs_lines_from_tree(_dict[int(x)], fun) for x in ids]
if len(in_list) > 100:
common_funcs.append(fun)
continue

return common_funcs

def get_funcs_lines_from_tree(self, tree_string, func):
data = tree_string.split("\n")
index_list = [data.index(i) for i in data if func in i]

result_list = []
for i in index_list:
list_in = [data[i].split(" ")[-1].lstrip("L")]
prefix = self.rreplace(data[i].split("L")[0], "| ", "")

for j in data[0:i][::-1]:
if prefix in j:
# print("old new_prefix: ", prefix)
prefix = self.rreplace(prefix, "| ", "")
# print("新建 new_prefix: ", prefix)
list_in.append(j.split(" ")[-1].lstrip("L"))
continue
else:
pass

# 查找完成所有的前向调用,之后查找后向调用,先还原 prefxi
prefix = self.rreplace(data[i].split("L")[0], "| ", "| | ")
list_in = list_in[::-1]
for j in data[i:]:
if prefix in j:
prefix = self.rreplace(prefix, "| ", "| | ")
list_in.append(j.split(" ")[-1].lstrip("L"))
continue
else:
pass
if len(list_in) > 0:
result_list.append(list_in)

if len(result_list) > 0:
return result_list
else:
return

def rreplace(self, s, old, new):
li = s.rsplit(old, 1)
return new.join(li)

def split_dict(self, x, chunks):
i = itertools.cycle(range(chunks))
split = [dict() for _ in range(chunks)]
for k, v in x.items():
split[next(i)][k] = v
return split


if __name__ == "__main__":
fun = FuncTactic(993)
fun.tactic()


```
WilliamHL
2021-02-04 18:49:55 +08:00
@Wincer 在 8 楼贴了一下代码
WilliamHL
2021-02-04 18:52:17 +08:00
@TimePPT 觉得来回读写消耗性能
WilliamHL
2021-02-04 18:52:58 +08:00
@linw1995 感谢啊,周末排查一下
laqow
2021-02-04 18:59:09 +08:00
会不会是 mp.Pool 里每个线程复制了一份字典
WilliamHL
2021-02-04 19:09:34 +08:00
@laqow Pool 是我新更新的,之前是单进程的
WilliamHL
2021-02-04 19:10:14 +08:00
@laqow 你竟然能看懂,我还没排版
skinny
2021-02-04 19:15:51 +08:00
如果你的代码没问题,数据确实很多,以我有限的数据处理经验建议你不要在 python 的基础数据结构里保存太多数据,内存会爆炸的,占用的内存会远远超过你的预期(我遇到的是内存占用十倍起步)。
WilliamHL
2021-02-04 19:25:06 +08:00
@skinny 考虑过采用 array 实现,稍微看了下貌似能降低内存占用
laqow
2021-02-04 19:47:46 +08:00
@WilliamHL 太长了看不懂,猜的,遇到过这个引起的问题
DoctorCat
2021-02-05 02:05:53 +08:00
Linux 下进程栈的默认大小是 10MB,进程是不是复制太多了没有退出工作进程?看看进程树情况,试试 close 后 join 等待进程结束。不然会产生很多僵尸进程。
todd7zhang
2021-02-05 09:20:40 +08:00
只能猜是你的 longtext 处理的时候造出太多的新的 str 了
Wincer
2021-02-05 10:45:58 +08:00
get_funcs_lines_from_tree 里面有太多针对字符串的切片操作了,Python 每一次对字符串的切片都会内存复制。按照你的说法如果这个字符串很长很长的话,确实会造成内存的飙升

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/751272

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX