Skip to content

Commit

Permalink
added true data/label shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
bencbartlett committed Jan 8, 2019
1 parent 9f10d29 commit 117b0bf
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions neuroptica/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 26,18 @@ def make_batches(data: np.ndarray, labels: np.ndarray, batch_size: int,
:param data: features vector, shape: (n_features, n_samples)
:param labels: labels vector, shape: (n_label_dim, n_samples)
:param batch_size: size of the batch
:param shuffle: if true, batches will be presented in random order (the data within each batch is not shuffled)
:param shuffle: if true, batches will be randomized
:return: yields a tuple (data_batch, label_batch)
'''

n_features, n_samples = data.shape

batch_indices = np.arange(0, n_samples, batch_size)
if shuffle: np.random.shuffle(batch_indices)

if shuffle:
permutation = np.random.permutation(n_samples)
data = data[:, permutation] # this doesn't overwrite data from outside function call
labels = labels[:, permutation]

for i in batch_indices:
X = data[:, i:i batch_size]
Expand Down

0 comments on commit 117b0bf

Please sign in to comment.