I am trying to make a program that learns how to solve the Towers of Hanoi problem. I am trying to use a REINFORCE algorithm. However, it hangs at tf.reduce_mean, where I am trying to find the mean gradients to apply to the model.
Here is my code:
def train(model, q_message, q_update, q_done): for iteration in range(n_iterations): send_message(False, q_message, q_update, q_done) all_rewards, all_grads = play_multiple_episodes(n_episodes_per_update, n_max_steps, model, loss_fn, q_message, q_update, q_done) send_message(True, q_message, q_update, q_done) all_final_rewards = discount_and_normalize_rewards(all_rewards, discount_factor) all_mean_grads =  for var_index in range(len(model.trainable_variables)): mean_grads = tf.reduce_mean([final_reward * all_grads[episode_index][step][var_index] for episode_index, final_rewards in enumerate(all_final_rewards) for step, final_reward in enumerate(final_rewards)], axis=0) all_mean_grads.append(mean_grads) optimizer.apply_gradients(zip(all_mean_grads, model.trainable_variables)) while True: again = input('Would you like to see a demo? [y/n]: ') if again == 'y': play_episode(model, q_message, q_update, q_done) elif again == 'n': break q_message = multiprocessing.Queue() q_update = multiprocessing.Queue() q_done = multiprocessing.Queue() pyglet.clock.schedule(hanoi2.update_puzzle, puzzle=hanoi2.puzzle, q_message=q_message, q_update=q_update, q_done=q_done) p = multiprocessing.Process(target=train, args=(model, q_message, q_update, q_done)) p.start() pyglet.app.run() p.join()
hanoi2 is the pyglet frontend of the program. It is imported as a separate module. The frontend communicates with the training process in order to display the learning, using Queues.
I tried removing the list comprehension from the tf.reduce_mean function call and storing it as a separate variable, and it was computed fine, and its dimensions looked correct. The program simply goes very slow once it tries to call tf.reduce_mean, and eventually hangs.
Any help would be much appreciated!