Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi id cross id rendering enable #4

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions configs/config-4.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
device: "cuda"
train:
dataset_dir: "/uca/julieta/oss/release/4TB/" # where the dataset is
data_csv: "256_ids.csv" # csv of identities to use
dataset_dir: "/home/emilykim/Desktop/MetaProject/AVA_dataset_8TB/" # where the dataset is
data_csv: "viz/1_id.csv" # csv of identities to use
nids: 4 # number of identities to train on
# checkpoint: "checkpoints/aeparams.pt" # checkpoint to resume training from
# checkpoint: ""
checkpoint: "/checkpoint/avatar/julietamartinez/ava-256/checkpoints/4ids/aeparams_300000.pt"
checkpoint: ""
# checkpoint: "/checkpoint/avatar/julietamartinez/ava-256/checkpoints/4ids/aeparams_300000.pt"
maxiter: 10_000_000 # maximum number of iterations to train for
num_epochs: 10 # number of epochs to train for
init_learning_rate: 2.0e-4 # learning rate
Expand Down
88 changes: 1 addition & 87 deletions ddp-train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from data.utils import get_framelist_neuttex_and_neutvert
from losses import mean_ell_1
from models.bottlenecks.vae import kl_loss_stable
from utils import get_autoencoder, load_checkpoint, render_img, tocuda, train_csv_loader
from utils import get_autoencoder, load_checkpoint, render_img, tocuda, train_csv_loader, xid_eval

sys.dont_write_bytecode = True

Expand Down Expand Up @@ -177,92 +177,6 @@ def prepare(rank, world_size, train_params):
return dataset, all_neut_avgtex_vert, dataloader, driver_dataloader


def xid_eval(model, driver_dataiter, all_neut_avgtex_vert, config, output_set, outpath, rank, iternum):
starttime = time.time()

indices_subjects = random.sample(range(0, len(all_neut_avgtex_vert)), config.progress.cross_id_n_subjects)
indices_subjects.sort()
model.eval()

with torch.no_grad():
driver = next(driver_dataiter)
while driver is None:
driver = next(driver_dataiter)

cudadriver: Dict[str, Union[torch.Tensor, int, str]] = tocuda(driver)

gt = cudadriver["image"].detach().cpu().numpy()
gt = einops.rearrange(gt, "1 c h w -> h w c")
renderImages_xid = []
renderImages_xid.append(gt)

running_avg_scale = False
gt_geo = None
residuals_weight = 1.0

output_driver = model(
cudadriver["camrot"],
cudadriver["campos"],
cudadriver["focal"],
cudadriver["princpt"],
cudadriver["modelmatrix"],
cudadriver["avgtex"],
cudadriver["verts"],
cudadriver["neut_avgtex"],
cudadriver["neut_verts"],
cudadriver["neut_avgtex"],
cudadriver["neut_verts"],
cudadriver["pixelcoords"],
cudadriver["idindex"],
cudadriver["camindex"],
running_avg_scale=running_avg_scale,
gt_geo=gt_geo,
residuals_weight=residuals_weight,
output_set=output_set,
)

rgb_driver = output_driver["irgbrec"].detach().cpu().numpy()
rgb_driver = einops.rearrange(rgb_driver, "1 c h w -> h w c")
del output_driver
renderImages_xid.append(rgb_driver)

for i in indices_subjects:
if i == int(cudadriver["idindex"][0]):
continue
cudadriven: Dict[str, Union[torch.Tensor, int, str]] = tocuda(all_neut_avgtex_vert[i])

output_driven = model(
cudadriver["camrot"],
cudadriver["campos"],
cudadriver["focal"],
cudadriver["princpt"],
cudadriver["modelmatrix"],
cudadriver["avgtex"],
cudadriver["verts"],
cudadriver["neut_avgtex"],
cudadriver["neut_verts"],
torch.unsqueeze(cudadriven["neut_avgtex"], 0),
torch.unsqueeze(cudadriven["neut_verts"], 0),
cudadriver["pixelcoords"],
cudadriver["idindex"],
cudadriver["camindex"],
running_avg_scale=running_avg_scale,
gt_geo=gt_geo,
residuals_weight=residuals_weight,
output_set=output_set,
)
rgb_driven = output_driven["irgbrec"].detach().cpu().numpy()
rgb_driven = einops.rearrange(rgb_driven, "1 c h w -> h w c")
del output_driven
renderImages_xid.append(rgb_driven)
del cudadriven
del cudadriver
if rank == 0:
render_img([renderImages_xid], f"{outpath}/x-id/progress_{iternum}.png")

