Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
teowu committed Dec 20, 2022
1 parent 1726300 commit 2ed0855
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 40 deletions.
1 change: 0 additions & 1 deletion default_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 111,6 @@ def inference_set(

# with open(f"dover_predictions/{set_name}.pkl", "wb") as f:
# pickle.dump(pr_dict, f)
print(pr_labels)
pr_labels = rescale(pr_labels, gt_labels)

s = spearmanr(gt_labels, pr_labels)[0]
Expand Down
64 changes: 25 additions & 39 deletions transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 50,10 @@ def rank_loss(y_pred, y):
torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale
).float()

def gaussian(y, eps=1e-8):
return (y - y.mean()) / (y.std() 1e-8)


def plcc_loss(y_pred, y):
sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False)
y_pred = (y_pred - m_hat) / (sigma_hat 1e-8)
Expand All @@ -67,9 71,10 @@ def rescaled_l2_loss(y_pred, y):

def rplcc_loss(y_pred, y, eps=1e-8):
## Literally (1 - PLCC) / 2
cov = torch.cov(y_pred, y)
std = (torch.std(y_pred) eps) * (torch.std(y) eps)
return (1 - cov / std) / 2
y_pred, y = gaussian(y_pred), gaussian(y)
cov = torch.sum(y_pred * y) / y_pred.shape[0]
#std = (torch.std(y_pred) eps) * (torch.std(y) eps)
return (1 - cov) / 2

def self_similarity_loss(f, f_hat, f_hat_detach=False):
if f_hat_detach:
Expand Down Expand Up @@ -103,46 108,20 @@ def finetune_epoch(ft_loader, model, model_ema, optimizer, scheduler, device, ep
if key in data:
video[key] = data[key].to(device)

if need_upsampled:
up_video = {}
for key in sample_types:
if key "_up" in data:
up_video[key] = data[key "_up"].to(device)

y = data["gt_label"].float().detach().to(device).unsqueeze(-1)
if need_feat:
scores, feats = model(video, inference=False,
return_pooled_feats=True,

scores = model(video, inference=False,
reduce_scores=False)
if len(scores) > 1:
y_pred = reduce(lambda x,y:x y, scores)
else:
y_pred = scores[0]
y_pred = y_pred.mean((-3, -2, -1))
if len(scores) > 1:
y_pred = reduce(lambda x,y:x y, scores)
else:
scores = model(video, inference=False,
reduce_scores=False)
if len(scores) > 1:
y_pred = reduce(lambda x,y:x y, scores)
else:
y_pred = scores[0]
y_pred = y_pred.mean((-3, -2, -1))
if need_upsampled:
if need_feat:
scores_up, feats_up = model(up_video, inference=False,
return_pooled_feats=True,
reduce_scores=False)
if len(scores) > 1:
y_pred_up = reduce(lambda x,y:x y, scores_up)
else:
y_pred_up = scores_up[0]
y_pred_up = y_pred_up.mean((-3, -2, -1))
else:
y_pred_up = model(up_video, inference=False).mean((-3, -2, -1))
y_pred = scores[0]
y_pred = y_pred.mean((-3, -2, -1))

frame_inds = data["frame_inds"]

# Plain Supervised Loss
p_loss, r_loss = plcc_loss(y_pred, y), rank_loss(y_pred, y)
p_loss, r_loss = rplcc_loss(y_pred, y), rank_loss(y_pred, y)

loss = p_loss 0.3 * r_loss
wandb.log(
Expand Down Expand Up @@ -410,13 389,20 @@ def main():

run = wandb.init(
project=opt["wandb"]["project_name"],
name=opt["name"] f'_split_{split}' if num_splits > 1 else opt["name"],
name=opt["name"] f'_target_{args.target_set}_split_{split}' if num_splits > 1 else opt["name"],
reinit=True,
settings=wandb.Settings(start_method='thread'),
)

state_dict = torch.load(opt["test_load_path"], map_location=device)

model.load_state_dict(state_dict, strict = True)
head_removed_state_dict = OrderedDict()
for key, v in state_dict.items():
if "head" not in key:
head_removed_state_dict[key] = v

# Allowing empty head weight
model.load_state_dict(state_dict, strict=False)

if opt["ema"]:
from copy import deepcopy
Expand Down

0 comments on commit 2ed0855

Please sign in to comment.