看 RNN 的 paper 大多数集中在 RNNcell 内部构建,少数涉及 units 之间交互,
Tensorflow 提供了几种最流行的 RNN 变种类,但没有 CNN 编写方便,这里分享一段使用 tf.scan 构建 GRUcell 代码,可以作为自定义 RNNcell 的参考。
import numpy as np
import pandas as pd
import tensorflow as tf
import pylab as pl
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
%matplotlib inline
class GRUcell(object):
def __init__(self):
self.in_length= 28
self.in_width= 28
self.hidden_layer_size = 2000
self.out_classes = 10
self.Wr = tf.Variable(tf.zeros([self.in_width, self.hidden_layer_size]))
self.Wz = tf.Variable(tf.zeros([self.in_width, self.hidden_layer_size]))
self.W_ = tf.Variable(tf.zeros([self.in_width, self.hidden_layer_size]))
self.Ur = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.hidden_layer_size]))
self.Uz = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.hidden_layer_size]))
self.U_ = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.hidden_layer_size]))
self.Wout = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.out_classes], mean=0., stddev=.1))
self.bout = tf.Variable(tf.truncated_normal([self.out_classes], mean=0., stddev=.1))
self.inX = tf.placeholder(shape=[None, self.in_length, self.in_width], dtype=tf.float32)
self.initial_hidden = tf.matmul(self.inX[:,0,:], tf.zeros([self.in_width, self.hidden_layer_size]))
self.X = tf.transpose(self.inX, perm=[1,0,2])
def GRU(self, hidden_states_previous, current_input_X):
"""
GRU topology unit
Note that the input order above is for the fn function
The two tensors are entered for the fn function,
the first tensor is the output calculated in the previous step,
and the second tensor is the input value at this time
"""
hp = hidden_states_previous
x = current_input_X
r = tf.sigmoid(tf.matmul(x, self.Wr) + tf.matmul(hp, self.Ur))
z = tf.sigmoid(tf.matmul(x, self.Wz) + tf.matmul(hp, self.Uz))
h_ = tf.tanh(tf.matmul(x, self.W_) + tf.matmul(r*hp ,self.U_))
h = tf.multiply(hp,z) + tf.multiply((1-z),h_)
return h
def PRO_TS(self):
"""
Perform recursive operations in time series
Iterates through time/ sequence to get all hidden state
Input format : [in_length, batch_size, in_width]
Output format : [in_length, batch_size, hidden_layer_size]
"""
return tf.scan(fn= self.GRU, elems=self.X, initializer=self.initial_hidden)
def Full_Connection_Layer(self, batch_hidden_layer_states):
"""
The hidden layer state input is converted to
output through the full connection layer
Input format : [batch_size, hidden_layer_size]
Output format : [batch_size, out_classes]
"""
return tf.nn.relu(tf.nn.bias_add(tf.matmul(batch_hidden_layer_states, self.Wout), self.bout))
def deal_hidden_layer(self):
"""
Handle all state output of hidden layer
Input format : [in_length, batch_size, hidden_layer_size]
Output format : [in_length, batch_size, out_classes]
"""
#all_hidden_states = self.PRO_TS()
#return tf.map_fn(self.Full_Connection_Layer, all_hidden_states)
return tf.map_fn(self.Full_Connection_Layer, self.PRO_TS())
def last_output(self):
tp = tf.reverse(self.deal_hidden_layer(), axis=[0])[0,:,:]
return tf.nn.softmax(tp)
y = tf.placeholder(tf.float32, shape=[None, 10],name='inputs')
rnn = GRUcell()
output = rnn.last_output()
cross_entropy = -tf.reduce_sum(y * tf.log(output))
train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(output,1))
accuracy = (tf.reduce_mean(tf.cast(correct_prediction, tf.float32)))
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
batch_size = 32
ss = []
for i in range(5000):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape((batch_size, 28, 28))
sess.run(train_step, feed_dict={rnn.inX:batch_x, y:batch_y})
t = sess.run(accuracy, feed_dict={rnn.inX:batch_x, y:batch_y})
ss.append(t)
ttt = pd.Series(ss)
ttt.plot()
使用 Tensorflow version 1.0 python 3.6
源代码地址: https://uqer.io/community/share/58a9332bf1973300597ae209
这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.