Skip to content

Commit

Permalink
Use multistep instead of raw_step in wrappers
Browse files Browse the repository at this point in the history
Fix tests
  • Loading branch information
sogartar committed Mar 14, 2022
1 parent 2ad9aad commit d0d464a
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 78 deletions.
8 changes: 6 additions & 2 deletions compiler_gym/wrappers/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 57,13 @@ def __init__(
name=f"{type(self).__name__}<{env.action_space.name}>",
)

def raw_step(
def multistep(
self,
actions: List[ActionType],
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
) -> StepType:
terminal_action: int = len(self.action_space.flags) - 1

Expand All @@ -74,10 76,12 @@ def raw_step(
if index_of_terminal >= 0:
actions = actions[:index_of_terminal]

observation, reward, done, info = self.env.raw_step(
observation, reward, done, info = self.env.multistep(
actions,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
observations=observations,
rewards=rewards,
)

# Communicate back to the frontend.
Expand Down
84 changes: 37 additions & 47 deletions compiler_gym/wrappers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 4,7 @@
# LICENSE file in the root directory of this source tree.
import warnings
from collections.abc import Iterable as IterableType
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Union

import gym

Expand Down Expand Up @@ -40,26 40,12 @@ def __init__(self, env: CompilerEnv): # pylint: disable=super-init-not-called
self.reward_range = self.env.reward_range
self.metadata = self.env.metadata

def raw_step(
self,
actions: Iterable[ActionType],
observation_spaces: Iterable[ObservationSpaceSpec],
reward_spaces: Iterable[Reward],
):
return self.env.raw_step(
actions, observation_spaces=observation_spaces, reward_spaces=reward_spaces
)

def reset(self, *args, **kwargs) -> ObservationType:
return self.env.reset(*args, **kwargs)

def fork(self) -> CompilerEnv:
return type(self)(env=self.env.fork())

# NOTE(cummins): This step() method is provided only because
# CompilerEnv.step accepts additional arguments over gym.Env.step. Users who
# wish to modify the behavior of CompilerEnv.step should overload
# raw_step().
def step( # pylint: disable=arguments-differ
self,
action: ActionType,
Expand Down Expand Up @@ -95,8 81,7 @@ def step( # pylint: disable=arguments-differ
category=DeprecationWarning,
)
reward_spaces = rewards
return self.env._multistep(
raw_step=self.raw_step,
return self.multistep(
actions=[action],
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
Expand Down Expand Up @@ -124,8 109,7 @@ def multistep(
category=DeprecationWarning,
)
reward_spaces = rewards
return self.env._multistep( # pylint: disable=protected-access
raw_step=self.raw_step,
return self.env.multistep(
actions=actions,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
Expand Down Expand Up @@ -166,16 150,20 @@ class ActionWrapper(CompilerEnvWrapper):
to allow an action space transformation.
"""

def raw_step(
def multistep(
self,
actions: Iterable[ActionType],
observation_spaces: Iterable[ObservationSpaceSpec],
reward_spaces: Iterable[Reward],
observation_spaces: Optional[Iterable[ObservationSpaceSpec]] = None,
reward_spaces: Optional[Iterable[Reward]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
return self.env.raw_step(
return self.env.multistep(
[self.action(a) for a in actions],
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
observations=observations,
rewards=rewards,
)

def action(self, action: ActionType) -> ActionType:
Expand All @@ -196,22 184,23 @@ def reset(self, *args, **kwargs):
observation = self.env.reset(*args, **kwargs)
return self.observation(observation)

def raw_step(
def multistep(
self,
actions: Iterable[ActionType],
observation_spaces: Iterable[ObservationSpaceSpec],
reward_spaces: Iterable[Reward],
actions: List[ActionType],
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
observation, reward, done, info = self.env.raw_step(
actions, observation_spaces=observation_spaces, reward_spaces=reward_spaces
observation, reward, done, info = self.env.multistep(
actions,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
observations=observations,
rewards=rewards,
)

# Only apply observation transformation if we are using the default
# observation space.
if observation_spaces == [self.observation_space_spec]:
observation = [self.observation(observation)]

return observation, reward, done, info
return self.observation(observation), reward, done, info

def observation(self, observation):
"""Translate an observation to the new space."""
Expand All @@ -226,22 215,23 @@ class RewardWrapper(CompilerEnvWrapper):
def reset(self, *args, **kwargs):
return self.env.reset(*args, **kwargs)

def raw_step(
def multistep(
self,
actions: Iterable[ActionType],
observation_spaces: Iterable[ObservationSpaceSpec],
reward_spaces: Iterable[Reward],
actions: List[ActionType],
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
observation, reward, done, info = self.env.step(
actions, observation_spaces=observation_spaces, reward_spaces=reward_spaces
observation, reward, done, info = self.env.multistep(
actions,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
observations=observations,
rewards=rewards,
)

# Only apply rewards transformation if we are using the default
# reward space.
if reward_spaces == [self.reward_space]:
reward = [self.reward(reward)]

return observation, reward, done, info
return observation, self.reward(reward), done, info

def reward(self, reward):
"""Translate a reward to the new space."""
Expand Down
4 changes: 2 additions & 2 deletions compiler_gym/wrappers/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 29,13 @@ def __init__(
super().__init__(env)
self.reward_penalty = reward_penalty

def raw_step(
def multistep(
self,
actions: List[ActionType],
observation_spaces=None,
reward_spaces=None,
):
observation, reward, done, info = self.env.raw_step(
observation, reward, done, info = self.env.multistep(
actions,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
Expand Down
2 changes: 1 addition & 1 deletion examples/llvm_autotuning/autotuners/opentuner_.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 93,7 @@ def __init__(self, data) -> None:
wrapped = DesiredResult(Configuration(manipulator.best_config))
manipulator.run(wrapped, None, None)
env.reset()
env.step(manipulator.serialize_actions(manipulator.best_config))
env.multistep(manipulator.serialize_actions(manipulator.best_config))


class LlvmOptFlagsTuner(MeasurementInterface):
Expand Down
4 changes: 2 additions & 2 deletions examples/llvm_rl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 126,7 @@ def reset(self, *args, **kwargs):
)
return super().reset(*args, **kwargs)

def raw_step(
def multistep(
self,
actions: List[ActionType],
observation_spaces=None,
Expand All @@ -135,7 135,7 @@ def raw_step(
):
for a in actions:
self.histogram[a] = self.increment
return self.env.raw_step(actions, **kwargs)
return self.env.multistep(actions, **kwargs)

def observation(self, observation):
return np.concatenate((observation, self.histogram)).astype(
Expand Down
2 changes: 2 additions & 0 deletions examples/loop_optimizations_service/service_py/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 5,8 @@

cg_add_all_subdirs()

return()

cg_py_library(
NAME
loops_opt_service
Expand Down
2 changes: 1 addition & 1 deletion examples/op_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 267,7 @@ def get_step_times(env: CompilerEnv, num_steps: int, batched=False):
# Run all actions in a single step().
steps = [env.action_space.sample() for _ in range(num_steps)]
with Timer() as timer:
_, _, done, _ = env.step(steps)
_, _, done, _ = env.multistep(steps)
if not done:
return [timer.time / num_steps] * num_steps
env.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 116,7 @@ def run_one_trial(
num_steps = random.randint(min_steps, max_steps)
warmup_actions = [env.action_space.sample() for _ in range(num_steps)]
env.reward_space = reward_space
_, _, done, _ = env.step(warmup_actions)
_, _, done, _ = env.multistep(warmup_actions)
if done:
return None
return env.episode_reward
Expand Down
2 changes: 1 addition & 1 deletion tests/util/minimize_trajectory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 152,7 @@ def hypothesis(env):
def test_minimize_trajectory_iteratively_llvm_crc32(env):
"""Test trajectory minimization on a real environment."""
env.reset(benchmark="cbench-v1/crc32")
env.step(
env.multistep(
[
env.action_space["-mem2reg"],
env.action_space["-gvn"],
Expand Down
22 changes: 3 additions & 19 deletions tests/wrappers/core_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/wrappers."""
import numpy as np
import pytest

from compiler_gym.datasets import Datasets
Expand Down Expand Up @@ -108,6 107,9 @@ def observation(self, observation):
def action(self, action):
return action # pass thru

def reward(self, reward):
return reward

env = MyWrapper(env)
env.reset()
(ir, ic), (icr, icroz), _, _ = env.multistep(
Expand Down Expand Up @@ -259,22 261,6 @@ def test_wrapped_observation_missing_definition(env: LlvmEnv):
env.reset()


def test_wrapped_observation_not_applied_to_non_default_observations(env: LlvmEnv):
class MyWrapper(ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = "Ir"

def observation(self, observation):
return len(observation)

env = MyWrapper(env)
env.reset()
(observation,), _, _, _ = env.step(0, observation_spaces=["Autophase"])
print(observation)
assert isinstance(observation, np.ndarray)


def test_wrapped_reward(env: LlvmEnv):
class MyWrapper(RewardWrapper):
def reward(self, reward):
Expand All @@ -286,11 272,9 @@ def reward(self, reward):
env.reset()
_, reward, _, _ = env.step(0)
assert reward == -5
assert env.episode_reward == -5

_, reward, _, _ = env.step(0)
assert reward == -5
assert env.episode_reward == -10


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions www/www.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 217,7 @@ def _step(request: StepRequest) -> StepReply:
if request.all_states:
# Replay actions one at a time to receive incremental rewards. The
# first item represents the state prior to any actions.
(instcount, autophase), _, done, info = env.raw_step(
(instcount, autophase), _, done, info = env.multistep(
actions=[],
observation_spaces=[
env.observation.spaces["InstCountDict"],
Expand Down Expand Up @@ -263,7 263,7 @@ def _step(request: StepRequest) -> StepReply:
)

# Perform the final action.
(ir, instcount, autophase), (reward,), done, _ = env.raw_step(
(ir, instcount, autophase), (reward,), done, _ = env.multistep(
actions=request.actions[-1:],
observation_spaces=[
env.observation.spaces["Ir"],
Expand Down

0 comments on commit d0d464a

Please sign in to comment.