Commit b173b190 authored by holgadoa's avatar holgadoa

updates in models and scripts

parent a1eb24b5
......@@ -65,7 +65,7 @@ if __name__ == "__main__":
parser.add_argument("--lr", default=LEARNING_RATE_CRITIC, type=float, help="Critic learning rate")
parser.add_argument("--maxkl", default=TRPO_MAX_KL, type=float, help="Maximum KL divergence")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
device = torch.device("cuda")
save_path = os.path.join("saves", "trpo-" + args.name)
os.makedirs(save_path, exist_ok=True)
......@@ -78,6 +78,7 @@ if __name__ == "__main__":
print(net_act)
print(net_crt)
writer = SummaryWriter(comment="-trpo_" + args.name)
agent = model.AgentA2C(net_act, device=device)
exp_source = ptan.experience.ExperienceSource(env, agent, steps_count=1)
......@@ -109,16 +110,22 @@ if __name__ == "__main__":
torch.save(net_act.state_dict(), fname)
best_reward = rewards
trajectory.append(exp)
if len(trajectory) < TRAJECTORY_SIZE:
continue
traj_states = [t[0].state for t in trajectory]
traj_actions = [t[0].action for t in trajectory]
traj_states_v = torch.FloatTensor(traj_states).to(device)
traj_actions_v = torch.FloatTensor(traj_actions).to(device)
traj_adv_v, traj_ref_v = calc_adv_ref(trajectory, net_crt, traj_states_v, device=device)
mu_v = net_act(traj_states_v)
old_logprob_v = calc_logprob(mu_v, net_act.logstd, traj_actions_v)
# normalize advantages
......
......@@ -87,7 +87,7 @@ if __name__ == "__main__":
parser.add_argument("--lr", default=LEARNING_RATE_CRITIC, type=float, help="Critic learning rate")
parser.add_argument("--maxkl", default=TRPO_MAX_KL, type=float, help="Maximum KL divergence")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
device = torch.device("cuda")
save_path = os.path.join("saves", "trpo-" + args.name)
os.makedirs(save_path, exist_ok=True)
......@@ -98,7 +98,7 @@ if __name__ == "__main__":
make_env = lambda: ptan.common.wrappers.wrap_dqn(gym.make("PongNoFrameskip-v4"))
env = make_env()
test_env = make_env()
print(env.observation_space.shape, env.action_space.n)
net_act = modelAtary.ModelActor(env.observation_space.shape, env.action_space.n).to(device)
net_crt = modelAtary.ModelCritic(env.observation_space.shape).to(device)
......
......@@ -71,7 +71,13 @@ class AgentA2C(ptan.agent.BaseAgent):
self.device = device
def __call__(self, states, agent_states):
states_v = ptan.agent.float32_preprocessor(states)
if len(states) == 1:
np_states = np.expand_dims(states[0], 0)
else:
np_states = np.array([np.array(s, copy=False) for s in states], copy=False)
states_v = ptan.agent.float32_preprocessor(np_states)
states_v = states_v.to(self.device)
mu_v = self.net(states_v)
......@@ -80,6 +86,7 @@ class AgentA2C(ptan.agent.BaseAgent):
rnd = np.random.normal(size=logstd.shape)
actions = mu + np.exp(logstd) * rnd
actions = np.clip(actions, -1, 1)
return actions, agent_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment