diff --git a/hy3dshape/hy3dshape/utils/misc.py b/hy3dshape/hy3dshape/utils/misc.py index 228e947..55e1136 100644 --- a/hy3dshape/hy3dshape/utils/misc.py +++ b/hy3dshape/hy3dshape/utils/misc.py @@ -49,7 +49,10 @@ def instantiate_from_config(config, **kwargs): cls = get_obj_from_str(config["target"]) if config.get("from_pretrained", None): - return cls.from_pretrained(config["from_pretrained"]) + return cls.from_pretrained( + config["from_pretrained"], + use_safetensors=config.get('use_safetensors', False), + variant=config.get('variant', 'fp16')) params = config.get("params", dict()) # params.update(kwargs) diff --git a/hy3dshape/hy3dshape/utils/utils.py b/hy3dshape/hy3dshape/utils/utils.py index 6ac8f5d..c88f6bb 100644 --- a/hy3dshape/hy3dshape/utils/utils.py +++ b/hy3dshape/hy3dshape/utils/utils.py @@ -95,6 +95,7 @@ def smart_load_model( original_model_path = model_path # try local path base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen') + model_fld = os.path.expanduser(os.path.join(base_dir, model_path)) model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder)) logger.info(f'Try to load model from local path: {model_path}') if not os.path.exists(model_path): @@ -105,6 +106,7 @@ def smart_load_model( path = snapshot_download( repo_id=original_model_path, allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹 + local_dir=model_fld ) model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变 except ImportError: diff --git a/hy3dshape/scripts/train_deepspeed.sh b/hy3dshape/scripts/train_deepspeed.sh index ba116e6..ed6b7c7 100644 --- a/hy3dshape/scripts/train_deepspeed.sh +++ b/hy3dshape/scripts/train_deepspeed.sh @@ -56,7 +56,7 @@ else fi NODE_RANK=$node_rank \ -HF_HUB_OFFLINE=1 \ +HF_HUB_OFFLINE=0 \ MASTER_PORT=12348 \ MASTER_ADDR=$master_ip \ NCCL_SOCKET_IFNAME=bond1 \ @@ -67,4 +67,4 @@ python3 main.py \ --num_gpus 8 \ --config $config \ --output_dir $output_dir \ - --deepspeed \ No newline at end of file + --deepspeed diff --git a/requirements.txt b/requirements.txt index abe7782..069c30b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,6 @@ torchmetrics==1.6.0 timm pythreejs -torchdiffe \ No newline at end of file +torchdiffe +deepspeed +