Skip to content

Commit

Permalink
refactor: refactor scale_by_rss
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Mar 22, 2023
1 parent 9908863 commit e9611fc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
34 changes: 19 additions & 15 deletions torchopt/transform/scale_by_rss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from torchopt import pytree
from torchopt.base import GradientTransformation
from torchopt.transform.utils import tree_map_flat
from torchopt.transform.utils import tree_map_flat, update_moment
from torchopt.typing import OptState, Params, Updates


Expand All @@ -63,8 +63,10 @@ def scale_by_rss(
- McMahan et al., 2010: https://arxiv.org/abs/1002.4908
Args:
initial_accumulator_value: Starting value for accumulators, must be >= 0.
eps: A small floating point value to avoid zero denominator.
initial_accumulator_value (float, optional): Starting value for accumulators, must be
``>= 0``. (default: :const:`0.0`)
eps (float, optional): A small floating point value to avoid zero denominator.
(default: :const:`1e-10`)
Returns:
An (init_fn, update_fn) tuple.
Expand Down Expand Up @@ -115,32 +117,34 @@ def update_fn(
params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
) -> tuple[Updates, OptState]:
sum_of_squares = tree_map(
lambda g, t: t + (g.conj() * g).real,
sum_of_squares = update_moment.impl( # type: ignore[attr-defined]
updates,
state.sum_of_squares,
decay=1.0,
order=2,
inplace=inplace,
already_flattened=already_flattened,
)

if inplace:

def f(t: torch.Tensor) -> torch.Tensor:
def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
return torch.where(
t > 0.0,
torch.ones_like(t).div_(t.sqrt().add_(eps)),
torch.tensor(0.0),
sos > 0.0,
g.div_(sos.sqrt().add_(eps)),
0.0,
)

else:

def f(t: torch.Tensor) -> torch.Tensor:
def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
return torch.where(
t > 0.0,
torch.ones_like(t).div(t.sqrt().add(eps)),
torch.tensor(0.0),
sos > 0.0,
g.div(sos.sqrt().add(eps)),
0.0,
)

inv_sqrt_g_square = tree_map(f, sum_of_squares)
updates = tree_map(lambda scale, g: g * scale, inv_sqrt_g_square, updates)
updates = tree_map(f, updates, sum_of_squares)
return updates, ScaleByRssState(sum_of_squares=sum_of_squares)

return GradientTransformation(init_fn, update_fn)
Expand Down
40 changes: 32 additions & 8 deletions torchopt/transform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,49 @@ def _update_moment(

if inplace:
if order == 2:
if decay != 1.0:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t

else:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.addcmul_(g, g) if g is not None else t

else:
if decay != 1.0:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t

else:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.add_(g) if g is not None else t

else:
if order == 2:
if decay != 1.0:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t

else:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.addcmul(g, g) if g is not None else t

else:
if decay != 1.0:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t

else:

def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t
def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor:
return t.add(g) if g is not None else t

if already_flattened:
return tree_map_flat(f, updates, moments, none_is_leaf=True)
Expand Down

0 comments on commit e9611fc

Please sign in to comment.