Skip to content

Commit

Permalink
Cache scipy compile graphs and re-use if args are same (#2074)
Browse files Browse the repository at this point in the history
* Cache compile graphs and re-use if args are same

* Add tests for scipy cache

* Ignore error for python3.7

* Tidy-up comments

* Update release notes

* Add docstring for enhancement

* Add compile_cache_size arg to __init__

* Remove redundant use of size in test

* Add id's to test and update release notes
  • Loading branch information
khurram-ghani authored Jul 11, 2023
1 parent a602aff commit cda85b0
Show file tree
Hide file tree
Showing 3 changed files with 390 additions and 11 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 58,15 @@ This release contains contributions from:
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* Scipy minimize wrapper caches compiled graphs and re-uses them if called with the same arguments.
This functionality can be disabled by setting the new `compile_cache_size` argument to 0. (#2074)

## Thanks to our Contributors

This release contains contributions from:

<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
khurram-ghani


# Release 2.8.1
Expand Down
68 changes: 60 additions & 8 deletions gpflow/optimizers/scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 13,19 @@
# limitations under the License.

import warnings
from typing import Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from collections import OrderedDict
from typing import (
Any,
Callable,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
import scipy.optimize
Expand All @@ -31,6 43,31 @@


class Scipy:
def __init__(self, compile_cache_size: int = 2) -> None:
"""
Wrapper around the scipy optimizer.
:param compile_cache_size: The number of compiled evalutation functions to cache for calls
to `minimize`. Only applies when `compile` argument to `minimize` is True.
The compiled evaluation functions are cached so that subsequent calls to `minimize` with
the same `closure`, `variables`, `allow_unused_variables`, and `tf_fun_args` will reuse
a previously compiled function. Up to `compile_cache_size` most recent functions are
cached. This can be disabled by setting `compile_cache_size` to 0.
"""
self.compile_cache: OrderedDict[
Tuple[Callable[[], Any], Tuple[int, ...], FrozenSet[Tuple[str, Any]], bool],
tf.function,
] = OrderedDict()

if compile_cache_size < 0:
raise ValueError(
"The 'compile_cache_size' argument must be non-negative, got {}.".format(
compile_cache_size
)
)
self.compile_cache_size = compile_cache_size

def minimize(
self,
closure: LossClosure,
Expand Down Expand Up @@ -116,9 153,8 @@ def minimize(
def initial_parameters(cls, variables: Sequence[tf.Variable]) -> tf.Tensor:
return cls.pack_tensors(variables)

@classmethod
def eval_func(
cls,
self,
closure: LossClosure,
variables: Sequence[tf.Variable],
tf_fun_args: Mapping[str, Any],
Expand All @@ -130,25 166,41 @@ def eval_func(
def _tf_eval(x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
nonlocal first_call

values = cls.unpack_tensors(variables, x)
cls.assign_tensors(variables, values)
values = self.unpack_tensors(variables, x)
self.assign_tensors(variables, values)

if first_call:
# Only check for unconnected gradients on the first function evaluation.
loss, grads = _compute_loss_and_gradients(
closure, variables, tf.UnconnectedGradients.NONE
)
grads = cls._filter_unused_variables(variables, grads, allow_unused_variables)
grads = self._filter_unused_variables(variables, grads, allow_unused_variables)
first_call = False
else:
loss, grads = _compute_loss_and_gradients(
closure, variables, tf.UnconnectedGradients.ZERO
)

return loss, cls.pack_tensors(grads)
return loss, self.pack_tensors(grads)

if compile:
_tf_eval = tf.function(_tf_eval, **tf_fun_args)
# Re-use the same tf.function graph for calls to minimize, as long as the arguments
# affecting the graph are the same. This can boost performance of use cases where
# minimize is called repeatedly with the same model loss.
key = (
closure,
tuple(id(v) for v in variables),
frozenset(tf_fun_args.items()),
allow_unused_variables,
)
if self.compile_cache_size > 0:
if key not in self.compile_cache:
if len(self.compile_cache) >= self.compile_cache_size:
self.compile_cache.popitem(last=False) # Remove the oldest entry.
self.compile_cache[key] = tf.function(_tf_eval, **tf_fun_args)
_tf_eval = self.compile_cache[key]
else:
_tf_eval = tf.function(_tf_eval, **tf_fun_args)

def _eval(x: AnyNDArray) -> Tuple[AnyNDArray, AnyNDArray]:
loss, grad = _tf_eval(tf.convert_to_tensor(x))
Expand Down
Loading

0 comments on commit cda85b0

Please sign in to comment.