diff --git a/gradio_app.py b/gradio_app.py index b300a61..c111320 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -28,6 +28,7 @@ except Exception as e: print(f"Warning: Failed to apply torchvision fix: {e}") +import gc import os import random import shutil @@ -213,6 +214,53 @@ height="{height}" width="100%" frameborder="0">' """ +# --------------------------------------------------------------------------- +# VRAM management helpers (used when --low_vram_mode is set) +# Models are unloaded (del'd) before the other model runs, then reloaded +# on next request — no CPU intermediate, VRAM freed immediately. +# --------------------------------------------------------------------------- + +def _unload_i23d_worker(): + """Delete shape model from GPU and free VRAM.""" + global i23d_worker + del i23d_worker + i23d_worker = None + gc.collect() + torch.cuda.empty_cache() + + +def _ensure_i23d_worker(): + """Reload shape model to GPU if it was previously unloaded.""" + global i23d_worker + if i23d_worker is None: + from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline + logger.info("Reloading shape model to GPU...") + i23d_worker = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained( + args.model_path, + subfolder=args.subfolder, + use_safetensors=False, + device=args.device, + ) + + +def _unload_tex_pipeline(): + """Delete texture pipeline from GPU and free VRAM.""" + global tex_pipeline + del tex_pipeline + tex_pipeline = None + gc.collect() + torch.cuda.empty_cache() + + +def _ensure_tex_pipeline(): + """Load texture pipeline to GPU if not already loaded.""" + global tex_pipeline + if tex_pipeline is None and tex_conf is not None: + from hy3dpaint.textureGenPipeline import Hunyuan3DPaintPipeline + logger.info("Loading texture pipeline to GPU...") + tex_pipeline = Hunyuan3DPaintPipeline(tex_conf) + + @spaces.GPU(duration=60) def _gen_shape( caption=None, @@ -297,6 +345,9 @@ def _gen_shape( # image to white model start_time = time.time() + if args.low_vram_mode: + _ensure_i23d_worker() + generator = torch.Generator() generator = generator.manual_seed(int(seed)) outputs = i23d_worker( @@ -379,18 +430,17 @@ def generation_all( text_path = os.path.join(save_folder, f'textured_mesh.obj') - # In low_vram_mode: offload shape model to CPU before texture gen to free VRAM, - # mirroring the sequential-load strategy in batch_generate.py. + # In low_vram_mode: unload shape model entirely (del, no CPU copy) to free VRAM, + # then load texture pipeline on demand. Shape model reloads lazily on next request. if args.low_vram_mode: - i23d_worker.to('cpu') - torch.cuda.empty_cache() + _unload_i23d_worker() + _ensure_tex_pipeline() path_textured = tex_pipeline(mesh_path=path, image_path=image, output_mesh_path=text_path, save_glb=False) - # Restore shape model to GPU so subsequent requests don't need to reload from disk. + # Unload texture pipeline after use so VRAM is free for the next shape request. if args.low_vram_mode: - i23d_worker.to('cuda') - torch.cuda.empty_cache() + _unload_tex_pipeline() logger.info("---Texture Generation takes %s seconds ---" % (time.time() - tmp_time)) stats['time']['texture generation'] = time.time() - tmp_time @@ -808,11 +858,13 @@ if __name__ == '__main__': # texgen_worker.enable_model_cpu_offload() from hy3dpaint.textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig - conf = Hunyuan3DPaintConfig(max_num_view=9, resolution=512) - conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth" - conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml" - conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr" - tex_pipeline = Hunyuan3DPaintPipeline(conf) + tex_conf = Hunyuan3DPaintConfig(max_num_view=9, resolution=512) + tex_conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth" + tex_conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml" + tex_conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr" + if not args.low_vram_mode: + # Load immediately; in low_vram_mode we load on-demand per request. + tex_pipeline = Hunyuan3DPaintPipeline(tex_conf) # Not help much, ignore for now. # if args.compile: