This way the variables between the two GRUCells are shared. Note that you need to be careful with shapes, since the same cell need to work with both the raw input and the output of itself.
Pay attention to distinguish two different things: the number of layers of your recurrent neural network and the number of time this RNN gets unrolled by the Back Propagation Through Time algorithm to handle sequence length.
In your code:
Recurrent networks like LSTM and GRU are powerful sequence models. I will explain how to create recurrent networks in TensorFlow and use them for sequence classification and labelling tasks.
If you are not familiar with recurrent networks, I suggest you take a look at Christopher Olah’s great article first. On the TensorFlow part, I also expect some basic knowledge. The official tutorials are a good place to start.
To use recurrent networks in TensorFlow we first need to define the network architecture consiting of one or more layers, the cell type and possibly dropout between the layers. In TensorFlow, we build recurrent networks out of so called cells that wrap each other.
import tensorflow as tf
num_units = 200
num_layers = 3
dropout = tf.placeholder(tf.float32)
cells = []
for _ in range(num_layers):
cell = tf.contrib.rnn.GRUCell(num_units) # Or LSTMCell(num_units)
cell = tf.contrib.rnn.DropoutWrapper(
cell, output_keep_prob=1.0 - dropout)
cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells)
We can now add the operations to the graph that simulate the recurrent network over the time steps of the input. We do this using TensorFlow’s dynamic_rnn() operation. It takes the a tensor block holding the input sequences and returns the output activations and last hidden state as tensors.
# Batch size x time steps x features.
data = tf.placeholder(tf.float32, [None, None, 28])
output, state = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32)
You can recover the LSTM weights from your tensorflow session "sess" as follows:
trainable_vars_dict = {}
for key in tvars:
trainable_vars_dict[key.name] = sess.run(key)
# Checking the names of the keys
print(key)
From this code you will get the key names. One key name corresponds to a matrix containing all weights of LSTM. The key in your case should have the name "LSTM/rnn/basic_lstm_cell/weights:0". Assuming the size of your input is input_size, you have to do:
lstm_weight_vals = trainable_vars_dict["LSTM/rnn/basic_lstm_cell/weights:0"]
w_i, w_C, w_f, w_o = np.split(lstm_weight_vals, 4, axis=1)
w_xi = w_i[:input_size, :]
w_hi = w_i[input_size:, :]
w_xC = w_C[:input_size, :]
w_hC = w_C[input_size:, :]
w_xf = w_f[:input_size, :]
w_hf = w_f[input_size:, :]
w_xo = w_o[:input_size, :]
w_ho = w_o[input_size:, :]
Where the matrices with "h" in them should be quadratic at the end (of size http://www.w3.org/1998/Math/MathML"><mn>128</mn><mo>×</mo><mn>128</mn></math>"> id="MathJax-Span-7" class="math">128×128128×128 in your case). I think for you the input size is http://www.w3.org/1998/Math/MathML"><mn>28</mn></math>"> id="MathJax-Span-12" class="math">2828.
You can see that the weights are not shared by executing the following script:
import tensorflow as tf
with tf.variable_scope("scope1") as vs:
cell = tf.nn.rnn_cell.GRUCell(10)
stacked_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * 2)
stacked_cell(tf.Variable(np.zeros((100, 100), dtype=np.float32), name="moo"), tf.Variable(np.zeros((100, 100), dtype=np.float32), "bla"))
# Retrieve just the LSTM variables.
vars = [v.name for v in tf.all_variables()
if v.name.startswith(vs.name)]
print vars
You will see that besides dummy variables it returns two sets of GRU weights: those with "Cell1" and those with "Cell0".
To make them shared, you can implement your own cell class that inherits from GRUcell and always reuses the weights by the means of always using the same variable scope:
import tensorflow as tf
class SharedGRUCell(tf.nn.rnn_cell.GRUCell):
def __init__(self, num_units, input_size=None, activation=tf.nn.tanh):
tf.nn.rnn_cell.GRUCell.__init__(self, num_units, input_size, activation)
self.my_scope = None
def __call__(self, a, b):
if self.my_scope == None:
self.my_scope = tf.get_variable_scope()
else:
self.my_scope.reuse_variables()
return tf.nn.rnn_cell.GRUCell.__call__(self, a, b, self.my_scope)
with tf.variable_scope("scope2") as vs:
cell = SharedGRUCell(10)
stacked_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * 2)
stacked_cell(tf.Variable(np.zeros((20, 10), dtype=np.float32), name="moo"), tf.Variable(np.zeros((20, 10), dtype=np.float32), "bla"))
# Retrieve just the LSTM variables.
vars = [v.name for v in tf.all_variables()
if v.name.startswith(vs.name)]
print vars
This way the variables between the two GRUCells are shared. Note that you need to be careful with shapes, since the same cell need to work with both the raw input and the output of itself.