fix(oom): use mmap=True for checkpoint loading + malloc_trim + expandable_segments
Root cause: torch.load() reads 6.9GB .ckpt into Python heap + model params in CPU RAM = ~14GB peak, exceeding 16GB system RAM → OOM Killer. Fix 1 - mmap=True on all torch.load() calls (torch 2.7 supports this): With mmap, checkpoint storage is file-backed (not heap). Only the model parameters (also ~7GB) exist in physical RAM during loading. Peak RAM drops from ~14GB to ~7GB — within safe limits on 16GB machines. Files changed: pipelines.py, hunyuan3ddit.py, model.py (×2), flow_matching_sit.py Fix 2 - malloc_trim(0) after every gc.collect(): Forces glibc to return freed heap pages to OS immediately, so Python's memory pool doesn't hoard freed model memory before the next load. Fix 3 - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True: Prevents CUDA allocator fragmentation between model switches. Fix 4 - Adaptive threshold recalculated: With mmap loading, loading a model requires ~7.5GB (model params) not 14GB. CPU offload threshold lowered from 16GB → 10.5GB, enabling fast path on machines with more headroom. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -143,7 +143,7 @@ class VectsetVAE(nn.Module):
|
||||
import safetensors.torch
|
||||
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
|
||||
|
||||
model_kwargs = config['params']
|
||||
model_kwargs.update(kwargs)
|
||||
@@ -181,7 +181,7 @@ class VectsetVAE(nn.Module):
|
||||
)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
state_dict = torch.load(path, map_location="cpu")
|
||||
state_dict = torch.load(path, map_location="cpu", mmap=True)
|
||||
state_dict = state_dict.get("state_dict", state_dict)
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
|
||||
@@ -358,7 +358,7 @@ class Hunyuan3DDiT(nn.Module):
|
||||
if ckpt_path is not None:
|
||||
print('restored denoiser ckpt', ckpt_path)
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', mmap=True)
|
||||
if 'state_dict' not in ckpt:
|
||||
# deepspeed ckpt
|
||||
state_dict = {}
|
||||
|
||||
@@ -135,7 +135,7 @@ class Diffuser(pl.LightningModule):
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
ckpt = torch.load(path, map_location="cpu")
|
||||
ckpt = torch.load(path, map_location="cpu", mmap=True)
|
||||
if 'state_dict' not in ckpt:
|
||||
# deepspeed ckpt
|
||||
state_dict = {}
|
||||
|
||||
@@ -165,7 +165,7 @@ class Hunyuan3DDiTPipeline:
|
||||
ckpt[model_name] = {}
|
||||
ckpt[model_name][new_key] = value
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True, mmap=True)
|
||||
# load model
|
||||
model = instantiate_from_config(config['model'])
|
||||
model.load_state_dict(ckpt['model'])
|
||||
|
||||
Reference in New Issue
Block a user