Home » Python » python – Why is tensorflow.reduce_mean hanging?-Exceptionshub

python – Why is tensorflow.reduce_mean hanging?-Exceptionshub

Posted by: admin February 24, 2020 Leave a comment

Questions:

I am trying to make a program that learns how to solve the Towers of Hanoi problem. I am trying to use a REINFORCE algorithm. However, it hangs at tf.reduce_mean, where I am trying to find the mean gradients to apply to the model.
Here is my code:

def train(model, q_message, q_update, q_done):
    for iteration in range(n_iterations):
        send_message(False, q_message, q_update, q_done)
        all_rewards, all_grads = play_multiple_episodes(n_episodes_per_update, n_max_steps, model, loss_fn, q_message, q_update, q_done)
        send_message(True, q_message, q_update, q_done)
        all_final_rewards = discount_and_normalize_rewards(all_rewards, discount_factor)
        all_mean_grads = []
        for var_index in range(len(model.trainable_variables)):
            mean_grads = tf.reduce_mean([final_reward * all_grads[episode_index][step][var_index] for episode_index, final_rewards in enumerate(all_final_rewards) for step, final_reward in enumerate(final_rewards)], axis=0)
            all_mean_grads.append(mean_grads)
        optimizer.apply_gradients(zip(all_mean_grads, model.trainable_variables))
    while True:
        again = input('Would you like to see a demo? [y/n]: ')
        if again == 'y':
            play_episode(model, q_message, q_update, q_done)
        elif again == 'n':
            break

q_message = multiprocessing.Queue()
q_update = multiprocessing.Queue()
q_done = multiprocessing.Queue()
pyglet.clock.schedule(hanoi2.update_puzzle, puzzle=hanoi2.puzzle, q_message=q_message, q_update=q_update, q_done=q_done)
p = multiprocessing.Process(target=train, args=(model, q_message, q_update, q_done))
p.start()
pyglet.app.run()
p.join()

hanoi2 is the pyglet frontend of the program. It is imported as a separate module. The frontend communicates with the training process in order to display the learning, using Queues.

I tried removing the list comprehension from the tf.reduce_mean function call and storing it as a separate variable, and it was computed fine, and its dimensions looked correct. The program simply goes very slow once it tries to call tf.reduce_mean, and eventually hangs.

Any help would be much appreciated!

How to&Answers: