Skip to content

Commit

Permalink
sac bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
seungeunrho committed Dec 1, 2020
1 parent a88447c commit 8c364c3
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 10,12 @@
#Hyperparameters
lr_pi = 0.0005
lr_q = 0.001
init_alpha = 0.01
init_alpha = 0.01
gamma = 0.98
batch_size = 32
buffer_limit = 50000
batch_size = 32
buffer_limit = 50000
tau = 0.01 # for target network soft update
target_entropy = -1.0 # for automated alpha update
target_entropy = -1.0 # for automated alpha update
lr_alpha = 0.001 # for automated alpha update

class ReplayBuffer():
Expand Down Expand Up @@ -69,11 69,11 @@ def forward(self, x):
return real_action, real_log_prob

def train_net(self, q1, q2, mini_batch):
s, a, r, s_prime, done = mini_batch
a_prime, log_prob = self.forward(s_prime)
s, _, _, _, _ = mini_batch
a, log_prob = self.forward(s)
entropy = -self.log_alpha.exp() * log_prob

q1_val, q2_val = q1(s,a_prime), q2(s,a_prime)
q1_val, q2_val = q1(s,a), q2(s,a)
q1_q2 = torch.cat([q1_val, q2_val], dim=1)
min_q = torch.min(q1_q2, 1, keepdim=True)[0]

Expand Down

0 comments on commit 8c364c3

Please sign in to comment.