update
This commit is contained in:
@@ -49,7 +49,10 @@ def instantiate_from_config(config, **kwargs):
|
||||
cls = get_obj_from_str(config["target"])
|
||||
|
||||
if config.get("from_pretrained", None):
|
||||
return cls.from_pretrained(config["from_pretrained"])
|
||||
return cls.from_pretrained(
|
||||
config["from_pretrained"],
|
||||
use_safetensors=config.get('use_safetensors', False),
|
||||
variant=config.get('variant', 'fp16'))
|
||||
|
||||
params = config.get("params", dict())
|
||||
# params.update(kwargs)
|
||||
|
||||
@@ -95,6 +95,7 @@ def smart_load_model(
|
||||
original_model_path = model_path
|
||||
# try local path
|
||||
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
||||
model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
|
||||
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
||||
logger.info(f'Try to load model from local path: {model_path}')
|
||||
if not os.path.exists(model_path):
|
||||
@@ -105,6 +106,7 @@ def smart_load_model(
|
||||
path = snapshot_download(
|
||||
repo_id=original_model_path,
|
||||
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
|
||||
local_dir=model_fld
|
||||
)
|
||||
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
|
||||
except ImportError:
|
||||
|
||||
@@ -56,7 +56,7 @@ else
|
||||
fi
|
||||
|
||||
NODE_RANK=$node_rank \
|
||||
HF_HUB_OFFLINE=1 \
|
||||
HF_HUB_OFFLINE=0 \
|
||||
MASTER_PORT=12348 \
|
||||
MASTER_ADDR=$master_ip \
|
||||
NCCL_SOCKET_IFNAME=bond1 \
|
||||
|
||||
@@ -62,3 +62,5 @@ torchmetrics==1.6.0
|
||||
timm
|
||||
pythreejs
|
||||
torchdiffe
|
||||
deepspeed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user