print(f"Cross ID viz took {time.time() - starttime}")


def main(rank, config, args):
"""
Rank is normally set automatically by mp.spawn()
Expand Down
137 changes: 36 additions & 101 deletions render.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
from data.ava_dataset import MultiCaptureDataset as AvaMultiCaptureDataset
from data.ava_dataset import SingleCaptureDataset as AvaSingleCaptureDataset
from data.ava_dataset import none_collate_fn
from data.utils import MugsyCapture
from utils import get_autoencoder, load_checkpoint, render_img, tocuda, train_csv_loader
from data.utils import MugsyCapture, get_framelist_neuttex_and_neutvert
from utils import get_autoencoder, load_checkpoint, train_csv_loader, xid_eval

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Visualize Cross ID driving")
parser.add_argument("--checkpoint", type=str, default="checkpoints/aeparams.pt", help="checkpoint location")
parser.add_argument("--checkpoint", type=str, default="checkpoints/aeparams_300000.pt", help="checkpoint location")
parser.add_argument("--output-dir", type=str, default="viz/", help="output image directory")
parser.add_argument("--config", default="configs/config.yaml", type=str, help="config yaml file")
parser.add_argument("--config", default="configs/config-4.yaml", type=str, help="config yaml file")

# Cross ID visualization configuration
parser.add_argument("--driver-id", type=str, default="20230324--0820--AEY864", help="id of the driver avatar")
parser.add_argument("--driven-id", type=str, default="20230831--0814--ADL311", help="id of the driven avatar")
parser.add_argument("--driver-id", type=str, default="20230405--1635--AAN112", help="id of the driver avatar")
parser.add_argument("--driven-id-indices", type=list, default=[1, 2, 3], help="id of the driven avatar")
parser.add_argument("--camera-id", type=str, default="401031", help="render camera id")
parser.add_argument(
"--segment-id",
type=str,
default="EXP_eyes_blink_light_medium_hard_wink",
default="SEN_all_your_wishful_thinking_wont_change_that",
help="segment to render; render all available frames if None",
)
parser.add_argument("--opts", default=[], type=str, nargs="+")
Expand All @@ -46,9 +46,9 @@

train_params = config.train

output_dir = args.output_dir + "/" + args.driver_id + "_" + args.driven_id + "+" + args.segment_id
output_dir = args.output_dir + "/" + args.driver_id + "+" + args.segment_id

pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_dir + "/x-id").mkdir(parents=True, exist_ok=True)

# Train dataset mean/std texture and vertex for normalization
train_captures, train_dirs = train_csv_loader(train_params.dataset_dir, train_params.data_csv, train_params.nids)
Expand Down Expand Up @@ -83,13 +83,6 @@
driver_dir = f"{train_params.dataset_dir}/{args.driver_id}/decoder"
driver_dataset = AvaSingleCaptureDataset(driver_capture, driver_dir, downsample=train_params.downsample)

# Driven capture dataloader
driven_capture = MugsyCapture(
mcd=args.driven_id.split("--")[0], mct=args.driven_id.split("--")[1], sid=args.driven_id.split("--")[2]
)
driven_dir = f"{train_params.dataset_dir}/{args.driven_id}/decoder"
driven_dataset = AvaSingleCaptureDataset(driven_capture, driven_dir, downsample=train_params.downsample)

texmean = dataset.texmean
vertmean = dataset.vertmean
texstd = dataset.texstd
Expand All @@ -99,7 +92,7 @@
del dataset

# Grab driven normalization stats
for dataset in [driver_dataset, driven_dataset]:
for dataset in [driver_dataset]: # ,driven_dataset]:
dataset.texmean = texmean
dataset.texstd = texstd
dataset.vertmean = vertmean
Expand Down Expand Up @@ -132,92 +125,34 @@
collate_fn=none_collate_fn,
)

driven_loader = torch.utils.data.DataLoader(
driven_dataset,
batch_size=batchsize,
shuffle=False,
drop_last=False,
num_workers=numworkers,
collate_fn=none_collate_fn,
)
driveniter = iter(driven_loader)
driven = next(driveniter)

while driven is None:
driven = next(driveniter)
driver_dataiter = iter(driver_loader)

it = 0

for driver in tqdm(driver_loader, desc="Rendering Frames"):
# Skip if any of the frames is empty
if driver is None:
continue

cudadriver: Dict[str, Union[torch.Tensor, int, str]] = tocuda(driver)
cudadriven: Dict[str, Union[torch.Tensor, int, str]] = tocuda(driven)

running_avg_scale = False
gt_geo = None
residuals_weight = 1.0
output_set = set(["irgbrec", "bg"])

# Generate image from original inputs
output_orig = ae(
camrot=cudadriver["camrot"],
campos=cudadriver["campos"],
focal=cudadriver["focal"],
princpt=cudadriver["princpt"],
modelmatrix=cudadriver["modelmatrix"],
avgtex=cudadriver["avgtex"],
verts=cudadriver["verts"],
neut_avgtex=cudadriver["neut_avgtex"],
neut_verts=cudadriver["neut_verts"],
target_neut_avgtex=cudadriver["neut_avgtex"],
target_neut_verts=cudadriver["neut_verts"],
pixelcoords=cudadriver["pixelcoords"],
idindex=cudadriver["idindex"],
camindex=cudadriver["camindex"],
running_avg_scale=running_avg_scale,
gt_geo=gt_geo,
residuals_weight=residuals_weight,
output_set=output_set,
)

# Generate image from cross id texture and vertex
output_driven = ae(
camrot=cudadriver["camrot"],
campos=cudadriver["campos"],
focal=cudadriver["focal"],
princpt=cudadriver["princpt"],
modelmatrix=cudadriver["modelmatrix"],
avgtex=cudadriver["avgtex"],
verts=cudadriver["verts"],
# normalized using the train data stats and driven data stats
neut_avgtex=cudadriver["neut_avgtex"],
neut_verts=cudadriver["neut_verts"],
target_neut_avgtex=cudadriven["neut_avgtex"],
target_neut_verts=cudadriven["neut_verts"],
pixelcoords=cudadriver["pixelcoords"],
idindex=cudadriver["idindex"],
camindex=cudadriver["camindex"],
running_avg_scale=running_avg_scale,
gt_geo=gt_geo,
residuals_weight=residuals_weight,
output_set=output_set,
# Store neut avgtex and neut vert from all ids for x-id check
all_neut_avgtex_vert = []

for directory in train_dirs:
_, neut_avgtex, neut_vert = get_framelist_neuttex_and_neutvert(directory)

neut_avgtex = (neut_avgtex - dataset.texmean) / dataset.texstd
neut_verts = (neut_vert - dataset.vertmean) / dataset.vertstd
all_neut_avgtex_vert.append({"neut_avgtex": torch.tensor(neut_avgtex), "neut_verts": torch.tensor(neut_verts)})

output_set = set(train_params.output_set)

for i in tqdm(range(len(driver_dataset.framelist.values.tolist()) - 1), desc="Rendering X-id frames"):
xid_eval(
ae,
driver_dataiter,
all_neut_avgtex_vert,
config,
output_set,
output_dir,
0,
i,
indices_subjects=args.driven_id_indices,
training=False,
)

# Grab ground truth frame from the driver
gt = cudadriver["image"].detach().cpu().numpy()
gt = einops.rearrange(gt, "1 c h w -> h w c")

rgb_orig = output_orig["irgbrec"].detach().cpu().numpy()
rgb_orig = einops.rearrange(rgb_orig, "1 c h w -> h w c")

rgb_driven = output_driven["irgbrec"].detach().cpu().numpy()
rgb_driven = einops.rearrange(rgb_driven, "1 c h w -> h w c")

render_img([[gt, rgb_orig, rgb_driven]], f"{output_dir}/img_{it:06d}.png")

it += 1

print(f"Done! Saved {it} images to {output_dir}")
print(f"Done! Saved {i} images to {output_dir}")
Loading
Loading