-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib.py
108 lines (74 loc) · 2.51 KB
/
lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import time
import gym
import gym.wrappers
import atari_py as ap
from stable_baselines3.common.vec_env import *
from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper
def set_title(title):
print(f'\33]0;{title}\a', end='', flush=True)
def record_video(model, env, deterministic, video_length, video_folder, name_prefix):
obs = env.reset()
# Record the video starting at the first step
env = VecVideoRecorder(env,
video_folder,
record_video_trigger=lambda x: x == 0,
video_length=video_length,
name_prefix=name_prefix)
obs = env.reset()
for _ in range(video_length 1):
(action, _) = model.predict(obs, deterministic = deterministic)
#action = [env.action_space.sample()]
obs, reward, _, _ = env.step(action)
#print("Reward:", reward)
# Save the video
env.close()
class Throttle:
def __init__(self, seconds):
self.seconds = seconds
self.next = 0
def tick(self):
now = time.time()
trigger = now >= self.next
if trigger:
self.next = now self.seconds
return trigger
class VecRewardOffset(VecEnvWrapper):
"""
Adds an offset to the reward at each time step.
For example, this can be used to deduct a penalty
to incentivize agents to move forward quickly.
"""
def __init__(self, venv, reward_offset):
super().__init__(venv)
print("Initialize VecRewardOffset:", reward_offset)
self.reward_offset = reward_offset
def reset(self):
obs = self.venv.reset()
return obs
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
reward = self.reward_offset
return obs, reward, done, info
def get_all_atari_ids(no_frameskip = False):
o = []
return ["Pong-v4"]
all_envs = gym.envs.registry.all()
all_envs = [i.id for i in all_envs]
all_envs = set(all_envs)
for i in ap.list_games():
chunks = i.split("_")
chunks = [i.capitalize() for i in chunks]
j = ''.join(chunks) ("NoFrameskip" if no_frameskip else "") "-v4"
# Doesn't work
if j == "Defender-v4":
continue
if j not in all_envs:
continue
o.append(j)
o.sort()
return o
def load_atari(env_id):
e = gym.make(env_id)
# TODO: re-enable and disable grayscale
#e = gym.wrappers.AtariPreprocessing(e, frame_skip=1)
return e