Skip to content Skip to sidebar Skip to footer

Tensorflow Raw_rnn Retrieve Tensor Of Shape Batch X Dim From Embedding Matrix

I am implementing encoder-decoder lstm, where I have to do custom computation at each step of the encoder. So, I am using raw_rnn. However, I am facing a problem accessing an eleme

Solution 1:

I was able fix the problem. Since embeddings have shape Batch x Time steps x Embedding dimensionality I slice out on time dimension. The resulting tensor has shape (?, embedding dimensionality). It is also required to explicitly set the shape of the resulting tensor in order to avoid the error:

ValueError: The shape for rnn/while/Merge_2:0 is not an invariant for the loop

Here is the relevant part:

def get_next_input():
    embedded_value = encoder_inputs_embedded[:, time, :]
    embedded_value.set_shape([batch_size, input_embedding_size])
    return embedded_value

Can anyone confirm if this is the right way to solve the problem?

Here is the complete code for reference:

import tensorflow as tf
import numpy as np

batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
                         dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)
W = tf.Variable(tf.random_uniform([num_units, vocab_size], -1, 1), dtype=tf.float32, name='W_reader')
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32, name='b_reader')
go_time_slice = tf.ones([batch_size], dtype=tf.int32, name='GO') * 1
go_step_embedded = tf.nn.embedding_lookup(embeddings, go_time_slice)


with tf.variable_scope('ReaderNetwork'):
    defloop_fn_initial():
        init_elements_finished = (0 >= encoder_inputs_length)
        init_input = go_step_embedded
        init_cell_state = cell.zero_state(batch_size, tf.float32)
        init_cell_output = None
        init_loop_state = Nonereturn (init_elements_finished, init_input,
                init_cell_state, init_cell_output, init_loop_state)

    defloop_fn_transition(time, previous_output, previous_state, previous_loop_state):
        defget_next_input():
            embedded_value = encoder_inputs_embedded[:, time, :]
            embedded_value.set_shape([batch_size, input_embedding_size])
            return embedded_value

        elements_finished = (time >= encoder_inputs_length)
        finished = tf.reduce_all(elements_finished)  # boolean scalar
        next_input = tf.cond(finished,
                             true_fn=lambda: tf.zeros([batch_size, input_embedding_size], dtype=tf.float32),
                             false_fn=get_next_input)
        state = previous_state
        output = previous_output
        loop_state = Nonereturn elements_finished, next_input, state, output, loop_state


    defloop_fn(time, previous_output, previous_state, previous_loop_state):
        if previous_state isNone:  # time = 0return loop_fn_initial()
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()


defnext_batch():
    return {
        encoder_inputs: np.random.randint(0, vocab_size, (batch_size, max_time)),
        encoder_inputs_length: [max_time] * batch_size
    }


init = tf.global_variables_initializer()
with tf.Session() as s:
    s.run(init)
    outs = s.run([outputs], feed_dict=next_batch())
    printlen(outs), outs[0].shape

Post a Comment for "Tensorflow Raw_rnn Retrieve Tensor Of Shape Batch X Dim From Embedding Matrix"