Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(examples/implicit): add iMAML example with OOP APIs #107

Merged
merged 4 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101).
- Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104).
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 273,10 @@ make install-editable # or run `pip3 install --no-build-isolation --editable .`
## Future Plan

- [X] CPU-accelerated optimizer
- [X] Support general implicit differentiation with functional programing
- [X] Support more optimizers such as AdamW, RMSProp
- [ ] Zero order optimization
- [ ] Distributed optimizers
- [X] Support general implicit differentiation
- [X] Zero order optimization
- [X] Distributed optimization
- [ ] Support `complex` data type

## Changelog
Expand Down
29 changes: 14 additions & 15 deletions docs/source/examples/MAML.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 99,7 @@ Define the ``train`` function:
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()

task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
task_num = x_spt.size(0)

# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
Expand Down Expand Up @@ -129,25 128,24 @@ Define the ``train`` function:
# These will be used to update the model's meta-parameters.
qry_logits = net(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss)
qry_accs.append(qry_acc.item())

torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)

qry_losses = torch.mean(torch.stack(qry_losses))
qry_losses.backward()
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
qry_losses = qry_losses.item()
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch float(batch_idx) / n_train_iter
iter_time = time.time() - start_time

print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)

log.append(
{
'epoch': i,
Expand Down Expand Up @@ -181,8 179,7 @@ Define the ``test`` function:
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')

task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
task_num = x_spt.size(0)

# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
Expand All @@ -201,15 198,17 @@ Define the ``test`` function:

# The query loss and acc induced by these parameters.
qry_logits = net(x_qry[i]).detach()
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss.item())
qry_accs.append(qry_acc.item())

torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)

qry_losses = torch.mean(torch.stack(qry_losses)).item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)

print(f'[Epoch {epoch 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
Expand Down
38 changes: 19 additions & 19 deletions examples/FuncTorch/maml_omniglot_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 39,10 @@
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""


import os
import sys


cur = os.path.abspath(os.path.dirname(__file__))
root = os.path.split(cur)[0]
sys.path.append(root '/few-shot')
import argparse
import functools
import pathlib
import sys
import time

import functorch
Expand All @@ -59,12 53,17 @@
import torch
import torch.nn.functional as F
import torch.optim as optim
from support.omniglot_loaders import OmniglotNShot
from torch import nn

import torchopt


CWD = pathlib(__file__).absolute().parent
sys.path.append(str(CWD.parent / 'few-shot'))

from helpers.omniglot_loaders import OmniglotNShot


mpl.use('Agg')
plt.style.use('bmh')

Expand Down Expand Up @@ -148,8 147,6 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
opt = torchopt.sgd(lr=1e-1)
opt_state = opt.init(params)

querysz = x_qry.size(0)

def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, x)
loss = F.cross_entropy(logits, y)
Expand All @@ -167,7 164,7 @@ def compute_loss(new_params, buffers, x, y):
# These will be used to update the model's meta-parameters.
qry_logits = fnet(new_params, buffers, x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
qry_acc = (qry_logits.argmax(dim=1) == y_qry).mean()

return qry_loss, qry_acc

Expand All @@ -192,18 189,19 @@ def train(db, net, device, meta_opt, epoch, log):
qry_losses, qry_accs = functorch.vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)

# Compute the maml loss by summing together the returned losses.
qry_losses.sum().backward()

qry_losses = torch.mean(torch.stack(qry_losses))
qry_losses.backward()
meta_opt.step()
qry_losses = qry_losses.detach().sum() / task_num
qry_accs = 100.0 * qry_accs.sum() / task_num
qry_losses = qry_losses.item()
qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item()
i = epoch float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)

