2 个版本,都没有解决内存持续增加,由于内存一直增加,最后因为小鸡内存不足而被杀掉。
这是一个加载训练后的模型,通过网络传入预测参数,然后返回预测 json 结果。
求大佬帮忙看看问题出在哪里,谢谢。
这是 chatgpt 4.0 给的版本
from flask import Flask, request, jsonify
import pickle
import os
import psutil
import pandas as pd
app = Flask(__name__)
class SingletonModel:
_instance = None
def __new__(cls):
if cls._instance is None:
print("Creating Singleton Instance")
cls._instance = super(SingletonModel, cls).__new__(cls)
modelName = "xgboost_model-k.pkl"
with open(modelName, "rb") as pkl_file:
loaded_data = pickle.load(pkl_file)
cls._instance.model = loaded_data['model']
cls._instance.scaler = loaded_data['scaler']
cls._instance.label_encoder = loaded_data['label_encoder']
cls._instance.feature_names = ['shortAvg','longAvg','volatility','diff']
return cls._instance
resources = SingletonModel()
model = resources.model
scaler = resources.scaler
label_encoder = resources.label_encoder
@app.route('/predict', methods=['POST'])
def predict():
global model, scaler, label_encoder
data = request.json['input']
df = pd.DataFrame([data], columns=resources.feature_names)
scaled_data = scaler.transform(df)
prediction = model.predict(scaled_data)
label_prediction = label_encoder.inverse_transform(prediction)
return jsonify([label_prediction[0]])
if __name__ == '__main__':
app.run(port=6601,debug=True)
这是 Claude 给的版本
import asyncio
from flask import Flask, request, jsonify
import pickle
import pandas as pd
app = Flask(__name__)
# 模型相关全局变量
model = None
scaler = None
label_encoder = None
async def load_model():
global model, scaler, label_encoder,feature_names
if not model:
with open('xgboost_model-k.pkl', 'rb') as f:
loaded_data = pickle.load(f)
model = loaded_data['model']
scaler = loaded_data['scaler']
label_encoder = loaded_data['label_encoder']
feature_names = ['shortAvg','longAvg','volatility','diff']
async def predict(data):
await load_model()
df = pd.DataFrame([data], columns=feature_names)
scaled_data = scaler.transform(df)
prediction = model.predict(scaled_data)
label_prediction = label_encoder.inverse_transform(prediction)
return label_prediction[0]
@app.route('/predict', methods=['POST'])
async def predict_handler():
data = request.json['input']
result = await asyncio.gather(predict(data))
return jsonify(result)
if __name__ == '__main__':
app.run(port=6601,debug=False)
这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.