| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import functools |
| import math |
| from tqdm import tqdm, trange |
| import argparse |
| import time |
| import subprocess |
| import re |
| import sys |
|
|
| sys.path.insert(1, os.path.join(sys.path[0], "..")) |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
|
|
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import torch.optim as optim |
|
|
| |
| import data_utils.inception_utils as inception_utils |
| import utils |
| import train_fns |
| from sync_batchnorm import patch_replication_callback |
| from data_utils import utils as data_utils |
|
|
|
|
| def run(config, ddp_setup="slurm", master_node=""): |
| config["n_classes"] = 1000 |
| config["G_activation"] = utils.activation_dict[config["G_nl"]] |
| config["D_activation"] = utils.activation_dict[config["D_nl"]] |
| config = utils.update_config_roots(config) |
|
|
| |
| utils.prepare_root(config) |
|
|
| if config["ddp_train"]: |
| if ddp_setup == "slurm": |
| n_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES")) |
| n_gpus_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE").split("(")[0]) |
| world_size = n_gpus_per_node * n_nodes |
| print( |
| "Master node is ", |
| master_node, |
| " World size is ", |
| world_size, |
| " with ", |
| n_gpus_per_node, |
| "gpus per node.", |
| ) |
| dist_url = "tcp://" |
| dist_url += master_node |
| port = 40000 |
| dist_url += ":" + str(port) |
| print("Dist url ", dist_url) |
| train(-1, world_size, config, dist_url) |
| else: |
| world_size = torch.cuda.device_count() |
| dist_url = "env://" |
| mp.spawn( |
| train, args=(world_size, config, dist_url), nprocs=world_size, join=True |
| ) |
| else: |
| train(0, -1, config, None) |
|
|
|
|
| def train(rank, world_size, config, dist_url): |
| print("Rank of this job is ", rank) |
| copy_locally = False |
| tmp_dir = "" |
| if config["ddp_train"]: |
| if dist_url == "env://": |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = "12355" |
| local_rank = rank |
| else: |
| rank = int(os.environ.get("SLURM_PROCID")) |
| local_rank = int(os.environ.get("SLURM_LOCALID")) |
| copy_locally = True |
| tmp_dir = "/scratch/slurm_tmpdir/" + str(os.environ.get("SLURM_JOB_ID")) |
|
|
| print("Before setting process group") |
| print(dist_url, rank) |
| dist.init_process_group( |
| backend="nccl", init_method=dist_url, rank=rank, world_size=world_size |
| ) |
| print("After setting process group") |
| device = "cuda:{}".format(local_rank) |
| print(dist_url, rank, " /Device is ", device) |
| else: |
| device = "cuda" |
| local_rank = "cuda" |
|
|
| |
| |
| |
| |
| |
|
|
| |
| utils.seed_rng(config["seed"] + rank) |
|
|
| |
| torch.backends.cudnn.benchmark = True |
| if config["deterministic_run"]: |
| torch.backends.cudnn.deterministic = True |
|
|
| |
| model = __import__(config["model"]) |
| experiment_name = ( |
| config["experiment_name"] |
| if config["experiment_name"] |
| else utils.name_from_config(config) |
| ) |
| print("Experiment name is %s" % experiment_name) |
|
|
| if config["ddp_train"]: |
| torch.cuda.set_device(device) |
| |
| G = model.Generator(**{**config, "embedded_optimizers": False}).to(device) |
| D = model.Discriminator(**{**config, "embedded_optimizers": False}).to(device) |
|
|
| |
| if config["ema"]: |
| print("Preparing EMA for G with decay of {}".format(config["ema_decay"])) |
| G_ema = model.Generator(**{**config, "skip_init": True, "no_optim": True}).to( |
| device |
| ) |
| ema = utils.ema(G, G_ema, config["ema_decay"], config["ema_start"]) |
| else: |
| G_ema, ema = None, None |
|
|
| print( |
| "Number of params in G: {} D: {}".format( |
| *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]] |
| ) |
| ) |
|
|
| |
| if config["D_fp16"]: |
| print("Using fp16 adam ") |
| optim_type = utils.Adam16 |
| else: |
| optim_type = optim.Adam |
| optimizer_D = optim_type( |
| params=D.parameters(), |
| lr=config["D_lr"], |
| betas=(config["D_B1"], config["D_B2"]), |
| weight_decay=0, |
| eps=config["adam_eps"], |
| ) |
| optimizer_G = optim_type( |
| params=G.parameters(), |
| lr=config["G_lr"], |
| betas=(config["G_B1"], config["G_B2"]), |
| weight_decay=0, |
| eps=config["adam_eps"], |
| ) |
|
|
| |
| state_dict = { |
| "itr": 0, |
| "epoch": 0, |
| "save_num": 0, |
| "save_best_num": 0, |
| "best_IS": 0, |
| "best_FID": 999999, |
| "es_epoch": 0, |
| "config": config, |
| } |
|
|
| |
| if config["G_fp16"]: |
| print("Casting G to float16...") |
| G = G.half() |
| if config["ema"]: |
| G_ema = G_ema.half() |
| if config["D_fp16"]: |
| print("Casting D to fp16...") |
| D = D.half() |
|
|
| |
| if config["ddp_train"]: |
| print("before G DDP ") |
| G = DDP( |
| G, |
| device_ids=[local_rank], |
| output_device=local_rank, |
| find_unused_parameters=True, |
| ) |
| print("After G DDP ") |
| D = DDP( |
| D, |
| device_ids=[local_rank], |
| output_device=local_rank, |
| find_unused_parameters=True, |
| ) |
|
|
| |
| print("Loading weights...") |
| if config["ddp_train"]: |
| dist.barrier() |
| map_location = device |
| else: |
| map_location = None |
|
|
| utils.load_weights( |
| G, |
| D, |
| state_dict, |
| config["weights_root"], |
| experiment_name, |
| config["load_weights"] if config["load_weights"] else None, |
| G_ema if config["ema"] else None, |
| map_location=map_location, |
| embedded_optimizers=False, |
| G_optim=optimizer_G, |
| D_optim=optimizer_D, |
| ) |
|
|
| |
| GD = model.G_D(G, D, optimizer_G=optimizer_G, optimizer_D=optimizer_D) |
|
|
| if config["parallel"] and world_size > -1: |
| GD = nn.DataParallel(GD) |
| if config["cross_replica"]: |
| patch_replication_callback(GD) |
|
|
| |
| |
| if rank == 0: |
| test_metrics_fname = "%s/%s_log.jsonl" % (config["logs_root"], experiment_name) |
| train_metrics_fname = "%s/%s" % (config["logs_root"], experiment_name) |
| print("Inception Metrics will be saved to {}".format(test_metrics_fname)) |
| test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=False) |
| print("Training Metrics will be saved to {}".format(train_metrics_fname)) |
| train_log = utils.MyLogger( |
| train_metrics_fname, reinitialize=False, logstyle=config["logstyle"] |
| ) |
| |
| utils.write_metadata(config["logs_root"], experiment_name, config, state_dict) |
| else: |
| test_log = None |
| train_log = None |
|
|
| D_batch_size = ( |
| config["batch_size"] * config["num_D_steps"] * config["num_D_accumulations"] |
| ) |
|
|
| if config["longtail"]: |
| samples_per_class = np.load( |
| "imagenet_lt/imagenet_lt_samples_per_class.npy", allow_pickle=True |
| ) |
| class_probabilities = np.load( |
| "imagenet_lt/imagenet_lt_class_prob.npy", allow_pickle=True |
| ) |
| else: |
| samples_per_class, class_probabilities = None, None |
|
|
| train_dataset = data_utils.get_dataset_hdf5( |
| **{ |
| **config, |
| "data_path": config["data_root"], |
| "batch_size": D_batch_size, |
| "augment": config["hflips"], |
| "local_rank": local_rank, |
| "copy_locally": copy_locally, |
| "tmp_dir": tmp_dir, |
| "ddp": config["ddp_train"], |
| } |
| ) |
| train_loader = data_utils.get_dataloader( |
| **{ |
| **config, |
| "dataset": train_dataset, |
| "batch_size": config["batch_size"], |
| "start_epoch": state_dict["epoch"], |
| "start_itr": state_dict["itr"], |
| "longtail_temperature": config["longtail_temperature"], |
| "samples_per_class": samples_per_class, |
| "class_probabilities": class_probabilities, |
| "rank": rank, |
| "world_size": world_size, |
| "shuffle": True, |
| "drop_last": True, |
| } |
| ) |
|
|
| |
| is_moments_prefix = "I" if config["which_dataset"] == "imagenet" else "COCO" |
|
|
| im_filename = "%s%i_%s" % ( |
| is_moments_prefix, |
| config["resolution"], |
| "" if not config["longtail"] else "longtail", |
| ) |
| print("Using ", im_filename, "for Inception metrics.") |
|
|
| get_inception_metrics = inception_utils.prepare_inception_metrics( |
| im_filename, |
| samples_per_class, |
| config["parallel"], |
| config["no_fid"], |
| config["data_root"], |
| device=device, |
| ) |
|
|
| G_batch_size = config["G_batch_size"] |
|
|
| z_, y_ = data_utils.prepare_z_y( |
| G_batch_size, |
| G.module.dim_z if config["ddp_train"] else G.dim_z, |
| config["n_classes"], |
| device=device, |
| fp16=config["G_fp16"], |
| longtail_gen=config["longtail_gen"], |
| custom_distrib=config["custom_distrib_gen"], |
| longtail_temperature=config["longtail_temperature"], |
| class_probabilities=class_probabilities, |
| ) |
|
|
| |
| weights_sampling = None |
| if ( |
| config["longtail"] |
| and config["use_balanced_sampler"] |
| and config["instance_cond"] |
| ): |
| if config["which_knn_balance"] == "center_balance": |
| print( |
| "Balancing the instance features." "Using custom temperature distrib?", |
| config["custom_distrib_gen"], |
| " with temperature", |
| config["longtail_temperature"], |
| ) |
| weights_sampling = data_utils.make_weights_for_balanced_classes( |
| samples_per_class, |
| train_loader.dataset.labels, |
| 1000, |
| config["custom_distrib_gen"], |
| config["longtail_temperature"], |
| class_probabilities=class_probabilities, |
| ) |
| |
| elif config["which_knn_balance"] == "nnclass_balance": |
| print( |
| "Balancing the class distribution (classes drawn from the neighbors)." |
| " Using custom temperature distrib?", |
| config["custom_distrib_gen"], |
| " with temperature", |
| config["longtail_temperature"], |
| ) |
| weights_sampling = torch.exp( |
| class_probabilities / config["longtail_temperature"] |
| ) / torch.sum( |
| torch.exp(class_probabilities / config["longtail_temperature"]) |
| ) |
|
|
| |
| sample_conditioning = functools.partial( |
| data_utils.sample_conditioning_values, |
| z_=z_, |
| y_=y_, |
| dataset=train_dataset, |
| batch_size=G_batch_size, |
| weights_sampling=weights_sampling, |
| ddp=config["ddp_train"], |
| constant_conditioning=config["constant_conditioning"], |
| class_cond=config["class_cond"], |
| instance_cond=config["instance_cond"], |
| nn_sampling_strategy=config["which_knn_balance"], |
| ) |
|
|
| print("G batch size ", G_batch_size) |
| |
| train = train_fns.GAN_training_function( |
| G, |
| D, |
| GD, |
| ema, |
| state_dict, |
| config, |
| sample_conditioning, |
| embedded_optimizers=False, |
| device=device, |
| batch_size=G_batch_size, |
| ) |
|
|
| |
| sample = functools.partial( |
| utils.sample, |
| G=(G_ema if config["ema"] and config["use_ema"] else G), |
| sample_conditioning_func=sample_conditioning, |
| config=config, |
| class_cond=config["class_cond"], |
| instance_cond=config["instance_cond"], |
| ) |
|
|
| print("Beginning training at epoch %d..." % state_dict["epoch"]) |
| |
| best_FID_run = state_dict["best_FID"] |
| FID = state_dict["best_FID"] |
|
|
| for epoch in range(state_dict["epoch"], config["num_epochs"]): |
| |
| if config["ddp_train"]: |
| train_loader.sampler.set_epoch(epoch) |
| |
| |
| if config["deterministic_run"]: |
| utils.seed_rng(config["seed"] + rank + state_dict["epoch"]) |
| |
| if config["pbar"] == "mine": |
| pbar = utils.progress( |
| train_loader, |
| displaytype="s1k" if config["use_multiepoch_sampler"] else "eta", |
| ) |
| else: |
| pbar = tqdm(train_loader) |
| s = time.time() |
| print("Before iteration, dataloader length", len(train_loader)) |
| for i, batch in enumerate(pbar): |
| |
| |
| in_label, in_feat = None, None |
| if config["instance_cond"] and config["class_cond"]: |
| x, in_label, in_feat, _ = batch |
| elif config["instance_cond"]: |
| x, in_feat, _ = batch |
| elif config["class_cond"]: |
| x, in_label = batch |
| if config["constant_conditioning"]: |
| in_label = torch.zeros_like(in_label) |
| else: |
| x = batch |
|
|
| x = x.to(device, non_blocking=True) |
| if in_label is not None: |
| in_label = in_label.to(device, non_blocking=True) |
| if in_feat is not None: |
| in_feat = in_feat.float().to(device, non_blocking=True) |
| |
| state_dict["itr"] += 1 |
| |
| |
| G.train() |
| D.train() |
| if config["ema"]: |
| G_ema.train() |
|
|
| metrics = train(x, in_label, in_feat) |
| |
| |
| if rank == 0: |
| train_log.log(itr=int(state_dict["itr"]), **metrics) |
|
|
| |
| if config["pbar"] == "mine" and rank == 0: |
| print( |
| ", ".join( |
| ["itr: %d" % state_dict["itr"]] |
| + ["%s : %+4.3f" % (key, metrics[key]) for key in metrics] |
| ), |
| end=" ", |
| ) |
| |
|
|
| print("Iteration time ", time.time() - s) |
| s = time.time() |
| |
| state_dict["epoch"] += 1 |
|
|
| if not (state_dict["epoch"] % config["test_every"]): |
| if config["G_eval_mode"]: |
| print("Switching G to eval mode...") |
| G.eval() |
| D.eval() |
|
|
| |
| test_time = time.time() |
| IS, FID = train_fns.test( |
| G, |
| D, |
| G_ema, |
| z_, |
| y_, |
| state_dict, |
| config, |
| sample, |
| get_inception_metrics, |
| experiment_name, |
| test_log, |
| loader=None, |
| embedded_optimizers=False, |
| G_optim=optimizer_G, |
| D_optim=optimizer_D, |
| rank=rank, |
| ) |
| print("Testing took ", time.time() - test_time) |
|
|
| if 2 * IS < state_dict["best_IS"] and config["stop_when_diverge"]: |
| print("Experiment diverged!") |
| break |
| else: |
| print("IS is ", IS, " and 2x best is ", 2 * state_dict["best_IS"]) |
|
|
| if not (state_dict["epoch"] % config["save_every"]) and rank == 0: |
| train_fns.save_weights( |
| G, |
| D, |
| G_ema, |
| state_dict, |
| config, |
| experiment_name, |
| embedded_optimizers=False, |
| G_optim=optimizer_G, |
| D_optim=optimizer_D, |
| ) |
| if rank == 0: |
| if FID < best_FID_run: |
| best_FID_run = FID |
| state_dict["es_epoch"] = 0 |
| else: |
| state_dict["es_epoch"] += 1 |
| if state_dict["es_epoch"] >= config["es_patience"]: |
| print("reached Early stopping!") |
| return FID |
| return FID |
|
|