init
This commit is contained in:
401
hy3dpaint/train.py
Normal file
401
hy3dpaint/train.py
Normal file
@@ -0,0 +1,401 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import torch
|
||||
import os, sys
|
||||
import argparse
|
||||
import shutil
|
||||
import subprocess
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
|
||||
|
||||
from src.utils.train_util import instantiate_from_config
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
diffusers_logging.set_verbosity(50)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def rank_zero_print(*args):
|
||||
print(*args)
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="resume from checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_weights_only",
|
||||
action="store_true",
|
||||
help="only resume model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--base",
|
||||
type=str,
|
||||
default="base_config.yaml",
|
||||
help="path to base configs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--name",
|
||||
type=str,
|
||||
default="",
|
||||
help="experiment name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of nodes to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus",
|
||||
type=str,
|
||||
default="0,",
|
||||
help="gpu ids to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="seed for seed_everything",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--logdir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="directory for logging data",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(self, resume, logdir, ckptdir, cfgdir, config):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
self.logdir = logdir
|
||||
self.ckptdir = ckptdir
|
||||
self.cfgdir = cfgdir
|
||||
self.config = config
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
os.makedirs(self.cfgdir, exist_ok=True)
|
||||
|
||||
rank_zero_print("Project config")
|
||||
rank_zero_print(OmegaConf.to_yaml(self.config))
|
||||
OmegaConf.save(self.config, os.path.join(self.cfgdir, "project.yaml"))
|
||||
|
||||
|
||||
class CodeSnapshot(Callback):
|
||||
"""
|
||||
Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60
|
||||
"""
|
||||
|
||||
def __init__(self, savedir):
|
||||
self.savedir = savedir
|
||||
|
||||
def get_file_list(self):
|
||||
return [
|
||||
b.decode()
|
||||
for b in set(subprocess.check_output('git ls-files -- ":!:configs/*"', shell=True).splitlines())
|
||||
| set( # hard code, TODO: use config to exclude folders or files
|
||||
subprocess.check_output("git ls-files --others --exclude-standard", shell=True).splitlines()
|
||||
)
|
||||
]
|
||||
|
||||
@rank_zero_only
|
||||
def save_code_snapshot(self):
|
||||
os.makedirs(self.savedir, exist_ok=True)
|
||||
|
||||
# for f in self.get_file_list():
|
||||
# if not os.path.exists(f) or os.path.isdir(f):
|
||||
# continue
|
||||
# os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
|
||||
# shutil.copyfile(f, os.path.join(self.savedir, f))
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
try:
|
||||
self.save_code_snapshot()
|
||||
except:
|
||||
rank_zero_warn(
|
||||
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# add cwd for convenience and to make classes in this file available when
|
||||
# running as `python main.py`
|
||||
sys.path.append(os.getcwd())
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
||||
parser = get_parser()
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
cfg_fname = os.path.split(opt.base)[-1]
|
||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||
exp_name = "-" + opt.name if opt.name != "" else ""
|
||||
logdir = os.path.join(opt.logdir, cfg_name + exp_name)
|
||||
|
||||
# assert not os.path.exists(logdir) or 'test' in logdir, logdir
|
||||
if os.path.exists(logdir) and opt.resume is None:
|
||||
auto_resume_path = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
if os.path.exists(auto_resume_path):
|
||||
opt.resume = auto_resume_path
|
||||
print(f"Auto set resume ckpt {opt.resume}")
|
||||
|
||||
ckptdir = os.path.join(logdir, "checkpoints")
|
||||
cfgdir = os.path.join(logdir, "configs")
|
||||
codedir = os.path.join(logdir, "code")
|
||||
|
||||
node_rank = int(os.environ.get("NODE_RANK", 0)) # 当前节点的编号
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0)) # 当前节点上的 GPU 编号
|
||||
num_gpus_per_node = torch.cuda.device_count() # 每个节点上的 GPU 数量
|
||||
|
||||
global_rank = node_rank * num_gpus_per_node + local_rank
|
||||
seed_everything(opt.seed + global_rank)
|
||||
|
||||
# init configs
|
||||
config = OmegaConf.load(opt.base)
|
||||
lightning_config = config.lightning
|
||||
trainer_config = lightning_config.trainer
|
||||
|
||||
trainer_config["accelerator"] = "gpu"
|
||||
rank_zero_print(f"Running on GPUs {opt.gpus}")
|
||||
try:
|
||||
ngpu = int(opt.gpus)
|
||||
except:
|
||||
ngpu = len(opt.gpus.strip(",").split(","))
|
||||
trainer_config["devices"] = ngpu
|
||||
|
||||
trainer_opt = argparse.Namespace(**trainer_config)
|
||||
lightning_config.trainer = trainer_config
|
||||
|
||||
# model
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
model_unet = model.unet.unet
|
||||
model_unet_prefix = "unet.unet."
|
||||
if hasattr(model_unet, "unet"):
|
||||
model_unet = model_unet.unet
|
||||
model_unet_prefix += "unet."
|
||||
|
||||
if getattr(config, "init_unet_from", None):
|
||||
unet_ckpt_path = config.init_unet_from
|
||||
sd = torch.load(unet_ckpt_path, map_location="cpu")
|
||||
model_unet.load_state_dict(sd, strict=True)
|
||||
|
||||
if getattr(config, "init_vae_from", None):
|
||||
vae_ckpt_path = config.init_vae_from
|
||||
sd_vae = torch.load(vae_ckpt_path, map_location="cpu")
|
||||
|
||||
def replace_key(key_str):
|
||||
replace_pairs = [("key", "to_k"), ("query", "to_q"), ("value", "to_v"), ("proj_attn", "to_out.0")]
|
||||
for replace_pair in replace_pairs:
|
||||
key_str = key_str.replace(replace_pair[0], replace_pair[1])
|
||||
return key_str
|
||||
|
||||
sd_vae = {replace_key(k): v for k, v in sd_vae.items()}
|
||||
model.pipeline.vae.load_state_dict(sd_vae, strict=True)
|
||||
|
||||
if hasattr(model.unet, "controlnet"):
|
||||
if getattr(config, "init_control_from", None):
|
||||
unet_ckpt_path = config.init_control_from
|
||||
sd_control = torch.load(unet_ckpt_path, map_location="cpu")
|
||||
model.unet.controlnet.load(sd_control, strict=True)
|
||||
|
||||
noise_in_channels = config.model.params.get("noise_in_channels", None)
|
||||
if noise_in_channels is not None:
|
||||
with torch.no_grad():
|
||||
new_conv_in = torch.nn.Conv2d(
|
||||
noise_in_channels,
|
||||
model_unet.conv_in.out_channels,
|
||||
model_unet.conv_in.kernel_size,
|
||||
model_unet.conv_in.stride,
|
||||
model_unet.conv_in.padding,
|
||||
)
|
||||
new_conv_in.weight.zero_()
|
||||
new_conv_in.weight[:, : model_unet.conv_in.in_channels, :, :].copy_(model_unet.conv_in.weight)
|
||||
|
||||
new_conv_in.bias.zero_()
|
||||
new_conv_in.bias[: model_unet.conv_in.bias.size(0)].copy_(model_unet.conv_in.bias)
|
||||
|
||||
model_unet.conv_in = new_conv_in
|
||||
|
||||
if hasattr(model.unet, "controlnet"):
|
||||
if config.model.params.get("control_in_channels", None):
|
||||
control_in_channels = config.model.params.control_in_channels
|
||||
model.unet.controlnet.config["conditioning_channels"] = control_in_channels
|
||||
condition_conv_in = model.unet.controlnet.controlnet_cond_embedding.conv_in
|
||||
|
||||
new_condition_conv_in = torch.nn.Conv2d(
|
||||
control_in_channels,
|
||||
condition_conv_in.out_channels,
|
||||
kernel_size=condition_conv_in.kernel_size,
|
||||
stride=condition_conv_in.stride,
|
||||
padding=condition_conv_in.padding,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
new_condition_conv_in.weight[:, : condition_conv_in.in_channels, :, :] = condition_conv_in.weight
|
||||
if condition_conv_in.bias is not None:
|
||||
new_condition_conv_in.bias = condition_conv_in.bias
|
||||
|
||||
model.unet.controlnet.controlnet_cond_embedding.conv_in = new_condition_conv_in
|
||||
|
||||
rank_zero_print(f"Loaded Init ...")
|
||||
|
||||
if getattr(config, "resume_from", None):
|
||||
cnet_ckpt_path = config.resume_from
|
||||
sds = torch.load(cnet_ckpt_path, map_location="cpu")["state_dict"]
|
||||
sd0 = {k[len(model_unet_prefix) :]: v for k, v in sds.items() if model_unet_prefix in k}
|
||||
# model.unet.unet.unet.load_state_dict(sd0, strict=True)
|
||||
model_unet.load_state_dict(sd0, strict=True)
|
||||
if hasattr(model.unet, "controlnet"):
|
||||
sd1 = {k[16:]: v for k, v in sds.items() if "unet.controlnet." in k}
|
||||
model.unet.controlnet.load_state_dict(sd1, strict=True)
|
||||
rank_zero_print(f"Loaded {cnet_ckpt_path} ...")
|
||||
|
||||
if opt.resume and opt.resume_weights_only:
|
||||
model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params)
|
||||
|
||||
model.logdir = logdir
|
||||
|
||||
# trainer and callbacks
|
||||
trainer_kwargs = dict()
|
||||
|
||||
# logger
|
||||
default_logger_cfg = {
|
||||
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
||||
"params": {
|
||||
"name": "tensorboard",
|
||||
"save_dir": logdir,
|
||||
"version": "0",
|
||||
},
|
||||
}
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# model checkpoint
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{step:08}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
"every_n_train_steps": 5000,
|
||||
"save_top_k": -1, # save all checkpoints
|
||||
},
|
||||
}
|
||||
|
||||
if "modelcheckpoint" in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
|
||||
# callbacks
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": {
|
||||
"target": "train.SetupCallback",
|
||||
"params": {
|
||||
"resume": opt.resume,
|
||||
"logdir": logdir,
|
||||
"ckptdir": ckptdir,
|
||||
"cfgdir": cfgdir,
|
||||
"config": config,
|
||||
},
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
},
|
||||
},
|
||||
"code_snapshot": {
|
||||
"target": "train.CodeSnapshot",
|
||||
"params": {
|
||||
"savedir": codedir,
|
||||
},
|
||||
},
|
||||
}
|
||||
default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
else:
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer_kwargs["precision"] = "bf16"
|
||||
trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False)
|
||||
|
||||
# trainer
|
||||
trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes, inference_mode=False)
|
||||
trainer.logdir = logdir
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
data.prepare_data()
|
||||
data.setup("fit")
|
||||
|
||||
# configure learning rate
|
||||
base_lr = config.model.base_learning_rate
|
||||
if "accumulate_grad_batches" in lightning_config.trainer:
|
||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
model.learning_rate = base_lr
|
||||
rank_zero_print("++++ NOT USING LR SCALING ++++")
|
||||
rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
# run training loop
|
||||
if opt.resume and not opt.resume_weights_only:
|
||||
trainer.fit(model, data, ckpt_path=opt.resume)
|
||||
else:
|
||||
trainer.fit(model, data)
|
||||
Reference in New Issue
Block a user