Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add AdamW optimizer #44

Merged
merged 33 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
70b8a47
feat(torchopt): init adamw optimizer
Benjamin-eecs Jul 27, 2022
3300142
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Aug 4, 2022
253cc2a
fix(torchopt): pass adamw tests
Benjamin-eecs Aug 4, 2022
17d5784
fix: force add adamw.py
Benjamin-eecs Aug 4, 2022
cdc3836
feat: add MetaAdamW test and pass lint
Benjamin-eecs Aug 5, 2022
cc3a3c7
feat: add MetaAdamW test and pass lint
Benjamin-eecs Aug 5, 2022
a071550
fix: pass lint and pass MetaAdamW tests
Benjamin-eecs Aug 5, 2022
89fac53
fix: rewrite MetaOptimizer test, pass MetaAdamW tests with error tol
Benjamin-eecs Aug 5, 2022
b50abe0
merge: resolve conflicts
Benjamin-eecs Aug 24, 2022
47ff9f3
merge: resolve conflicts
Benjamin-eecs Aug 24, 2022
476332e
fix: update adamw low level test
Benjamin-eecs Aug 26, 2022
8175181
merge: resolve conflicts
Benjamin-eecs Sep 1, 2022
bb82209
fix(tests): use new test
Benjamin-eecs Sep 4, 2022
4b01c7e
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Sep 4, 2022
d935014
fix: pass lint
Benjamin-eecs Sep 4, 2022
47cfa45
fix: pass test
Benjamin-eecs Sep 4, 2022
9b32e7b
Merge remote-tracking branch 'upstream/main' into feature/adamw
Benjamin-eecs Sep 4, 2022
42ed8a5
fix: pass test
Benjamin-eecs Sep 4, 2022
1e64877
fix: pass test
Benjamin-eecs Sep 4, 2022
872b8d4
fix: update docstring
Benjamin-eecs Sep 4, 2022
824d1c5
fix: update docstring
Benjamin-eecs Sep 4, 2022
e920c74
fix: update docstring
Benjamin-eecs Sep 4, 2022
8ee3c41
fix: correct already_flattened
Benjamin-eecs Sep 4, 2022
0f129c0
fix: correct weight_decay range check
Benjamin-eecs Sep 4, 2022
e75671e
fix: already_flattened of mask
Benjamin-eecs Sep 4, 2022
c791bba
style: format code
XuehaiPan Sep 5, 2022
24690a0
feat: add shortcut
XuehaiPan Sep 5, 2022
fec6f99
chore: reorganize code structure
XuehaiPan Sep 5, 2022
d3ad838
feat: inplace support for AdamW
XuehaiPan Sep 5, 2022
c685954
docs: update docstrings
XuehaiPan Sep 5, 2022
8114286
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 5, 2022
c075533
docs: update docstrings
XuehaiPan Sep 5, 2022
0f5c90a
docs: update docstrings
XuehaiPan Sep 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: force add adamw.py
  • Loading branch information
Benjamin-eecs committed Aug 4, 2022
commit 17d5784e2b61897afc4fcdcd0b1af3818f2f5a15
84 changes: 84 additions & 0 deletions torchopt/_src/optimizer/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Any, Callable, Iterable, Optional, Union

import torch

from torchopt._src import base
from torchopt._src.alias import adamw
from torchopt._src.optimizer.base import Optimizer
from torchopt._src.typing import ScalarOrSchedule


class AdamW(Optimizer):
"""The classic RMSProp optimizer.

See Also:
- The functional RMSProp optimizer: :func:`torchopt.rmsprop`.
- The differentiable meta-RMSProp optimizer: :class:`torchopt.MetaRMSProp`.
"""

# pylint: disable=too-many-arguments
def __init__(
self,
params: Iterable[torch.Tensor],
lr: ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
use_accelerated_op: bool = False,
weight_decay: float = 0.01,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
):
r"""The `init` function.

Args:
learning_rate: this is a fixed global scaling factor.
b1: the exponential decay rate to track the first moment of past gradients.
b2: the exponential decay rate to track the second moment of past gradients.
eps: a small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: (default `0`), a small constant applied to denominator inside the
square root (as in RMSProp), to avoid dividing by zero when rescaling.
This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype: optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: a tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
"""
super().__init__(
params,
adamw(
lr=lr,
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
moment_requires_grad=False,
use_accelerated_op=use_accelerated_op,
weight_decay=weight_decay,
mask=mask,
),
)
85 changes: 85 additions & 0 deletions torchopt/_src/optimizer/meta/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Any, Callable, Optional, Union

import torch.nn as nn

from torchopt._src import base
from torchopt._src.alias import adamw
from torchopt._src.optimizer.meta.base import MetaOptimizer
from torchopt._src.typing import ScalarOrSchedule


class MetaAdamW(MetaOptimizer):
"""The classic RMSProp optimizer.

See Also:
- The functional RMSProp optimizer: :func:`torchopt.rmsprop`.
- The differentiable meta-RMSProp optimizer: :class:`torchopt.MetaRMSProp`.
"""

# pylint: disable=too-many-arguments
def __init__(
self,
net: nn.Module,
lr: ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
moment_requires_grad: bool = False,
use_accelerated_op: bool = False,
weight_decay: float = 0.01,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
):
r"""The `init` function.

Args:
learning_rate: this is a fixed global scaling factor.
b1: the exponential decay rate to track the first moment of past gradients.
b2: the exponential decay rate to track the second moment of past gradients.
eps: a small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: (default `0`), a small constant applied to denominator inside the
square root (as in RMSProp), to avoid dividing by zero when rescaling.
This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype: optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: a tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
"""
super().__init__(
net,
adamw(
lr=lr,
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
moment_requires_grad=moment_requires_grad,
use_accelerated_op=use_accelerated_op,
weight_decay=weight_decay,
mask=mask,
),
)