fix: eliminate OOM on RTX 3080 via load_state_dict(assign=True) + low-VRAM mode

Root cause: torch.load() with mmap=True returns fp16 tensors, but
load_state_dict() without assign=True widens them fp16→fp32 in-place,
doubling CPU anon-rss (7 GB fp16 ckpt → 14 GB fp32 params). Combined
with the 2 GB Gradio server baseline, this exceeded the 15 GB physical
RAM limit on the second generation request.

Fix: add assign=True to all load_state_dict calls in pipelines.py and
autoencoders/model.py. With assign=True the mmap fp16 tensors are
assigned directly as model parameters without any fp16→fp32 copy.
When model.to('cuda') is then called, the mmap pages (file-backed,
evictable) are streamed directly to VRAM — CPU anon-rss stays near 0.

Peak RSS is now ~3.9 GB instead of 14.7 GB (killed) across all rounds.

gradio_app.py changes:
- low_vram_mode always takes the full-delete path (never CPU offload)
- glibc malloc tuning at startup (MALLOC_ARENA_MAX=1, malloc_trim)
- preemptive gc.collect(2) + malloc_trim + empty_cache at generation start
- _rlog() memory logging at each major step for monitoring

pipelines.py:
- load_state_dict(..., assign=True) for model, vae, conditioner
- del ckpt after state dict assignment to release mmap fd early

autoencoders/model.py:
- load_state_dict(..., assign=True) in from_single_file
- load_state_dict(..., assign=True) in init_from_ckpt

Verified: 4 consecutive Playwright WebUI rounds (shape+texture) pass
with no OOM. API two-round test also passes.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Akasei
2026-03-17 02:03:43 +08:00
parent 5acd0a765b
commit 70289d04d7
3 changed files with 67 additions and 23 deletions

View File

@@ -149,7 +149,7 @@ class VectsetVAE(nn.Module):
model_kwargs.update(kwargs)
model = cls(**model_kwargs)
model.load_state_dict(ckpt)
model.load_state_dict(ckpt, assign=True)
model.to(device=device, dtype=dtype)
return model
@@ -189,7 +189,7 @@ class VectsetVAE(nn.Module):
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
missing, unexpected = self.load_state_dict(state_dict, strict=False, assign=True)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")

View File

@@ -166,14 +166,16 @@ class Hunyuan3DDiTPipeline:
ckpt[model_name][new_key] = value
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
# load model
# load model — use assign=True so mmap fp16 tensors are assigned directly as
# parameters (no fp16→fp32 widening copy), keeping CPU anon-rss near zero.
model = instantiate_from_config(config['model'])
model.load_state_dict(ckpt['model'])
model.load_state_dict(ckpt['model'], assign=True)
vae = instantiate_from_config(config['vae'])
vae.load_state_dict(ckpt['vae'], strict=False)
vae.load_state_dict(ckpt['vae'], strict=False, assign=True)
conditioner = instantiate_from_config(config['conditioner'])
if 'conditioner' in ckpt:
conditioner.load_state_dict(ckpt['conditioner'])
conditioner.load_state_dict(ckpt['conditioner'], assign=True)
del ckpt # free mmap file-backed pages now that params hold their own refs
image_processor = instantiate_from_config(config['image_processor'])
scheduler = instantiate_from_config(config['scheduler'])