Update pipelines.py
This commit is contained in:
@@ -781,3 +781,111 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
||||
enable_pbar=enable_pbar,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@synchronize_timer('Hunyuan3DDiTFlowMatchingPipeline from Lightning Checkpoint')
|
||||
def from_lightning_checkpoint(
|
||||
cls,
|
||||
ckpt_path: str,
|
||||
config_path: str,
|
||||
device: str = 'cuda',
|
||||
dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads a model from a checkpoint created by the project's PyTorch Lightning training script.
|
||||
|
||||
This method correctly handles the nested configuration structure and state_dict prefixes
|
||||
produced during training, and can intelligently load sharded checkpoints saved by Deepspeed.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the .ckpt checkpoint file or directory.
|
||||
config_path (str): Path to the .yaml configuration file used for training.
|
||||
device (str, optional): The device to load the model on. Defaults to 'cuda'.
|
||||
dtype (torch.dtype, optional): The data type for the model. Defaults to torch.float16.
|
||||
|
||||
Returns:
|
||||
Hunyuan3DDiTFlowMatchingPipeline: An instantiated pipeline ready for inference.
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from hy3dshape.utils.misc import instantiate_from_config
|
||||
from hy3dshape.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
logger.info(f"Loading model from Lightning checkpoint: {ckpt_path}")
|
||||
logger.info(f"Using training config: {config_path}")
|
||||
|
||||
config = OmegaConf.load(config_path)
|
||||
|
||||
if os.path.isdir(ckpt_path):
|
||||
# Assumes a Deepspeed-saved checkpoint directory
|
||||
model_state_file = os.path.join(ckpt_path, 'checkpoint', 'mp_rank_00_model_states.pt')
|
||||
if not os.path.exists(model_state_file):
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model weights file 'mp_rank_00_model_states.pt' in Deepspeed checkpoint directory: {os.path.join(ckpt_path, 'checkpoint')}"
|
||||
)
|
||||
|
||||
logger.info(f"Detected Deepspeed checkpoint directory, loading weights from: '{model_state_file}'")
|
||||
ckpt = torch.load(model_state_file, map_location='cpu', weights_only=False)
|
||||
# Deepspeed weights are often nested under the 'module' key
|
||||
state_dict = ckpt.get('module', ckpt)
|
||||
else:
|
||||
# Standard .ckpt file
|
||||
logger.info("Detected standard .ckpt file.")
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
|
||||
state_dict = ckpt.get('state_dict', ckpt)
|
||||
|
||||
# 1. Instantiate components that were frozen during training.
|
||||
# They will load their own pretrained weights upon instantiation.
|
||||
logger.info("Instantiating VAE, Conditioner, and ImageProcessor...")
|
||||
vae = instantiate_from_config(config.model.params.first_stage_config)
|
||||
conditioner = instantiate_from_config(config.model.params.cond_stage_config)
|
||||
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
|
||||
|
||||
# 2. Instantiate the component that was trained (the Denoiser).
|
||||
logger.info("Instantiating Denoiser...")
|
||||
denoiser = instantiate_from_config(config.model.params.denoiser_cfg)
|
||||
|
||||
# 3. Load weights only for the Denoiser from our training checkpoint.
|
||||
possible_prefixes = ["model.model.", "_forward_module.model.", "model."]
|
||||
denoiser_dict = {}
|
||||
matched_prefix = None
|
||||
for prefix in possible_prefixes:
|
||||
sub_dict = {k.replace(prefix, ''): v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
if sub_dict:
|
||||
denoiser_dict = sub_dict
|
||||
matched_prefix = prefix
|
||||
break
|
||||
|
||||
if denoiser_dict:
|
||||
logger.info(f"Successfully matched Denoiser weight prefix: '{matched_prefix}'")
|
||||
missing_keys, unexpected_keys = denoiser.load_state_dict(denoiser_dict, strict=False)
|
||||
logger.info(" Successfully loaded weights for 'denoiser'.")
|
||||
if missing_keys:
|
||||
logger.warning(f" - Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logger.warning(f" - Unexpected keys: {unexpected_keys}")
|
||||
else:
|
||||
logger.warning("Could not find weights for 'denoiser' in checkpoint. It will be randomly initialized.")
|
||||
|
||||
# 4. Instantiate a new, inference-compatible scheduler.
|
||||
logger.info("Creating a new scheduler for inference...")
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
# 5. Assemble the final, healthy pipeline.
|
||||
pipeline = cls(
|
||||
model=denoiser,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
conditioner=conditioner,
|
||||
image_processor=image_processor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# 6. Move all model components to the correct device and set to evaluation mode.
|
||||
pipeline.to(torch.device(device), dtype=dtype)
|
||||
pipeline.model.eval()
|
||||
pipeline.vae.eval()
|
||||
pipeline.conditioner.eval()
|
||||
|
||||
logger.info("\n Pipeline successfully assembled from Lightning checkpoint!")
|
||||
return pipeline
|
||||
|
||||
Reference in New Issue
Block a user