Skip to content

Commit

Permalink
implemented skipping generating too many test/eval/straight files
Browse files Browse the repository at this point in the history
  • Loading branch information
jzuern committed Apr 28, 2023
1 parent 2bae3e7 commit 53f82ed
Showing 1 changed file with 72 additions and 4 deletions.
76 changes: 72 additions & 4 deletions aggregation/aggregate_av2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 146,14 @@ def process_samples(args, city_name, trajectories_vehicles_, trajectories_ped_,

print("In process_samples for city: {}".format(city_name))

num_train_samples = 0
num_eval_samples = 0
num_test_samples = 0

num_branching = 0
num_straight = 0


if args.source == "lanegraph":

edge_0_pos = np.array([G_annot.nodes[edge[0]]['pos'] for edge in G_annot.edges()])
Expand All @@ -169,14 177,12 @@ def process_samples(args, city_name, trajectories_vehicles_, trajectories_ped_,
curr_node = successors[np.random.randint(0, len(successors))]
agent_trajectory.append(curr_node)


# leave out the last nodes cause otherwise future trajectory is ending in image
agent_trajectory = agent_trajectory[0:-50]
agent_trajectory = agent_trajectory[::10]
if len(agent_trajectory) == 0:
continue


# Iterate over agent trajectory:
for t in range(0, len(agent_trajectory)-1, 2):

Expand Down Expand Up @@ -209,6 215,11 @@ def process_samples(args, city_name, trajectories_vehicles_, trajectories_ped_,
if dataset_split == "train":
continue

if num_eval_samples > 2000 and dataset_split == "eval":
continue
if num_test_samples > 2000 and dataset_split == "test":
continue

out_path = os.path.join(out_path_root, dataset_split)
sample_id = "{}-{}-{}".format(city_name, x_noise, y_noise)

Expand Down Expand Up @@ -317,6 328,14 @@ def process_samples(args, city_name, trajectories_vehicles_, trajectories_ped_,
else:
sample_type = "straight"

# Skip too many straight samples
if sample_type == "straight":
if num_straight >= 2 * num_branching:
continue




do_debugging = False
if do_debugging:
sat_image_crop_viz = cv2.cvtColor(sat_image_crop, cv2.COLOR_BGR2RGB)
Expand Down Expand Up @@ -387,8 406,6 @@ def viz(event, mouseX, mouseY, flags, param):
pos_encoding[..., 2] = np.abs((y - q[1])) / sat_image_crop.shape[0]
pos_encoding = (pos_encoding * 255).astype(np.uint8)

sample_num = 1

print("---- TID: {}/{}: Sample {}/{}/{}/{} ({}/{}) - Samples / s = {:.2f}".format(args.thread_id, args.num_parallel,
out_path, sample_type, sample_id,
i_query, sample_num, max_num_samples,
Expand All @@ -407,6 424,25 @@ def viz(event, mouseX, mouseY, flags, param):
Image.fromarray(mask_angle_colorized).save(
"{}/{}/{}-{}-angles.png".format(out_path, sample_type, sample_id, i_query))

sample_num = 1

if dataset_split == "train":
num_train_samples = 1
elif dataset_split == "eval":
num_eval_samples = 1
elif dataset_split == "test":
num_test_samples = 1
else:
continue

if sample_type == "branching":
num_branching = 1
elif sample_type == "straight":
num_straight = 1
else:
continue


elif "tracklets" in args.source:
annot_veh_ = trajectories_vehicles_
centers = [np.mean(t, axis=0) for t in annot_veh_]
Expand All @@ -433,6 469,12 @@ def viz(event, mouseX, mouseY, flags, param):
if dataset_split == "train":
continue

if num_eval_samples > 2000 and dataset_split == "eval":
continue
if num_test_samples > 2000 and dataset_split == "test":
continue


out_path = os.path.join(out_path_root, dataset_split)

sample_id = "{}-{}-{}".format(city_name, crop_center[0], crop_center[1])
Expand Down Expand Up @@ -540,6 582,15 @@ def viz(event, mouseX, mouseY, flags, param):
sample_type = "branching"
else:
sample_type = "straight"


# Skip too many straight samples
if sample_type == "straight":
if num_straight >= 2 * num_branching:
continue



# Filter out all samples that do not fit in quality criteria
if len(succ_traj) < N_MIN_SUCC_TRAJECTORIES:
logging.debug("Too few successor trajectories")
Expand Down Expand Up @@ -583,6 634,23 @@ def viz(event, mouseX, mouseY, flags, param):
Image.fromarray(drivable_gt_crop.astype(np.uint8)).save("{}/{}/{}-{}-drivable-gt.png".format(out_path, sample_type, sample_id, i_query))
Image.fromarray(mask_angle_colorized).save("{}/{}/{}-{}-angles.png".format(out_path, sample_type, sample_id, i_query))

sample_num = 1

if dataset_split == "train":
num_train_samples = 1
elif dataset_split == "eval":
num_eval_samples = 1
elif dataset_split == "test":
num_test_samples = 1
else:
continue

if sample_type == "branching":
num_branching = 1
elif sample_type == "straight":
num_straight = 1
else:
continue

else:
raise ValueError("Invalid source")
Expand Down

0 comments on commit 53f82ed

Please sign in to comment.