Tensorflow Raw_rnn Retrieve Tensor Of Shape Batch X Dim From Embedding Matrix
I am implementing encoder-decoder lstm, where I have to do custom computation at each step of the encoder. So, I am using raw_rnn. However, I am facing a problem accessing an eleme
Solution 1:
I was able fix the problem. Since embeddings have shape Batch x Time steps x Embedding dimensionality I slice out on time dimension. The resulting tensor has shape (?, embedding dimensionality). It is also required to explicitly set the shape of the resulting tensor in order to avoid the error:
ValueError: The shape for rnn/while/Merge_2:0 is not an invariant for the loop
Here is the relevant part:
def get_next_input():
embedded_value = encoder_inputs_embedded[:, time, :]
embedded_value.set_shape([batch_size, input_embedding_size])
return embedded_value
Can anyone confirm if this is the right way to solve the problem?
Here is the complete code for reference:
import tensorflow as tf
import numpy as np
batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
cell = tf.contrib.rnn.LSTMCell(num_units)
W = tf.Variable(tf.random_uniform([num_units, vocab_size], -1, 1), dtype=tf.float32, name='W_reader')
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32, name='b_reader')
go_time_slice = tf.ones([batch_size], dtype=tf.int32, name='GO') * 1
go_step_embedded = tf.nn.embedding_lookup(embeddings, go_time_slice)
with tf.variable_scope('ReaderNetwork'):
defloop_fn_initial():
init_elements_finished = (0 >= encoder_inputs_length)
init_input = go_step_embedded
init_cell_state = cell.zero_state(batch_size, tf.float32)
init_cell_output = None
init_loop_state = Nonereturn (init_elements_finished, init_input,
init_cell_state, init_cell_output, init_loop_state)
defloop_fn_transition(time, previous_output, previous_state, previous_loop_state):
defget_next_input():
embedded_value = encoder_inputs_embedded[:, time, :]
embedded_value.set_shape([batch_size, input_embedding_size])
return embedded_value
elements_finished = (time >= encoder_inputs_length)
finished = tf.reduce_all(elements_finished) # boolean scalar
next_input = tf.cond(finished,
true_fn=lambda: tf.zeros([batch_size, input_embedding_size], dtype=tf.float32),
false_fn=get_next_input)
state = previous_state
output = previous_output
loop_state = Nonereturn elements_finished, next_input, state, output, loop_state
defloop_fn(time, previous_output, previous_state, previous_loop_state):
if previous_state isNone: # time = 0return loop_fn_initial()
return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)
reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()
defnext_batch():
return {
encoder_inputs: np.random.randint(0, vocab_size, (batch_size, max_time)),
encoder_inputs_length: [max_time] * batch_size
}
init = tf.global_variables_initializer()
with tf.Session() as s:
s.run(init)
outs = s.run([outputs], feed_dict=next_batch())
printlen(outs), outs[0].shape
Post a Comment for "Tensorflow Raw_rnn Retrieve Tensor Of Shape Batch X Dim From Embedding Matrix"