Skip to content

Commit

Permalink
a2c added
Browse files Browse the repository at this point in the history
  • Loading branch information
seungeunrho committed Jul 21, 2019
1 parent d32e545 commit c8f286e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 25 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 19,9 @@ Implementations of basic RL algorithms with minimal lines of codes! (PyTorch bas
4. PPO (119 lines, including GAE)
5. DDPG (147 lines, including OU noise and soft target update)
6. A3C (129 lines)
7. ACER added ! (149 lines)
8. Any suggestion..?
7. ACER (149 lines)
8. A2C added! (188 lines)
9. Any suggestion ..?


## Dependencies
Expand All @@ -37,5 38,6 @@ python3 dqn.py
python3 ppo.py
python3 ddpg.py
python3 a3c.py
python3 a2c.py
python3 acer.py
```
30 changes: 7 additions & 23 deletions a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 8,20 @@
import time
import numpy as np


# Hyperparameters
n_train_processes = 3
learning_rate = 0.0002
update_interval = 5
gamma = 0.98
max_train_steps = 60_000

# Constants
max_train_steps = 60000
PRINT_INTERVAL = update_interval * 100
DIM_STATE = 4
DIM_HIDDEN = 256
DIM_VALUE_OUT = 1
DIM_PI_OUT = 2


class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.fc1 = nn.Linear(DIM_STATE, DIM_HIDDEN)
self.fc_pi = nn.Linear(DIM_HIDDEN, DIM_PI_OUT)
self.fc_v = nn.Linear(DIM_HIDDEN, DIM_VALUE_OUT)
self.fc1 = nn.Linear(4, 256)
self.fc_pi = nn.Linear(256, 2)
self.fc_v = nn.Linear(256, 1)

def pi(self, x, softmax_dim=1):
x = F.relu(self.fc1(x))
Expand All @@ -42,7 34,6 @@ def v(self, x):
v = self.fc_v(x)
return v


def worker(worker_id, master_end, worker_end):
master_end.close() # Forbid worker to use the master end for messaging
env = gym.make('CartPole-v1')
Expand All @@ -69,7 60,6 @@ def worker(worker_id, master_end, worker_end):
else:
raise NotImplementedError


class ParallelEnv:
def __init__(self, n_train_processes):
self.nenvs = n_train_processes
Expand Down Expand Up @@ -111,10 101,7 @@ def step(self, actions):
self.step_async(actions)
return self.step_wait()

def close(self):
"""
Clean up the environments' resources.
"""
def close(self): # For clean up resources
if self.closed:
return
if self.waiting:
Expand All @@ -125,7 112,6 @@ def close(self):
worker.join()
self.closed = True


def test(step_idx, model):
env = gym.make('CartPole-v1')
score = 0.0
Expand All @@ -145,7 131,6 @@ def test(step_idx, model):

env.close()


def compute_target(v_final, r_lst, mask_lst):
G = v_final.reshape(-1)
td_target = list()
Expand All @@ -156,7 141,6 @@ def compute_target(v_final, r_lst, mask_lst):

return torch.tensor(td_target[::-1]).float()


if __name__ == '__main__':
envs = ParallelEnv(n_train_processes)

Expand Down Expand Up @@ -185,7 169,7 @@ def compute_target(v_final, r_lst, mask_lst):
td_target = compute_target(v_final, r_lst, mask_lst)

td_target_vec = td_target.reshape(-1)
s_vec = torch.tensor(s_lst).float().reshape(-1, DIM_STATE)
s_vec = torch.tensor(s_lst).float().reshape(-1, 4) # 4 == Dimension of state
a_vec = torch.tensor(a_lst).reshape(-1).unsqueeze(1)
advantage = td_target_vec - model.v(s_vec).reshape(-1)

Expand All @@ -201,4 185,4 @@ def compute_target(v_final, r_lst, mask_lst):
if step_idx % PRINT_INTERVAL == 0:
test(step_idx, model)

envs.close()
envs.close()

0 comments on commit c8f286e

Please sign in to comment.