Update pipelines.py

This commit is contained in:
s572915912
2025-07-11 16:51:33 +08:00
committed by GitHub
parent dc2ea32d76
commit f0a008279e

View File

@@ -781,111 +781,3 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
enable_pbar=enable_pbar,
)
@classmethod
@synchronize_timer('Hunyuan3DDiTFlowMatchingPipeline from Lightning Checkpoint')
def from_lightning_checkpoint(
cls,
ckpt_path: str,
config_path: str,
device: str = 'cuda',
dtype: torch.dtype = torch.float16,
**kwargs,
):
"""
Loads a model from a checkpoint created by the project's PyTorch Lightning training script.
This method correctly handles the nested configuration structure and state_dict prefixes
produced during training, and can intelligently load sharded checkpoints saved by Deepspeed.
Args:
ckpt_path (str): Path to the .ckpt checkpoint file or directory.
config_path (str): Path to the .yaml configuration file used for training.
device (str, optional): The device to load the model on. Defaults to 'cuda'.
dtype (torch.dtype, optional): The data type for the model. Defaults to torch.float16.
Returns:
Hunyuan3DDiTFlowMatchingPipeline: An instantiated pipeline ready for inference.
"""
from omegaconf import OmegaConf
from hy3dshape.utils.misc import instantiate_from_config
from hy3dshape.schedulers import FlowMatchEulerDiscreteScheduler
logger.info(f"Loading model from Lightning checkpoint: {ckpt_path}")
logger.info(f"Using training config: {config_path}")
config = OmegaConf.load(config_path)
if os.path.isdir(ckpt_path):
# Assumes a Deepspeed-saved checkpoint directory
model_state_file = os.path.join(ckpt_path, 'checkpoint', 'mp_rank_00_model_states.pt')
if not os.path.exists(model_state_file):
raise FileNotFoundError(
f"Could not find model weights file 'mp_rank_00_model_states.pt' in Deepspeed checkpoint directory: {os.path.join(ckpt_path, 'checkpoint')}"
)
logger.info(f"Detected Deepspeed checkpoint directory, loading weights from: '{model_state_file}'")
ckpt = torch.load(model_state_file, map_location='cpu', weights_only=False)
# Deepspeed weights are often nested under the 'module' key
state_dict = ckpt.get('module', ckpt)
else:
# Standard .ckpt file
logger.info("Detected standard .ckpt file.")
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
state_dict = ckpt.get('state_dict', ckpt)
# 1. Instantiate components that were frozen during training.
# They will load their own pretrained weights upon instantiation.
logger.info("Instantiating VAE, Conditioner, and ImageProcessor...")
vae = instantiate_from_config(config.model.params.first_stage_config)
conditioner = instantiate_from_config(config.model.params.cond_stage_config)
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
# 2. Instantiate the component that was trained (the Denoiser).
logger.info("Instantiating Denoiser...")
denoiser = instantiate_from_config(config.model.params.denoiser_cfg)
# 3. Load weights only for the Denoiser from our training checkpoint.
possible_prefixes = ["model.model.", "_forward_module.model.", "model."]
denoiser_dict = {}
matched_prefix = None
for prefix in possible_prefixes:
sub_dict = {k.replace(prefix, ''): v for k, v in state_dict.items() if k.startswith(prefix)}
if sub_dict:
denoiser_dict = sub_dict
matched_prefix = prefix
break
if denoiser_dict:
logger.info(f"Successfully matched Denoiser weight prefix: '{matched_prefix}'")
missing_keys, unexpected_keys = denoiser.load_state_dict(denoiser_dict, strict=False)
logger.info(" Successfully loaded weights for 'denoiser'.")
if missing_keys:
logger.warning(f" - Missing keys: {missing_keys}")
if unexpected_keys:
logger.warning(f" - Unexpected keys: {unexpected_keys}")
else:
logger.warning("Could not find weights for 'denoiser' in checkpoint. It will be randomly initialized.")
# 4. Instantiate a new, inference-compatible scheduler.
logger.info("Creating a new scheduler for inference...")
scheduler = FlowMatchEulerDiscreteScheduler()
# 5. Assemble the final, healthy pipeline.
pipeline = cls(
model=denoiser,
vae=vae,
scheduler=scheduler,
conditioner=conditioner,
image_processor=image_processor,
**kwargs,
)
# 6. Move all model components to the correct device and set to evaluation mode.
pipeline.to(torch.device(device), dtype=dtype)
pipeline.model.eval()
pipeline.vae.eval()
pipeline.conditioner.eval()
logger.info("\n Pipeline successfully assembled from Lightning checkpoint!")
return pipeline