This commit is contained in:
Huiwenshi
2025-06-14 01:39:07 +08:00
parent db57a4467e
commit 4d67e18386
4 changed files with 11 additions and 4 deletions

View File

@@ -49,7 +49,10 @@ def instantiate_from_config(config, **kwargs):
cls = get_obj_from_str(config["target"])
if config.get("from_pretrained", None):
return cls.from_pretrained(config["from_pretrained"])
return cls.from_pretrained(
config["from_pretrained"],
use_safetensors=config.get('use_safetensors', False),
variant=config.get('variant', 'fp16'))
params = config.get("params", dict())
# params.update(kwargs)

View File

@@ -95,6 +95,7 @@ def smart_load_model(
original_model_path = model_path
# try local path
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
logger.info(f'Try to load model from local path: {model_path}')
if not os.path.exists(model_path):
@@ -105,6 +106,7 @@ def smart_load_model(
path = snapshot_download(
repo_id=original_model_path,
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
local_dir=model_fld
)
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
except ImportError: