Skip to content

Commit

Permalink
fix issue #2
Browse files Browse the repository at this point in the history
  • Loading branch information
teowu committed Dec 19, 2022
1 parent fe28636 commit bfe0450
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 102 deletions.
132 changes: 81 additions & 51 deletions default_infer.py
Original file line number Diff line number Diff line change
@@ -1,37 1,37 @@
import torch
import cv2
import random
import os.path as osp
import dover.models as models
import dover.datasets as datasets

import argparse
import math
import os.path as osp
import pickle
import random
from time import time

from scipy.stats import spearmanr, pearsonr
from scipy.stats.stats import kendalltau as kendallr
import cv2
import numpy as np

from time import time
import torch
import yaml
from scipy.stats import kendalltau as kendallr
from scipy.stats import pearsonr, spearmanr
from thop import profile
from tqdm import tqdm
import pickle
import math

import dover.datasets as datasets
import dover.models as models
import wandb
import yaml

from thop import profile


def rescale(pr, gt=None):
if gt is None:
print(np.mean(pr), np.std(pr))
print("mean", np.mean(pr), "std", np.std(pr))
pr = (pr - np.mean(pr)) / np.std(pr)
else:
print(np.mean(pr), np.std(pr), np.std(gt), np.mean(gt))
pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) np.mean(gt)
return pr

sample_types=["aesthetic", "technical"]

sample_types = ["aesthetic", "technical"]


def profile_inference(inf_set, model, device):
Expand All @@ -41,17 41,33 @@ def profile_inference(inf_set, model, device):
if key in data:
video[key] = data[key].to(device)
c, t, h, w = video[key].shape
video[key] = video[key].reshape(1, c, data["num_clips"][key], t // data["num_clips"][key], h, w).permute(0,2,1,3,4,5).reshape( data["num_clips"][key], c, t // data["num_clips"][key], h, w)
video[key] = (
video[key]
.reshape(
1, c, data["num_clips"][key], t // data["num_clips"][key], h, w
)
.permute(0, 2, 1, 3, 4, 5)
.reshape(data["num_clips"][key], c, t // data["num_clips"][key], h, w)
)
with torch.no_grad():
flops, params = profile(model, (video, ))
print(f"The FLOps of the Variant is {flops/1e9:.1f}G, with Params {params/1e6:.2f}M.")
flops, params = profile(model, (video,))
print(
f"The FLOps of the Variant is {flops/1e9:.1f}G, with Params {params/1e6:.2f}M."
)


def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s', set_name="na"):
def inference_set(
inf_loader, model, device, best_, save_model=False, suffix="s", set_name="na"
):
print(f"Validating for {set_name}.")
results = []
try:
model = torch.compile(model)
except:
print("You may try to accelerate your model with torch 2.0")

best_s, best_p, best_k, best_r = best_

keys = []

for i, data in enumerate(tqdm(inf_loader, desc="Validating")):
Expand All @@ -63,9 79,18 @@ def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s'
if key in data:
video[key] = data[key].to(device)
b, c, t, h, w = video[key].shape
video[key] = video[key].reshape(b, c, data["num_clips"][key], t // data["num_clips"][key], h, w).permute(0,2,1,3,4,5).reshape(b * data["num_clips"][key], c, t // data["num_clips"][key], h, w)
video[key] = (
video[key]
.reshape(
b, c, data["num_clips"][key], t // data["num_clips"][key], h, w
)
.permute(0, 2, 1, 3, 4, 5)
.reshape(
b * data["num_clips"][key], c, t // data["num_clips"][key], h, w
)
)
with torch.no_grad():
labels = model(video,reduce_scores=False)
labels = model(video, reduce_scores=False)
labels = [np.mean(l.cpu().numpy()) for l in labels]
result["pr_labels"] = labels
result["gt_label"] = data["gt_label"].item()
Expand All @@ -74,7 99,6 @@ def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s'
# del data
results.append(result)


## generate the demo video for video quality localization
gt_labels = [r["gt_label"] for r in results]
pr_labels = 0
Expand All @@ -84,22 108,28 @@ def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s'
key_pr_labels = rescale([np.mean(r["pr_labels"][i]) for r in results])
pr_labels = key_pr_labels * w
pr_dict[key] = key_pr_labels
#with open(f"dover_predictions/{set_name}.pkl", "wb") as f:

# with open(f"dover_predictions/{set_name}.pkl", "wb") as f:
# pickle.dump(pr_dict, f)
print(pr_labels)
pr_labels = rescale(pr_labels, gt_labels)

s = spearmanr(gt_labels, pr_labels)[0]
p = pearsonr(gt_labels, pr_labels)[0]
k = kendallr(gt_labels, pr_labels)[0]
r = np.sqrt(((gt_labels - pr_labels) ** 2).mean())



results = sorted(results, key=lambda x: x["pr_labels"])

try:
wandb.log({f"val/SRCC-{suffix}": s, f"val/PLCC-{suffix}": p, f"val/KRCC-{suffix}": k, f"val/RMSE-{suffix}": r})
wandb.log(
{
f"val/SRCC-{suffix}": s,
f"val/PLCC-{suffix}": p,
f"val/KRCC-{suffix}": k,
f"val/RMSE-{suffix}": r,
}
)
except:
pass

Expand Down Expand Up @@ -127,6 157,7 @@ def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s'

return best_s, best_p, best_k, best_r, pr_labels


def main():

parser = argparse.ArgumentParser()
Expand All @@ -138,44 169,45 @@ def main():
with open(args.opt, "r") as f:
opt = yaml.safe_load(f)
print(opt)




## adaptively choose the device

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
# device = "cpu"

## defining model and loading checkpoint

bests_ = []

model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device)

state_dict = torch.load(opt["test_load_path"], map_location=device)#["state_dict"]

state_dict = torch.load(
opt["test_load_path"], map_location=device
) # ["state_dict"]

model.load_state_dict(state_dict, strict=True)

for key in opt["data"].keys():

if "val" not in key and "test" not in key:
continue

run = wandb.init(
project=opt["wandb"]["project_name"],
name=opt["name"] "_Test_" key,
name=opt["name"] "_Test_" key,
reinit=True,
)

val_dataset = getattr(datasets, opt["data"][key]["type"])(opt["data"][key]["args"])


val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1, num_workers=opt["num_workers"], pin_memory=True,
val_dataset = getattr(datasets, opt["data"][key]["type"])(
opt["data"][key]["args"]
)


val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=opt["num_workers"],
pin_memory=True,
)

profile_inference(val_dataset, model, device)

Expand All @@ -184,11 216,11 @@ def main():

best_ = -1, -1, -1, 1000


best_ = inference_set(
val_loader,
model,
device, best_,
device,
best_,
set_name=key,
)

Expand All @@ -199,11 231,9 @@ def main():
KROCC: {best_[2]:.4f}
RMSE: {best_[3]:.4f}."""
)


run.finish()



if __name__ == "__main__":
main()
Loading

0 comments on commit bfe0450

Please sign in to comment.