Skip to content

Commit

Permalink
lint: solve mypy errors introduced in the previous commit
Browse files Browse the repository at this point in the history
  • Loading branch information
AldoGl committed Sep 10, 2024
1 parent 880474a commit 6c252a2
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 16 deletions.
8 changes: 5 additions & 3 deletions black_it/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 127,11 @@ def __init__( # noqa: PLR0913
# initialize arrays
self.params_samp = np.zeros((0, self.param_grid.dims))
self.losses_samp = np.zeros(0)
self.batch_num_samp = np.zeros(0, dtype=int)
self.method_samp = np.zeros(0, dtype=int)
self.series_samp = np.zeros((0, self.ensemble_size, self.N, self.D))
self.batch_num_samp: NDArray[np.int64] = np.zeros(0, dtype=int)
self.method_samp: NDArray[np.int64] = np.zeros(0, dtype=int)
self.series_samp: NDArray[np.float64] = np.zeros(
(0, self.ensemble_size, self.N, self.D),
)

# initialize variables before calibration
self.n_sampled_params = 0
Expand Down
2 changes: 1 addition & 1 deletion black_it/loss_functions/gsl_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 270,7 @@ def get_words(time_series: NDArray[np.float64], length: int) -> NDArray:
"the chosen word length is too high",
exception_class=ValueError,
)
tsw = np.zeros(shape=(tswlen,), dtype=np.int32)
tsw: NDArray[np.float64] = np.zeros(shape=(tswlen,), dtype=np.int32)

for i in range(length):
k = 10 ** (length - i - 1)
Expand Down
12 changes: 8 additions & 4 deletions black_it/loss_functions/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 18,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, cast

import numpy as np

Expand Down Expand Up @@ -82,9 82,13 @@ def compute_loss(
Returns:
The loss value.
"""
r = sim_data_ensemble.shape[0] # number of repetitions
s = sim_data_ensemble.shape[1] # simulation length
d = sim_data_ensemble.shape[2] # number of dimensions
sim_data_ensemble_shape: tuple[int, int, int] = cast(
tuple[int, int, int],
sim_data_ensemble.shape,
)
r = sim_data_ensemble_shape[0] # number of repetitions
s = sim_data_ensemble_shape[1] # simulation length
d = sim_data_ensemble_shape[2] # time series dimension

if self.coordinate_weights is not None:
warnings.warn( # noqa: B028
Expand Down
4 changes: 3 additions & 1 deletion black_it/plot/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 32,8 @@
if TYPE_CHECKING:
import os

from numpy.typing import NDArray


def _get_samplers_id_table(saving_folder: str | os.PathLike) -> dict[str, int]:
"""Get the id table of the samplers from the checkpoint.
Expand Down Expand Up @@ -298,7 300,7 @@ def plot_sampling_interact(saving_folder: str | os.PathLike) -> None:
data_frame = pd.read_csv(calibration_results_file)

max_bn = int(max(data_frame["batch_num_samp"]))
all_bns = np.arange(max_bn 1, dtype=int)
all_bns: NDArray[np.int64] = np.arange(max_bn 1, dtype=int)
indices_bns = np.array_split(all_bns, min(max_bn, 3))

dict_bns = {}
Expand Down
6 changes: 3 additions & 3 deletions black_it/samplers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 28,9 @@
if TYPE_CHECKING:
from numpy.typing import NDArray

MAX_FLOAT32 = np.finfo(np.float32).max
MIN_FLOAT32 = np.finfo(np.float32).min
EPS_FLOAT32 = np.finfo(np.float32).eps
MAX_FLOAT32: float = cast(float, np.finfo(np.float32).max)
MIN_FLOAT32: float = cast(float, np.finfo(np.float32).min)
EPS_FLOAT32: float = cast(float, np.finfo(np.float32).eps)


class XGBoostSampler(MLSurrogateSampler):
Expand Down
2 changes: 1 addition & 1 deletion black_it/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 72,7 @@ def __init__(
self._param_grid: list[NDArray[np.float64]] = []
self._space_size = 1
for i in range(self.dims):
new_col = np.arange(
new_col: NDArray[np.float64] = np.arange(
parameters_bounds[0][i],
parameters_bounds[1][i] 0.0000001,
parameters_precision[i],
Expand Down
7 changes: 4 additions & 3 deletions tests/test_samplers/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""This module contains tests for the xgboost sampler."""
import sys
from typing import cast

import numpy as np

Expand All @@ -34,9 35,9 @@
else:
expected_params = np.array([[0.24, 0.26], [0.37, 0.21], [0.43, 0.14], [0.11, 0.04]])

MAX_FLOAT32 = np.finfo(np.float32).max
MIN_FLOAT32 = np.finfo(np.float32).min
EPS_FLOAT32 = np.finfo(np.float32).eps
MAX_FLOAT32: float = cast(float, np.finfo(np.float32).max)
MIN_FLOAT32: float = cast(float, np.finfo(np.float32).min)
EPS_FLOAT32: float = cast(float, np.finfo(np.float32).eps)


def test_xgboost_2d() -> None:
Expand Down

0 comments on commit 6c252a2

Please sign in to comment.