Skip to content

Commit

Permalink
Support initialising Linear when A is a Parameter (#2096)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Schmiegel authored Dec 19, 2023
1 parent 5d4c222 commit dca094b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 89,6 @@ Because GitHub's [graph of contributors](http://github.com/GPflow/GPflow/graphs/
[@khurram-ghani](https://github.com/khurram-ghani)
[@partev](https://github.com/partev)
[@uri-granta](https://github.com/uri-granta)
[@jschmiegel](https://github.com/jschmiegel)

Add yourself when you first contribute to GPflow's code, tests, or documentation!
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 56,7 @@ This release contains contributions from:
## Bug Fixes and Other Changes

* Don't round small values in kernel summary printout
* allowing to set a prior for A in the Linear() mean function

## Thanks to our Contributors

Expand Down
10 changes: 9 additions & 1 deletion gpflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 110,15 @@ def __init__(self, A: TensorType = None, b: TensorType = None) -> None:
MeanFunction.__init__(self)
A = np.ones((1, 1), dtype=default_float()) if A is None else A
b = np.zeros(1, dtype=default_float()) if b is None else b
self.A = Parameter(np.atleast_2d(A))
if isinstance(A, Parameter):
if len(A._shape) >= 2:
self.A = A
else:
raise ValueError(
"Error 'gpflow.funcitons.Linear()' mean function. A has not the correct shape (at least 2d)."
)
else:
self.A = Parameter(np.atleast_2d(A))
self.b = Parameter(b)

@inherit_check_shapes
Expand Down
23 changes: 23 additions & 0 deletions tests/gpflow/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 16,7 @@

import numpy as np
import pytest
import tensorflow_probability as tfp
from check_shapes import check_shapes
from numpy.testing import assert_allclose

Expand Down Expand Up @@ -371,3 372,25 @@ def test_models_with_mean_functions_changes(model_class: Type[Any]) -> None:
assert np.all(var_zero.numpy() == var_non_zero.numpy())
# predictive mean changes after modifying mean function
assert not np.all(mu_zero.numpy() == mu_non_zero.numpy())


class TestIssue2091EnsureParameterTypeForLinearMeanFunction:
"""
See github issue #2091. These are tests to ensure that the prior is kept if A is already given as Parameter.
"""

# test_parameter_with_correct_shape
def test_parameter_with_correct_shape(self) -> None:
"Check that Linear copies the prior from a correctly shaped A Parameter"
A = gpflow.Parameter(
np.ones((1, 1)), dtype=np.float64, prior=tfp.distributions.Normal(0.0, 1.0)
)
linear_function = Linear(A, 1)
assert linear_function.A.prior is not None

# test_parameter_with_incorrect_shape
def test_parameter_with_incorrect_shape(self) -> None:
"Check that Linear throws an error when the A Parameter is not correctly shaped"
A = gpflow.Parameter(np.zeros(20), dtype=np.float64)
with pytest.raises(ValueError):
Linear(A, 1)

0 comments on commit dca094b

Please sign in to comment.