Commit 2eeb87cf authored by dimos's avatar dimos

Update 04_train_trpo_Atary.py

parent f3754320
......@@ -139,8 +139,6 @@ if __name__ == "__main__":
torch.save(net_act.state_dict(), fname)
best_reward = rewards
trajectory.append(exp)
trajectory.append(exp)
if len(trajectory) < TRAJECTORY_SIZE:
continue
......@@ -152,7 +150,7 @@ if __name__ == "__main__":
traj_actions_v = torch.FloatTensor(traj_actions).to(device)
traj_adv_v, traj_ref_v = calc_adv_ref(trajectory, net_crt, traj_states_v, device=device)
mu_v = net_act(traj_states_v)
mu_v = F.log_softmax(mu_v, dim=1)
#mu_v = F.log_softmax(mu_v, dim=1)
old_logprob_v = calc_logprob(mu_v, net_act.logstd, traj_actions_v)
# normalize advantages
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment