关于 Tensorflow 使用 Dataset API 占用内存高的问题

2019-02-24 21:18:58 +08:00
 tomleung1996

初学深度学习,看的是《 Hands on Machine Learning with Scikit-Learn and Tensorflow 》这本书,书中用自己定义的shuffle_batch函数实现将数据分批输入神经网络的功能,数据集用的是 MNIST。

书上的函数定义如下:

def shuffle_batch(X, y, batch_size):
    rnd_idx = np.random.permutation(len(X))
    n_batches = len(X) // batch_size
    for batch_idx in np.array_split(rnd_idx, n_batches):
        X_batch, y_batch = X[batch_idx], y[batch_idx]
        yield X_batch, y_batch

楼主上网搜索一下发现用 Dataset API 和它的shufflebatchrepeat函数可能可以更加优雅地实现分批输入的功能,于是就写了下面的代码:

train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_data = train_data.shuffle(m)
train_data = train_data.batch(batch_size)
train_data = train_data.repeat()
td_iter = train_data.make_one_shot_iterator()
features, labels = td_iter.get_next()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(n_epochs):
        for iteration in range(n_batchs):
            X_batch, y_batch = sess.run([features, labels])
            sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
        acc_train = accuracy.eval(feed_dict={X:X_batch, y:y_batch})
        acc_test = accuracy.eval(feed_dict={X:X_test, y:y_test})
        print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
    save_path = saver.save(sess, './my_model')

但是我发现这段代码虽然也能训练出类似精度的模型,但是在打印出第一个 epoch 的输出前,内存占用极高,而且要等好久才会有第一个输出(后面的输出就花费正常时间)。

如果是按照书上的代码来训练(不使用 Dataset API ),内存几乎没有任何波动。但是我觉得就算是用了 Dataset API,MNIST 这个数据集也不大吧?要占用这么多内存么?

同样的内存占用情况也发生在下面的代码:

with tf.Session() as sess:
    sess.run(init)
    sess.run([features, labels])

我觉得是不是我代码哪里写错了?因为刚接触这个 API,是模仿人家的写法写的,希望大家解答下疑惑哈

2405 次点击
所在节点    问与答
0 条回复

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/538244

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX