This commit is contained in:
Huiwenshi
2025-06-14 01:39:07 +08:00
parent db57a4467e
commit 4d67e18386
4 changed files with 11 additions and 4 deletions

View File

@@ -49,7 +49,10 @@ def instantiate_from_config(config, **kwargs):
cls = get_obj_from_str(config["target"]) cls = get_obj_from_str(config["target"])
if config.get("from_pretrained", None): if config.get("from_pretrained", None):
return cls.from_pretrained(config["from_pretrained"]) return cls.from_pretrained(
config["from_pretrained"],
use_safetensors=config.get('use_safetensors', False),
variant=config.get('variant', 'fp16'))
params = config.get("params", dict()) params = config.get("params", dict())
# params.update(kwargs) # params.update(kwargs)

View File

@@ -95,6 +95,7 @@ def smart_load_model(
original_model_path = model_path original_model_path = model_path
# try local path # try local path
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen') base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder)) model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
logger.info(f'Try to load model from local path: {model_path}') logger.info(f'Try to load model from local path: {model_path}')
if not os.path.exists(model_path): if not os.path.exists(model_path):
@@ -105,6 +106,7 @@ def smart_load_model(
path = snapshot_download( path = snapshot_download(
repo_id=original_model_path, repo_id=original_model_path,
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹 allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
local_dir=model_fld
) )
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变 model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
except ImportError: except ImportError:

View File

@@ -56,7 +56,7 @@ else
fi fi
NODE_RANK=$node_rank \ NODE_RANK=$node_rank \
HF_HUB_OFFLINE=1 \ HF_HUB_OFFLINE=0 \
MASTER_PORT=12348 \ MASTER_PORT=12348 \
MASTER_ADDR=$master_ip \ MASTER_ADDR=$master_ip \
NCCL_SOCKET_IFNAME=bond1 \ NCCL_SOCKET_IFNAME=bond1 \

View File

@@ -62,3 +62,5 @@ torchmetrics==1.6.0
timm timm
pythreejs pythreejs
torchdiffe torchdiffe
deepspeed