fix(oom): use mmap=True for checkpoint loading + malloc_trim + expandable_segments

Root cause: torch.load() reads 6.9GB .ckpt into Python heap + model params
in CPU RAM = ~14GB peak, exceeding 16GB system RAM → OOM Killer.

Fix 1 - mmap=True on all torch.load() calls (torch 2.7 supports this):
  With mmap, checkpoint storage is file-backed (not heap). Only the model
  parameters (also ~7GB) exist in physical RAM during loading. Peak RAM
  drops from ~14GB to ~7GB — within safe limits on 16GB machines.
  Files changed: pipelines.py, hunyuan3ddit.py, model.py (×2), flow_matching_sit.py

Fix 2 - malloc_trim(0) after every gc.collect():
  Forces glibc to return freed heap pages to OS immediately, so Python's
  memory pool doesn't hoard freed model memory before the next load.

Fix 3 - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True:
  Prevents CUDA allocator fragmentation between model switches.

Fix 4 - Adaptive threshold recalculated:
  With mmap loading, loading a model requires ~7.5GB (model params) not
  14GB. CPU offload threshold lowered from 16GB → 10.5GB, enabling fast
  path on machines with more headroom.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Akasei
2026-03-16 23:18:16 +08:00
parent 6534f4ba15
commit f192c86c60
46 changed files with 334079 additions and 10 deletions

View File

@@ -143,7 +143,7 @@ class VectsetVAE(nn.Module):
import safetensors.torch
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
model_kwargs = config['params']
model_kwargs.update(kwargs)
@@ -181,7 +181,7 @@ class VectsetVAE(nn.Module):
)
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")
state_dict = torch.load(path, map_location="cpu", mmap=True)
state_dict = state_dict.get("state_dict", state_dict)
keys = list(state_dict.keys())
for k in keys:

View File

@@ -358,7 +358,7 @@ class Hunyuan3DDiT(nn.Module):
if ckpt_path is not None:
print('restored denoiser ckpt', ckpt_path)
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt = torch.load(ckpt_path, map_location='cpu', mmap=True)
if 'state_dict' not in ckpt:
# deepspeed ckpt
state_dict = {}

View File

@@ -135,7 +135,7 @@ class Diffuser(pl.LightningModule):
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=()):
ckpt = torch.load(path, map_location="cpu")
ckpt = torch.load(path, map_location="cpu", mmap=True)
if 'state_dict' not in ckpt:
# deepspeed ckpt
state_dict = {}

View File

@@ -165,7 +165,7 @@ class Hunyuan3DDiTPipeline:
ckpt[model_name] = {}
ckpt[model_name][new_key] = value
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
# load model
model = instantiate_from_config(config['model'])
model.load_state_dict(ckpt['model'])