From b06e6ddf37a859b44bf9380dc9b6543393d5b8b3 Mon Sep 17 00:00:00 2001 From: s572915912 <54531516+s572915912@users.noreply.github.com> Date: Fri, 11 Jul 2025 02:29:25 +0800 Subject: [PATCH] Update pipelines.py --- hy3dshape/hy3dshape/pipelines.py | 108 +++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/hy3dshape/hy3dshape/pipelines.py b/hy3dshape/hy3dshape/pipelines.py index 71de472..0bb7c8f 100644 --- a/hy3dshape/hy3dshape/pipelines.py +++ b/hy3dshape/hy3dshape/pipelines.py @@ -781,3 +781,111 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): box_v, mc_level, num_chunks, octree_resolution, mc_algo, enable_pbar=enable_pbar, ) + + @classmethod + @synchronize_timer('Hunyuan3DDiTFlowMatchingPipeline from Lightning Checkpoint') + def from_lightning_checkpoint( + cls, + ckpt_path: str, + config_path: str, + device: str = 'cuda', + dtype: torch.dtype = torch.float16, + **kwargs, + ): + """ + Loads a model from a checkpoint created by the project's PyTorch Lightning training script. + + This method correctly handles the nested configuration structure and state_dict prefixes + produced during training, and can intelligently load sharded checkpoints saved by Deepspeed. + + Args: + ckpt_path (str): Path to the .ckpt checkpoint file or directory. + config_path (str): Path to the .yaml configuration file used for training. + device (str, optional): The device to load the model on. Defaults to 'cuda'. + dtype (torch.dtype, optional): The data type for the model. Defaults to torch.float16. + + Returns: + Hunyuan3DDiTFlowMatchingPipeline: An instantiated pipeline ready for inference. + """ + from omegaconf import OmegaConf + from hy3dshape.utils.misc import instantiate_from_config + from hy3dshape.schedulers import FlowMatchEulerDiscreteScheduler + + logger.info(f"Loading model from Lightning checkpoint: {ckpt_path}") + logger.info(f"Using training config: {config_path}") + + config = OmegaConf.load(config_path) + + if os.path.isdir(ckpt_path): + # Assumes a Deepspeed-saved checkpoint directory + model_state_file = os.path.join(ckpt_path, 'checkpoint', 'mp_rank_00_model_states.pt') + if not os.path.exists(model_state_file): + raise FileNotFoundError( + f"Could not find model weights file 'mp_rank_00_model_states.pt' in Deepspeed checkpoint directory: {os.path.join(ckpt_path, 'checkpoint')}" + ) + + logger.info(f"Detected Deepspeed checkpoint directory, loading weights from: '{model_state_file}'") + ckpt = torch.load(model_state_file, map_location='cpu', weights_only=False) + # Deepspeed weights are often nested under the 'module' key + state_dict = ckpt.get('module', ckpt) + else: + # Standard .ckpt file + logger.info("Detected standard .ckpt file.") + ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) + state_dict = ckpt.get('state_dict', ckpt) + + # 1. Instantiate components that were frozen during training. + # They will load their own pretrained weights upon instantiation. + logger.info("Instantiating VAE, Conditioner, and ImageProcessor...") + vae = instantiate_from_config(config.model.params.first_stage_config) + conditioner = instantiate_from_config(config.model.params.cond_stage_config) + image_processor = instantiate_from_config(config.model.params.image_processor_cfg) + + # 2. Instantiate the component that was trained (the Denoiser). + logger.info("Instantiating Denoiser...") + denoiser = instantiate_from_config(config.model.params.denoiser_cfg) + + # 3. Load weights only for the Denoiser from our training checkpoint. + possible_prefixes = ["model.model.", "_forward_module.model.", "model."] + denoiser_dict = {} + matched_prefix = None + for prefix in possible_prefixes: + sub_dict = {k.replace(prefix, ''): v for k, v in state_dict.items() if k.startswith(prefix)} + if sub_dict: + denoiser_dict = sub_dict + matched_prefix = prefix + break + + if denoiser_dict: + logger.info(f"Successfully matched Denoiser weight prefix: '{matched_prefix}'") + missing_keys, unexpected_keys = denoiser.load_state_dict(denoiser_dict, strict=False) + logger.info(" Successfully loaded weights for 'denoiser'.") + if missing_keys: + logger.warning(f" - Missing keys: {missing_keys}") + if unexpected_keys: + logger.warning(f" - Unexpected keys: {unexpected_keys}") + else: + logger.warning("Could not find weights for 'denoiser' in checkpoint. It will be randomly initialized.") + + # 4. Instantiate a new, inference-compatible scheduler. + logger.info("Creating a new scheduler for inference...") + scheduler = FlowMatchEulerDiscreteScheduler() + + # 5. Assemble the final, healthy pipeline. + pipeline = cls( + model=denoiser, + vae=vae, + scheduler=scheduler, + conditioner=conditioner, + image_processor=image_processor, + **kwargs, + ) + + # 6. Move all model components to the correct device and set to evaluation mode. + pipeline.to(torch.device(device), dtype=dtype) + pipeline.model.eval() + pipeline.vae.eval() + pipeline.conditioner.eval() + + logger.info("\n Pipeline successfully assembled from Lightning checkpoint!") + return pipeline