log.append(
{
'epoch': i,
Expand Down Expand Up @@ -249,8 247,10 @@ def test(db, net, device, epoch, log):
qry_losses.append(qry_loss.detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())

qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
qry_losses = torch.mean(torch.stack(qry_losses)).item()
qry_accs = 100.0 * torch.mean(torch.stack(qry_accs)).item()
torch.cuda.empty_cache()

print(f'[Epoch {epoch 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions examples/L2R/l2r.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 39,9 @@


# isort: off
from helper.argument import parse_args
from helper.model import LeNet5
from helper.utils import get_imbalance_dataset, plot, set_seed
from helpers.argument import parse_args
from helpers.model import LeNet5
from helpers.utils import get_imbalance_dataset, plot, set_seed


def run_baseline(args, mnist_train, mnist_test):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions examples/LOLA/lola_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 21,10 @@


# isort: off
from helper.agent import Agent
from helper.argument import parse_args
from helper.env import IPD
from helper.utils import sample, step
from helpers.agent import Agent
from helpers.argument import parse_args
from helpers.env import IPD
from helpers.utils import sample, step


def main(args):
Expand Down
9 changes: 5 additions & 4 deletions examples/distributed/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 58,7 @@
import torchopt.distributed as todist


from support.omniglot_loaders import OmniglotNShot # isort: skip
from helpers.omniglot_loaders import OmniglotNShot # isort: skip


mpl.use('Agg')
Expand Down Expand Up @@ -187,7 187,6 @@ def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter):
x_qry = x_qry.to(device)
y_qry = y_qry.to(device)

querysz = x_qry.size(0)
inner_opt = torchopt.MetaSGD(net, lr=1e-1)

for _ in range(n_inner_iter):
Expand All @@ -197,7 196,7 @@ def inner_loop(net_rref, x_spt, y_spt, x_qry, y_qry, n_inner_iter):

qry_logits = net(x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry).cpu()
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum().cpu().item() / querysz
qry_acc = (qry_logits.argmax(dim=1) == y_qry).mean().cpu().item()

return qry_loss, qry_acc

Expand Down Expand Up @@ -232,11 231,11 @@ def train(db: OmniglotNShot, net: nn.Module, meta_opt: optim.Adam, epoch: int, l
qry_acc = 100.0 * qry_acc
i = epoch float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
)

log.append(
{
'epoch': i,
Expand Down Expand Up @@ -275,6 274,8 @@ def test(db, net, epoch, log):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
Expand Down
9 changes: 5 additions & 4 deletions examples/distributed/few-shot/maml_omniglot_local_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 60,7 @@
import torchopt.distributed as todist


from support.omniglot_loaders import OmniglotNShot # isort: skip
from helpers.omniglot_loaders import OmniglotNShot # isort: skip


mpl.use('Agg')
Expand Down Expand Up @@ -228,7 228,6 @@ def inner_loop(net_rref, n_inner_iter, task_id, task_num, mode):
x_qry = x_qry.to(device)
y_qry = y_qry.to(device)

querysz = x_qry.size(0)
inner_opt = torchopt.MetaSGD(net, lr=1e-1)

for _ in range(n_inner_iter):
Expand All @@ -238,7 237,7 @@ def inner_loop(net_rref, n_inner_iter, task_id, task_num, mode):

qry_logits = net(x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry).cpu()
qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum().cpu().item() / querysz
qry_acc = (qry_logits.argmax(dim=1) == y_qry).mean().cpu().item()

return qry_loss, qry_acc

Expand Down Expand Up @@ -275,11 274,11 @@ def train(net: nn.Module, meta_opt: optim.Adam, epoch: int, log: list):
qry_acc = 100.0 * qry_acc
i = epoch float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_loss:.2f} | Acc: {qry_acc:.2f} | Time: {iter_time:.2f}'
)

log.append(
{
'epoch': i,
Expand Down Expand Up @@ -319,6 318,8 @@ def test(net, epoch, log):

qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
Expand Down
29 changes: 15 additions & 14 deletions examples/few-shot/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 54,7 @@
import torchopt


from support.omniglot_loaders import OmniglotNShot # isort: skip
from helpers.omniglot_loaders import OmniglotNShot # isort: skip


mpl.use('Agg')
Expand Down Expand Up @@ -133,8 133,7 @@ def train(db, net, meta_opt, epoch, log):
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()

task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
task_num = x_spt.size(0)

# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
Expand Down Expand Up @@ -165,9 164,9 @@ def train(db, net, meta_opt, epoch, log):
# These will be used to update the model's meta-parameters.
qry_logits = net(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss)
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
qry_accs.append(qry_acc.item())

torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
Expand All @@ -176,14 175,14 @@ def train(db, net, meta_opt, epoch, log):
qry_losses.backward()
meta_opt.step()
qry_losses = qry_losses.item()
qry_accs = 100.0 * sum(qry_accs) / task_num
qry_accs = 100.0 * np.mean(qry_accs)
i = epoch float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
torch.cuda.empty_cache()

print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)

log.append(
{
'epoch': i,
Expand Down Expand Up @@ -211,8 210,7 @@ def test(db, net, epoch, log):
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')

task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
task_num = x_spt.size(0)

# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
Expand All @@ -231,15 229,18 @@ def test(db, net, epoch, log):

# The query loss and acc induced by these parameters.
qry_logits = net(x_qry[i]).detach()
qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
qry_losses.append(qry_loss.item())
qry_accs.append(qry_acc.item())

torchopt.recover_state_dict(net, net_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)

qry_losses = torch.mean(torch.stack(qry_losses)).item()
qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
qry_losses = np.mean(qry_losses)
qry_accs = 100.0 * np.mean(qry_accs)
torch.cuda.empty_cache()

print(f'[Epoch {epoch 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
{
Expand Down
Loading