fix: eliminate OOM on RTX 3080 via load_state_dict(assign=True) + low-VRAM mode
Root cause: torch.load() with mmap=True returns fp16 tensors, but
load_state_dict() without assign=True widens them fp16→fp32 in-place,
doubling CPU anon-rss (7 GB fp16 ckpt → 14 GB fp32 params). Combined
with the 2 GB Gradio server baseline, this exceeded the 15 GB physical
RAM limit on the second generation request.
Fix: add assign=True to all load_state_dict calls in pipelines.py and
autoencoders/model.py. With assign=True the mmap fp16 tensors are
assigned directly as model parameters without any fp16→fp32 copy.
When model.to('cuda') is then called, the mmap pages (file-backed,
evictable) are streamed directly to VRAM — CPU anon-rss stays near 0.
Peak RSS is now ~3.9 GB instead of 14.7 GB (killed) across all rounds.
gradio_app.py changes:
- low_vram_mode always takes the full-delete path (never CPU offload)
- glibc malloc tuning at startup (MALLOC_ARENA_MAX=1, malloc_trim)
- preemptive gc.collect(2) + malloc_trim + empty_cache at generation start
- _rlog() memory logging at each major step for monitoring
pipelines.py:
- load_state_dict(..., assign=True) for model, vae, conditioner
- del ckpt after state dict assignment to release mmap fd early
autoencoders/model.py:
- load_state_dict(..., assign=True) in from_single_file
- load_state_dict(..., assign=True) in init_from_ckpt
Verified: 4 consecutive Playwright WebUI rounds (shape+texture) pass
with no OOM. API two-round test also passes.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -51,10 +51,25 @@ import numpy as np
|
|||||||
from hy3dshape.utils import logger
|
from hy3dshape.utils import logger
|
||||||
from hy3dpaint.convert_utils import create_glb_with_pbr_materials
|
from hy3dpaint.convert_utils import create_glb_with_pbr_materials
|
||||||
|
|
||||||
# Force OS to reclaim freed heap pages, reducing Python's RSS after model deletion.
|
# ── glibc malloc tuning ───────────────────────────────────────────────────────
|
||||||
|
# Applied BEFORE any large allocation so glibc honours them from the start.
|
||||||
|
# M_MMAP_THRESHOLD (-3): allocations > 1 MB use anonymous mmap instead of
|
||||||
|
# the heap; when freed they are immediately returned to the OS via munmap,
|
||||||
|
# eliminating heap fragmentation for PyTorch tensors (all >> 1 MB).
|
||||||
|
# M_ARENA_MAX (-8 via env): limit to 1 arena so malloc_trim() can release
|
||||||
|
# ALL freed pages, not just the main-thread arena.
|
||||||
|
os.environ.setdefault("MALLOC_ARENA_MAX", "1")
|
||||||
|
os.environ.setdefault("MALLOC_MMAP_THRESHOLD_", "1048576") # 1 MB
|
||||||
|
|
||||||
_libc = ctypes.CDLL(ctypes.util.find_library("c") or "libc.so.6", use_errno=True)
|
_libc = ctypes.CDLL(ctypes.util.find_library("c") or "libc.so.6", use_errno=True)
|
||||||
|
try:
|
||||||
|
_libc.mallopt(-3, 1024 * 1024) # M_MMAP_THRESHOLD = 1 MB (runtime)
|
||||||
|
_libc.mallopt(-1, 128 * 1024) # M_TRIM_THRESHOLD = 128 KB (trim aggressively)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _malloc_trim():
|
def _malloc_trim():
|
||||||
|
"""Return all free heap pages to the OS (glibc brk-based heap)."""
|
||||||
try:
|
try:
|
||||||
_libc.malloc_trim(0)
|
_libc.malloc_trim(0)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -285,13 +300,18 @@ def _can_offload_to_cpu():
|
|||||||
|
|
||||||
|
|
||||||
def _prepare_for_tex():
|
def _prepare_for_tex():
|
||||||
"""Free VRAM from shape model before loading texture pipeline."""
|
"""Free VRAM from shape model before loading texture pipeline.
|
||||||
|
|
||||||
|
In low_vram_mode the shape model is always fully deleted so that its
|
||||||
|
~7.25 GB of VRAM is completely free before the texture pipeline loads.
|
||||||
|
CPU-offload path is only considered when low_vram_mode is disabled.
|
||||||
|
"""
|
||||||
global i23d_worker, _i23d_on_cpu
|
global i23d_worker, _i23d_on_cpu
|
||||||
if i23d_worker is None:
|
if i23d_worker is None:
|
||||||
_ensure_tex_pipeline()
|
_ensure_tex_pipeline()
|
||||||
return
|
return
|
||||||
|
|
||||||
if _can_offload_to_cpu():
|
if not args.low_vram_mode and _can_offload_to_cpu():
|
||||||
logger.info("Offloading shape model to CPU RAM (fast path)...")
|
logger.info("Offloading shape model to CPU RAM (fast path)...")
|
||||||
i23d_worker.to('cpu')
|
i23d_worker.to('cpu')
|
||||||
_i23d_on_cpu = True
|
_i23d_on_cpu = True
|
||||||
@@ -299,7 +319,7 @@ def _prepare_for_tex():
|
|||||||
_malloc_trim()
|
_malloc_trim()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
logger.info("Deleting shape model entirely (safe path, limited RAM)...")
|
logger.info("Deleting shape model entirely (low_vram path)...")
|
||||||
del i23d_worker
|
del i23d_worker
|
||||||
i23d_worker = None
|
i23d_worker = None
|
||||||
_i23d_on_cpu = False
|
_i23d_on_cpu = False
|
||||||
@@ -312,14 +332,17 @@ def _prepare_for_tex():
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_i23d_worker():
|
def _ensure_i23d_worker():
|
||||||
"""Load shape model to GPU — from CPU RAM (fast) or disk (slow)."""
|
"""Load shape model to GPU.
|
||||||
|
|
||||||
|
In low_vram_mode always reload from disk (CPU-offload path is never used).
|
||||||
|
"""
|
||||||
global i23d_worker, _i23d_on_cpu
|
global i23d_worker, _i23d_on_cpu
|
||||||
if i23d_worker is not None and _i23d_on_cpu:
|
if not args.low_vram_mode and i23d_worker is not None and _i23d_on_cpu:
|
||||||
logger.info("Restoring shape model from CPU to GPU (fast path)...")
|
logger.info("Restoring shape model from CPU to GPU (fast path)...")
|
||||||
i23d_worker.to(args.device)
|
i23d_worker.to(args.device)
|
||||||
_i23d_on_cpu = False
|
_i23d_on_cpu = False
|
||||||
elif i23d_worker is None:
|
elif i23d_worker is None:
|
||||||
logger.info("Reloading shape model from disk to GPU (slow path)...")
|
logger.info("Reloading shape model from disk to GPU...")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
_malloc_trim()
|
_malloc_trim()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -487,6 +510,27 @@ def generation_all(
|
|||||||
num_chunks=200000,
|
num_chunks=200000,
|
||||||
randomize_seed: bool = False,
|
randomize_seed: bool = False,
|
||||||
):
|
):
|
||||||
|
import os as _os
|
||||||
|
def _rss_mb():
|
||||||
|
try:
|
||||||
|
with open('/proc/self/status') as _f:
|
||||||
|
for _l in _f:
|
||||||
|
if _l.startswith('VmRSS:'):
|
||||||
|
return int(_l.split()[1]) // 1024
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 0
|
||||||
|
def _rlog(label):
|
||||||
|
vram = torch.cuda.memory_allocated() // (1024*1024)
|
||||||
|
logger.info(f"[MEM] {label:40s} RSS={_rss_mb():6d} MB VRAM={vram:5d} MB")
|
||||||
|
|
||||||
|
# Proactively free any memory left over from previous generations so that
|
||||||
|
# fresh model loading starts from the lowest possible RSS baseline.
|
||||||
|
gc.collect(2)
|
||||||
|
_malloc_trim()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
_rlog("generation_all start")
|
||||||
|
|
||||||
start_time_0 = time.time()
|
start_time_0 = time.time()
|
||||||
mesh, image, save_folder, stats, seed = _gen_shape(
|
mesh, image, save_folder, stats, seed = _gen_shape(
|
||||||
caption,
|
caption,
|
||||||
@@ -503,18 +547,12 @@ def generation_all(
|
|||||||
num_chunks=num_chunks,
|
num_chunks=num_chunks,
|
||||||
randomize_seed=randomize_seed,
|
randomize_seed=randomize_seed,
|
||||||
)
|
)
|
||||||
|
_rlog("after _gen_shape")
|
||||||
path = export_mesh(mesh, save_folder, textured=False)
|
path = export_mesh(mesh, save_folder, textured=False)
|
||||||
|
|
||||||
|
|
||||||
print(path)
|
print(path)
|
||||||
print('='*40)
|
print('='*40)
|
||||||
|
|
||||||
# tmp_time = time.time()
|
|
||||||
# mesh = floater_remove_worker(mesh)
|
|
||||||
# mesh = degenerate_face_remove_worker(mesh)
|
|
||||||
# logger.info("---Postprocessing takes %s seconds ---" % (time.time() - tmp_time))
|
|
||||||
# stats['time']['postprocessing'] = time.time() - tmp_time
|
|
||||||
|
|
||||||
tmp_time = time.time()
|
tmp_time = time.time()
|
||||||
mesh = face_reduce_worker(mesh)
|
mesh = face_reduce_worker(mesh)
|
||||||
|
|
||||||
@@ -523,21 +561,24 @@ def generation_all(
|
|||||||
|
|
||||||
logger.info("---Face Reduction takes %s seconds ---" % (time.time() - tmp_time))
|
logger.info("---Face Reduction takes %s seconds ---" % (time.time() - tmp_time))
|
||||||
stats['time']['face reduction'] = time.time() - tmp_time
|
stats['time']['face reduction'] = time.time() - tmp_time
|
||||||
|
_rlog("after face reduction")
|
||||||
|
|
||||||
tmp_time = time.time()
|
tmp_time = time.time()
|
||||||
|
|
||||||
text_path = os.path.join(save_folder, f'textured_mesh.obj')
|
text_path = os.path.join(save_folder, f'textured_mesh.obj')
|
||||||
|
|
||||||
# In low_vram_mode: adaptively offload shape model (CPU or delete based on
|
# In low_vram_mode: delete shape model then load texture pipeline.
|
||||||
# available RAM), then load texture pipeline.
|
|
||||||
if args.low_vram_mode:
|
if args.low_vram_mode:
|
||||||
_prepare_for_tex()
|
_prepare_for_tex()
|
||||||
|
_rlog("after _prepare_for_tex (shape deleted, tex loaded)")
|
||||||
|
|
||||||
path_textured = tex_pipeline(mesh_path=path, image_path=image, output_mesh_path=text_path, save_glb=False)
|
path_textured = tex_pipeline(mesh_path=path, image_path=image, output_mesh_path=text_path, save_glb=False)
|
||||||
|
_rlog("after tex_pipeline inference")
|
||||||
|
|
||||||
# Unload texture pipeline after use so VRAM is free for the next shape request.
|
# Unload texture pipeline after use so VRAM is free for the next shape request.
|
||||||
if args.low_vram_mode:
|
if args.low_vram_mode:
|
||||||
_unload_tex_pipeline()
|
_unload_tex_pipeline()
|
||||||
|
_rlog("after _unload_tex_pipeline")
|
||||||
|
|
||||||
logger.info("---Texture Generation takes %s seconds ---" % (time.time() - tmp_time))
|
logger.info("---Texture Generation takes %s seconds ---" % (time.time() - tmp_time))
|
||||||
stats['time']['texture generation'] = time.time() - tmp_time
|
stats['time']['texture generation'] = time.time() - tmp_time
|
||||||
@@ -555,6 +596,7 @@ def generation_all(
|
|||||||
width=HTML_WIDTH, textured=True)
|
width=HTML_WIDTH, textured=True)
|
||||||
if args.low_vram_mode:
|
if args.low_vram_mode:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
_rlog("generation_all complete")
|
||||||
return (
|
return (
|
||||||
gr.update(value=path),
|
gr.update(value=path),
|
||||||
gr.update(value=glb_path_textured),
|
gr.update(value=glb_path_textured),
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ class VectsetVAE(nn.Module):
|
|||||||
model_kwargs.update(kwargs)
|
model_kwargs.update(kwargs)
|
||||||
|
|
||||||
model = cls(**model_kwargs)
|
model = cls(**model_kwargs)
|
||||||
model.load_state_dict(ckpt)
|
model.load_state_dict(ckpt, assign=True)
|
||||||
model.to(device=device, dtype=dtype)
|
model.to(device=device, dtype=dtype)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ class VectsetVAE(nn.Module):
|
|||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del state_dict[k]
|
del state_dict[k]
|
||||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
missing, unexpected = self.load_state_dict(state_dict, strict=False, assign=True)
|
||||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
print(f"Missing Keys: {missing}")
|
print(f"Missing Keys: {missing}")
|
||||||
|
|||||||
@@ -166,14 +166,16 @@ class Hunyuan3DDiTPipeline:
|
|||||||
ckpt[model_name][new_key] = value
|
ckpt[model_name][new_key] = value
|
||||||
else:
|
else:
|
||||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
|
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
|
||||||
# load model
|
# load model — use assign=True so mmap fp16 tensors are assigned directly as
|
||||||
|
# parameters (no fp16→fp32 widening copy), keeping CPU anon-rss near zero.
|
||||||
model = instantiate_from_config(config['model'])
|
model = instantiate_from_config(config['model'])
|
||||||
model.load_state_dict(ckpt['model'])
|
model.load_state_dict(ckpt['model'], assign=True)
|
||||||
vae = instantiate_from_config(config['vae'])
|
vae = instantiate_from_config(config['vae'])
|
||||||
vae.load_state_dict(ckpt['vae'], strict=False)
|
vae.load_state_dict(ckpt['vae'], strict=False, assign=True)
|
||||||
conditioner = instantiate_from_config(config['conditioner'])
|
conditioner = instantiate_from_config(config['conditioner'])
|
||||||
if 'conditioner' in ckpt:
|
if 'conditioner' in ckpt:
|
||||||
conditioner.load_state_dict(ckpt['conditioner'])
|
conditioner.load_state_dict(ckpt['conditioner'], assign=True)
|
||||||
|
del ckpt # free mmap file-backed pages now that params hold their own refs
|
||||||
image_processor = instantiate_from_config(config['image_processor'])
|
image_processor = instantiate_from_config(config['image_processor'])
|
||||||
scheduler = instantiate_from_config(config['scheduler'])
|
scheduler = instantiate_from_config(config['scheduler'])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user