diff --git a/neuroptica/losses.py b/neuroptica/losses.py index 10095b2..702191d 100644 --- a/neuroptica/losses.py +++ b/neuroptica/losses.py @@ -43,7 +43,7 @@ class CategoricalCrossEntropy(Loss): @staticmethod def L(X: np.ndarray, T: np.ndarray) -> np.ndarray: X_softmax = np.exp(X) / np.sum(np.exp(X), axis=0) - X_clip = np.clip(X_softmax, 1e-7, 1 - 1e-7) + X_clip = np.clip(X_softmax, 1e-9, 1 - 1e-9) return -np.sum(T * np.log(X_clip), axis=0) @staticmethod @@ -54,3 +54,4 @@ def dL(X: np.ndarray, T: np.ndarray) -> np.ndarray: return -T / X_clip else: return np.conj(X - T) + # return X - T diff --git a/tests/test_losses.py b/tests/test_losses.py new file mode 100644 index 0000000..93d7120 --- /dev/null +++ b/tests/test_losses.py @@ -0,0 +1,47 @@ +import unittest + +from neuroptica.layers import Activation, ClementsLayer +from neuroptica.losses import CategoricalCrossEntropy, MeanSquaredError +from neuroptica.models import Sequential +from neuroptica.nonlinearities import * +from neuroptica.optimizers import Optimizer +from tests.base import NeuropticaTest +from tests.test_models import TestModels + + +class TestLosses(NeuropticaTest): + '''Tests for model losses''' + + def test_loss_gradients(self): + for N in [9, 10]: + + losses = [MeanSquaredError, CategoricalCrossEntropy] + + for loss in losses: + + print("Testing loss {}".format(loss)) + + batch_size = 6 + n_samples = batch_size * 4 + + X_all = (2 * np.random.rand(N * n_samples) - 1).reshape((N, n_samples)) + Y_all = np.abs(X_all) + + # Make a single-layer model + model = Sequential([ClementsLayer(N), + Activation(Abs(N)) + ]) + + for X, Y in Optimizer.make_batches(X_all, Y_all, batch_size): + # Propagate the data forward + Y_hat = model.forward_pass(X) + d_loss = loss.dL(Y_hat, Y) + + # Compute the backpropagated signals for the model + gradients = model.backward_pass(d_loss) + + TestModels.verify_model_gradients(model, X, Y, loss.L, gradients, epsilon=1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nonlinearities.py b/tests/test_nonlinearities.py index c94941e..02c59f4 100644 --- a/tests/test_nonlinearities.py +++ b/tests/test_nonlinearities.py @@ -41,6 +41,8 @@ def test_OpticalMesh_adjoint_optimize(self): # nonlinearities that may be applied to complex outpus nonlinearities_complex = [Abs(N, mode="full"), + Abs(N, mode="condensed"), + Abs(N, mode="polar"), SoftMax(N), AbsSquared(N), ElectroOpticActivation(N, **eo_settings),