Skip to content

Commit

Permalink
Unify MARL Env into New Framework (#524)
Browse files Browse the repository at this point in the history
* top-down renderer to base

* verify env deterministic

* clean Environment

* remove PGDriveEnv V2

* format

* clean baseEnv

* move save policy to outside

* move some func from pgdrive to base env

* WIP

* finish

* basic marl finish

* all in spawn manager

* move udpate dest to spawn mgr

* almost finish

* format

* fix test

* move env

* fix fix bug

* unify env

* -length

* fix bugs

* format

* fix test

* not assert reward

* not assert reward

* fix bugs

* format

* vis multi agent

* fix bug

* fix copy issue

* format

* format

* discrete action

* format

Co-authored-by: PENG Zhenghao <[email protected]>
  • Loading branch information
QuanyiLi and PENG Zhenghao authored Aug 22, 2021
1 parent 9e50bde commit ad1d3ec
Show file tree
Hide file tree
Showing 85 changed files with 1,000 additions and 1,933 deletions.
3 changes: 1 addition & 2 deletions pgdrive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 1,2 @@
import pgdrive.register
from pgdrive.envs import PGDriveEnv, TopDownPGDriveEnv, TopDownSingleFramePGDriveEnv, \
PGDriveEnvV2, TopDownPGDriveEnvV2
from pgdrive.envs import PGDriveEnv, TopDownPGDriveEnv, TopDownSingleFramePGDriveEnv, TopDownPGDriveEnvV2
3 changes: 2 additions & 1 deletion pgdrive/base_class/base_runnable.py
Original file line number Diff line number Diff line change
@@ -1,3 1,4 @@
import copy
from typing import Dict

from pgdrive.base_class.configurable import Configurable
Expand Down Expand Up @@ -25,7 26,7 @@ def __init__(self, name=None, random_seed=None, config=None):
), "Using PGSpace to define parameter spaces of " self.class_name
self.sample_parameters()
# use external config update to overwrite sampled parameters, except None
self.update_config(config, allow_add_new_key=True)
self.update_config(copy.copy(config), allow_add_new_key=True)

def get_state(self) -> Dict:
"""
Expand Down
17 changes: 10 additions & 7 deletions pgdrive/component/vehicle/base_vehicle.py
Original file line number Diff line number Diff line change
@@ -1,4 1,5 @@
import math
import copy
from pgdrive.utils.space import VehicleParameterSpace, ParameterSpace
from collections import deque
from typing import Union, Optional
Expand Down Expand Up @@ -182,7 183,7 @@ def __init__(
# others
self._add_modules_for_vehicle()
self.takeover = False
self._expert_takeover = False
self.expert_takeover = False
self.energy_consumption = 0
self.action_space = self.get_action_space_before_init(extra_action_dim=self.config["extra_action_dim"])
self.break_down = False
Expand Down Expand Up @@ -228,9 229,6 @@ def _preprocess_action(self, action):
assert self.action_space.contains(action), "Input {} is not compatible with action space {}!".format(
action, self.action_space
)

# protect agent from nan error
action = safe_clip_for_small_array(action, min_val=self.action_space.low[0], max_val=self.action_space.high[0])
return action, {'raw_action': (action[0], action[1])}

def before_step(self, action):
Expand Down Expand Up @@ -632,7 630,7 @@ def _state_check(self):
res = rect_region_detection(
self.engine, self.position, np.rad2deg(self.heading_theta), self.LENGTH, self.WIDTH, CollisionGroup.Sidewalk
)
if res.hasHit():
if res.hasHit() and res.getNode().getName() == BodyName.Sidewalk:
self.crash_sidewalk = True
contacts.add(BodyName.Sidewalk)
self.contact_results = contacts
Expand Down Expand Up @@ -712,8 710,13 @@ def get_overtake_num(self):
return len(self.front_vehicles.intersection(self.back_vehicles))

@classmethod
def get_action_space_before_init(cls, extra_action_dim: int = 0):
return gym.spaces.Box(-1.0, 1.0, shape=(2 extra_action_dim, ), dtype=np.float32)
def get_action_space_before_init(
cls, extra_action_dim: int = 0, discrete_action=False, steering_dim=5, throttle_dim=5
):
if not discrete_action:
return gym.spaces.Box(-1.0, 1.0, shape=(2 extra_action_dim, ), dtype=np.float32)
else:
return gym.spaces.MultiDiscrete([steering_dim, throttle_dim])

def __del__(self):
super(BaseVehicle, self).__del__()
Expand Down
13 changes: 8 additions & 5 deletions pgdrive/component/vehicle_module/lidar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 36,20 @@ def __init__(self, num_lasers: int = 240, distance: float = 50, enable_show=Fals
self.enable_mask = True if not engine.global_config["_disable_detector_mask"] else False

def perceive(self, base_vehicle, detector_mask=True):
lidar_mask = self._get_lidar_mask(base_vehicle)[0] if detector_mask and self.enable_mask else None
return super(Lidar, self).perceive(base_vehicle, base_vehicle.engine.physics_world.dynamic_world, lidar_mask)
res = self._get_lidar_mask(base_vehicle)
lidar_mask = res[0] if detector_mask and self.enable_mask else None
detected_objects = res[1]
return super(Lidar, self).perceive(base_vehicle, base_vehicle.engine.physics_world.dynamic_world,
lidar_mask)[0], detected_objects

@staticmethod
def get_surrounding_vehicles(detected_objects) -> Set:
# TODO this will be removed in the future and use the broad phase detection results
from pgdrive.component.vehicle.base_vehicle import BaseVehicle
vehicles = set()
objs = detected_objects
for ret in objs:
if ret.getNode().hasPythonTag(BodyName.Vehicle):
vehicles.add(get_object_from_node(ret.getNode()))
if isinstance(ret, BaseVehicle):
vehicles.add(ret)
return vehicles

def get_surrounding_vehicles_info(self, ego_vehicle, detected_objects, num_others: int = 4):
Expand Down
2 changes: 1 addition & 1 deletion pgdrive/engine/engine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 35,7 @@ def close_engine():

def get_global_config():
engine = get_engine()
return copy.copy(engine.global_config)
return engine.global_config.copy()


def set_global_random_seed(random_seed: Optional[int]):
Expand Down
1 change: 0 additions & 1 deletion pgdrive/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 1,3 @@
from pgdrive.envs.marl_envs.multi_agent_pgdrive import MultiAgentPGDrive
from pgdrive.envs.pgdrive_env import PGDriveEnv
from pgdrive.envs.pgdrive_env_v2 import PGDriveEnvV2
from pgdrive.envs.top_down_env import TopDownSingleFramePGDriveEnv, TopDownPGDriveEnv, TopDownPGDriveEnvV2
Loading

0 comments on commit ad1d3ec

Please sign in to comment.