Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adagrad optimizer support #80

Merged
merged 54 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift click to select a range
640e3b7
feat(torchopt): adagrad optimizer support
XuehaiPan Oct 13, 2022
ba6be61
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 15, 2023
21102bf
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Feb 15, 2023
035a429
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 17, 2023
c4a1899
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 20, 2023
bf029ae
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Feb 20, 2023
61552b2
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 23, 2023
3830947
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 27, 2023
5dcf35f
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 1, 2023
eb31c43
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
dac67fb
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
51bafa9
Merge branch 'feature/adagrad' of https://github.com/Benjamin-eecs/to…
Benjamin-eecs Mar 3, 2023
a953329
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
9786565
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
449bdb0
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
75b2bfb
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
3f28f98
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
ae56e25
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
91c7086
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
2f78e60
feat: adagrad integration
Benjamin-eecs Mar 4, 2023
c8e74f4
feat: adagrad integration
Benjamin-eecs Mar 4, 2023
fd4e257
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 7, 2023
95be0cb
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 9, 2023
9718cc0
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 9, 2023
7e76a7e
feat: adagrad integration
Benjamin-eecs Mar 11, 2023
1077916
feat: adagrad integration
Benjamin-eecs Mar 11, 2023
fc43b03
Merge branch 'feature/adagrad' of https://github.com/Benjamin-eecs/to…
Benjamin-eecs Mar 11, 2023
93d9daf
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 13, 2023
11c99d7
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 14, 2023
5e64fe1
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 17, 2023
adf641e
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Mar 17, 2023
bb50658
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 17, 2023
3ca005c
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
9a17c10
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
79036ed
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
85709e3
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
5636992
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
3ede2b4
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
4c4e1a3
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
d431749
feat(torchopt.optim): update torchopt/alias/adagrad.py
Benjamin-eecs Mar 20, 2023
03e5b42
feat: adagrad integration
Benjamin-eecs Mar 20, 2023
11d1c1f
Merge branch 'feature/adagrad' of https://github.com/Benjamin-eecs/to…
Benjamin-eecs Mar 20, 2023
ace6b5e
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
a937d6b
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
beab339
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
26afa12
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
91bb5c2
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
87b5219
Merge branch 'main' into feature/adagrad
XuehaiPan Mar 22, 2023
38482bf
fix: ca pi gu 💩
XuehaiPan Mar 22, 2023
61234cd
refactor: refactor scale_by_rss
XuehaiPan Mar 22, 2023
d34c534
test: fix eps value for AdaGrad
XuehaiPan Mar 22, 2023
ccf87e2
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Mar 22, 2023
c188839
revert: revert change in torchopt/version.py
XuehaiPan Mar 22, 2023
fbb68f8
chore: update CHANGELOG
Benjamin-eecs Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: adagrad integration
  • Loading branch information
Benjamin-eecs committed Mar 21, 2023
commit ace6b5efa517c38b4b88737d08d2a9d349ac2f56
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 114,7 @@ addlicense-install: go-install

pytest: test-install
cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \
$(PYTHON) -m pytest -k "test_exponential_decay" --verbose --color=yes --durations=0 \
$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
--cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
$(PYTESTOPTS) .

Expand Down
26 changes: 13 additions & 13 deletions torchopt/alias/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 39,15 @@
scale_by_neg_lr,
)
from torchopt.combine import chain
from torchopt.transform import scale_by_rss
from torchopt.typing import GradientTransformation, Numeric, Scalar, Schedule
from torchopt.transform import scale_by_rss, scale_by_schedule
from torchopt.typing import GradientTransformation, Numeric, Scalar, ScalarOrSchedule, Schedule


__all__ = ['adagrad']


# pylint: disable-next=too-many-arguments
def _adagrad_lr_decay(
init_value: Scalar,
def _adagrad_lr_schedule(
decay_rate: Scalar,
transition_begin: int = 0,
) -> Schedule:
Expand All @@ -61,10 60,10 @@ def _adagrad_lr_decay(
```

Args:
init_value: the initial learning rate.
decay_rate: The decay rate.
transition_begin: must be positive. After how many steps to start annealing
decay_rate (float, optional): The decay rate.
transition_begin (int, optional): must be positive. After how many steps to start annealing
(before this many steps the scalar value is held fixed at `init_value`).
(default: :const:`1`)

Returns:
schedule: A function that maps step counts to values.
Expand All @@ -78,14 77,14 @@ def _adagrad_lr_decay(

def schedule(count: Numeric) -> Numeric:
decreased_count = count - transition_begin
return init_value / (1 decay_rate * decreased_count)
return 1 / (1 decay_rate * decreased_count)

return schedule


# pylint: disable-next=too-many-arguments
def adagrad(
lr: Scalar = 1e-2,
lr: ScalarOrSchedule = 1e-2,
lr_decay: float = 0.0,
weight_decay: float = 0.0,
initial_accumulator_value: float = 0.0,
Expand Down Expand Up @@ -144,21 143,22 @@ def adagrad(
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
adagrad_scaler_fn = scale_by_rss
scale_by_neg_lr_fn = scale_by_neg_lr
schedule_fn = _adagrad_lr_decay
step_size_fn = _adagrad_lr_schedule
scale_by_schedule_fn = scale_by_schedule

if _get_use_chain_flat(): # default behavior
chain_fn = chain_fn.flat # type: ignore[attr-defined]
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined]
adagrad_scaler_fn = adagrad_scaler_fn.flat # type: ignore[attr-defined]
scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined]
scale_by_schedule_fn = scale_by_schedule_fn.flat # type: ignore[attr-defined]

return chain_fn(
flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize),
adagrad_scaler_fn(
initial_accumulator_value=initial_accumulator_value,
eps=eps,
),
scale_by_neg_lr_fn(
schedule_fn(init_value=lr, decay_rate=lr_decay, transition_begin=0),
),
scale_by_schedule_fn(step_size_fn=step_size_fn(decay_rate=lr_decay, transition_begin=0)),
scale_by_neg_lr_fn(lr),
)
2 changes: 1 addition & 1 deletion torchopt/schedule/exponential_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 71,7 @@ def exponential_decay(
entire annealing process is disabled and the value is held fixed at ``init_value``.
(default: :const:`1`)
staircase (bool): If ``True``, decay the scalar at discrete intervals.
end_value (float or Tensor): End value of the scalar to be annealed.
end_value (float or Tensor, optional): End value of the scalar to be annealed.

Returns:
schedule: A function that maps step counts to values.
Expand Down