Update pipelines.py
This commit is contained in:
@@ -781,3 +781,111 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
|||||||
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
||||||
enable_pbar=enable_pbar,
|
enable_pbar=enable_pbar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@synchronize_timer('Hunyuan3DDiTFlowMatchingPipeline from Lightning Checkpoint')
|
||||||
|
def from_lightning_checkpoint(
|
||||||
|
cls,
|
||||||
|
ckpt_path: str,
|
||||||
|
config_path: str,
|
||||||
|
device: str = 'cuda',
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads a model from a checkpoint created by the project's PyTorch Lightning training script.
|
||||||
|
|
||||||
|
This method correctly handles the nested configuration structure and state_dict prefixes
|
||||||
|
produced during training, and can intelligently load sharded checkpoints saved by Deepspeed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_path (str): Path to the .ckpt checkpoint file or directory.
|
||||||
|
config_path (str): Path to the .yaml configuration file used for training.
|
||||||
|
device (str, optional): The device to load the model on. Defaults to 'cuda'.
|
||||||
|
dtype (torch.dtype, optional): The data type for the model. Defaults to torch.float16.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hunyuan3DDiTFlowMatchingPipeline: An instantiated pipeline ready for inference.
|
||||||
|
"""
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from hy3dshape.utils.misc import instantiate_from_config
|
||||||
|
from hy3dshape.schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
|
||||||
|
logger.info(f"Loading model from Lightning checkpoint: {ckpt_path}")
|
||||||
|
logger.info(f"Using training config: {config_path}")
|
||||||
|
|
||||||
|
config = OmegaConf.load(config_path)
|
||||||
|
|
||||||
|
if os.path.isdir(ckpt_path):
|
||||||
|
# Assumes a Deepspeed-saved checkpoint directory
|
||||||
|
model_state_file = os.path.join(ckpt_path, 'checkpoint', 'mp_rank_00_model_states.pt')
|
||||||
|
if not os.path.exists(model_state_file):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find model weights file 'mp_rank_00_model_states.pt' in Deepspeed checkpoint directory: {os.path.join(ckpt_path, 'checkpoint')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Detected Deepspeed checkpoint directory, loading weights from: '{model_state_file}'")
|
||||||
|
ckpt = torch.load(model_state_file, map_location='cpu', weights_only=False)
|
||||||
|
# Deepspeed weights are often nested under the 'module' key
|
||||||
|
state_dict = ckpt.get('module', ckpt)
|
||||||
|
else:
|
||||||
|
# Standard .ckpt file
|
||||||
|
logger.info("Detected standard .ckpt file.")
|
||||||
|
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
|
||||||
|
state_dict = ckpt.get('state_dict', ckpt)
|
||||||
|
|
||||||
|
# 1. Instantiate components that were frozen during training.
|
||||||
|
# They will load their own pretrained weights upon instantiation.
|
||||||
|
logger.info("Instantiating VAE, Conditioner, and ImageProcessor...")
|
||||||
|
vae = instantiate_from_config(config.model.params.first_stage_config)
|
||||||
|
conditioner = instantiate_from_config(config.model.params.cond_stage_config)
|
||||||
|
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
|
||||||
|
|
||||||
|
# 2. Instantiate the component that was trained (the Denoiser).
|
||||||
|
logger.info("Instantiating Denoiser...")
|
||||||
|
denoiser = instantiate_from_config(config.model.params.denoiser_cfg)
|
||||||
|
|
||||||
|
# 3. Load weights only for the Denoiser from our training checkpoint.
|
||||||
|
possible_prefixes = ["model.model.", "_forward_module.model.", "model."]
|
||||||
|
denoiser_dict = {}
|
||||||
|
matched_prefix = None
|
||||||
|
for prefix in possible_prefixes:
|
||||||
|
sub_dict = {k.replace(prefix, ''): v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||||
|
if sub_dict:
|
||||||
|
denoiser_dict = sub_dict
|
||||||
|
matched_prefix = prefix
|
||||||
|
break
|
||||||
|
|
||||||
|
if denoiser_dict:
|
||||||
|
logger.info(f"Successfully matched Denoiser weight prefix: '{matched_prefix}'")
|
||||||
|
missing_keys, unexpected_keys = denoiser.load_state_dict(denoiser_dict, strict=False)
|
||||||
|
logger.info(" Successfully loaded weights for 'denoiser'.")
|
||||||
|
if missing_keys:
|
||||||
|
logger.warning(f" - Missing keys: {missing_keys}")
|
||||||
|
if unexpected_keys:
|
||||||
|
logger.warning(f" - Unexpected keys: {unexpected_keys}")
|
||||||
|
else:
|
||||||
|
logger.warning("Could not find weights for 'denoiser' in checkpoint. It will be randomly initialized.")
|
||||||
|
|
||||||
|
# 4. Instantiate a new, inference-compatible scheduler.
|
||||||
|
logger.info("Creating a new scheduler for inference...")
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||||
|
|
||||||
|
# 5. Assemble the final, healthy pipeline.
|
||||||
|
pipeline = cls(
|
||||||
|
model=denoiser,
|
||||||
|
vae=vae,
|
||||||
|
scheduler=scheduler,
|
||||||
|
conditioner=conditioner,
|
||||||
|
image_processor=image_processor,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Move all model components to the correct device and set to evaluation mode.
|
||||||
|
pipeline.to(torch.device(device), dtype=dtype)
|
||||||
|
pipeline.model.eval()
|
||||||
|
pipeline.vae.eval()
|
||||||
|
pipeline.conditioner.eval()
|
||||||
|
|
||||||
|
logger.info("\n Pipeline successfully assembled from Lightning checkpoint!")
|
||||||
|
return pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user