update
This commit is contained in:
@@ -49,7 +49,10 @@ def instantiate_from_config(config, **kwargs):
|
|||||||
cls = get_obj_from_str(config["target"])
|
cls = get_obj_from_str(config["target"])
|
||||||
|
|
||||||
if config.get("from_pretrained", None):
|
if config.get("from_pretrained", None):
|
||||||
return cls.from_pretrained(config["from_pretrained"])
|
return cls.from_pretrained(
|
||||||
|
config["from_pretrained"],
|
||||||
|
use_safetensors=config.get('use_safetensors', False),
|
||||||
|
variant=config.get('variant', 'fp16'))
|
||||||
|
|
||||||
params = config.get("params", dict())
|
params = config.get("params", dict())
|
||||||
# params.update(kwargs)
|
# params.update(kwargs)
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ def smart_load_model(
|
|||||||
original_model_path = model_path
|
original_model_path = model_path
|
||||||
# try local path
|
# try local path
|
||||||
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
||||||
|
model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
|
||||||
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
||||||
logger.info(f'Try to load model from local path: {model_path}')
|
logger.info(f'Try to load model from local path: {model_path}')
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
@@ -105,6 +106,7 @@ def smart_load_model(
|
|||||||
path = snapshot_download(
|
path = snapshot_download(
|
||||||
repo_id=original_model_path,
|
repo_id=original_model_path,
|
||||||
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
|
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
|
||||||
|
local_dir=model_fld
|
||||||
)
|
)
|
||||||
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
|
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
NODE_RANK=$node_rank \
|
NODE_RANK=$node_rank \
|
||||||
HF_HUB_OFFLINE=1 \
|
HF_HUB_OFFLINE=0 \
|
||||||
MASTER_PORT=12348 \
|
MASTER_PORT=12348 \
|
||||||
MASTER_ADDR=$master_ip \
|
MASTER_ADDR=$master_ip \
|
||||||
NCCL_SOCKET_IFNAME=bond1 \
|
NCCL_SOCKET_IFNAME=bond1 \
|
||||||
|
|||||||
@@ -62,3 +62,5 @@ torchmetrics==1.6.0
|
|||||||
timm
|
timm
|
||||||
pythreejs
|
pythreejs
|
||||||
torchdiffe
|
torchdiffe
|
||||||
|
deepspeed
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user