forked from google-deepmind/bsuite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
csv_logging.py
89 lines (72 loc) · 3.17 KB
/
csv_logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Logging functionality for CSV-based experiments."""
import os
from typing import Any, Mapping
from bsuite import environments
from bsuite import sweep
from bsuite.logging import base
from bsuite.utils import wrappers
import dm_env
import pandas as pd
SAFE_SEPARATOR = '-'
INITIAL_SEPARATOR = '_-_'
BSUITE_PREFIX = 'bsuite_id' INITIAL_SEPARATOR
def wrap_environment(env: environments.Environment,
bsuite_id: str,
results_dir: str,
overwrite: bool = False,
log_by_step: bool = False) -> dm_env.Environment:
"""Returns a wrapped environment that logs using CSV."""
logger = Logger(bsuite_id, results_dir, overwrite)
return wrappers.Logging(env, logger, log_by_step=log_by_step)
class Logger(base.Logger):
"""Saves data to a CSV file via Pandas.
In this simplified logger, each bsuite_id logs to a unique CSV index by
bsuite_id. These are saved to a single results_dir by experiment.
We strongly suggest that you use a *fresh* folder for each bsuite run.
The write method rewrites the entire CSV file on each call. This is not
intended to be an optimized example. However, writes are infrequent due to
bsuite's logarithmically-spaced logging.
This logger, along with the corresponding load functionality, serves as a
simple, minimal example for users who need to implement logging to a different
storage system.
"""
def __init__(self,
bsuite_id: str,
results_dir: str = '/tmp/bsuite',
overwrite: bool = False):
"""Initializes a new CSV logger."""
if not os.path.exists(results_dir):
try:
os.makedirs(results_dir)
except OSError: # concurrent processes can makedir at same time
pass
# The default '/' symbol is dangerous for file systems!
safe_bsuite_id = bsuite_id.replace(sweep.SEPARATOR, SAFE_SEPARATOR)
filename = f'{BSUITE_PREFIX}{safe_bsuite_id}.csv'
save_path = os.path.join(results_dir, filename)
if os.path.exists(save_path) and not overwrite:
raise ValueError(
f'File {save_path} already exists. Specify a different '
'directory, or set overwrite=True to overwrite existing data.')
self._data = []
self._save_path = save_path
def write(self, data: Mapping[str, Any]):
"""Adds a row to the internal list of data and saves to CSV."""
self._data.append(data)
df = pd.DataFrame(self._data)
df.to_csv(self._save_path, index=False)