Merge pull request #104 from WncFht/feature/add-enable-flashvdm
【犀牛鸟实战issue】inference speed
This commit is contained in:
@@ -199,7 +199,10 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2.1')
|
parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2.1')
|
||||||
parser.add_argument("--subfolder", type=str, default='hunyuan3d-dit-v2-1')
|
parser.add_argument("--subfolder", type=str, default='hunyuan3d-dit-v2-1')
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
parser.add_argument('--mc_algo', type=str, default='mc')
|
||||||
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
||||||
|
parser.add_argument('--enable_flashvdm', action='store_true')
|
||||||
|
parser.add_argument('--compile', action='store_true')
|
||||||
parser.add_argument('--low_vram_mode', action='store_true')
|
parser.add_argument('--low_vram_mode', action='store_true')
|
||||||
parser.add_argument('--cache-path', type=str, default='./gradio_cache')
|
parser.add_argument('--cache-path', type=str, default='./gradio_cache')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -219,6 +222,9 @@ if __name__ == "__main__":
|
|||||||
low_vram_mode=args.low_vram_mode,
|
low_vram_mode=args.low_vram_mode,
|
||||||
worker_id=worker_id,
|
worker_id=worker_id,
|
||||||
model_semaphore=model_semaphore,
|
model_semaphore=model_semaphore,
|
||||||
save_dir=SAVE_DIR
|
save_dir=SAVE_DIR,
|
||||||
|
mc_algo=args.mc_algo,
|
||||||
|
enable_flashvdm=args.enable_flashvdm,
|
||||||
|
compile=args.compile
|
||||||
)
|
)
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
|||||||
@@ -748,7 +748,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--compile', action='store_true')
|
parser.add_argument('--compile', action='store_true')
|
||||||
parser.add_argument('--low_vram_mode', action='store_true')
|
parser.add_argument('--low_vram_mode', action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.enable_flashvdm = False
|
|
||||||
|
|
||||||
SAVE_DIR = args.cache_path
|
SAVE_DIR = args.cache_path
|
||||||
os.makedirs(SAVE_DIR, exist_ok=True)
|
os.makedirs(SAVE_DIR, exist_ok=True)
|
||||||
|
|||||||
@@ -64,7 +64,10 @@ class ModelWorker:
|
|||||||
low_vram_mode=False,
|
low_vram_mode=False,
|
||||||
worker_id=None,
|
worker_id=None,
|
||||||
model_semaphore=None,
|
model_semaphore=None,
|
||||||
save_dir='gradio_cache'):
|
save_dir='gradio_cache',
|
||||||
|
mc_algo='mc',
|
||||||
|
enable_flashvdm=False,
|
||||||
|
compile=False):
|
||||||
"""
|
"""
|
||||||
Initialize the model worker.
|
Initialize the model worker.
|
||||||
|
|
||||||
@@ -83,6 +86,9 @@ class ModelWorker:
|
|||||||
self.low_vram_mode = low_vram_mode
|
self.low_vram_mode = low_vram_mode
|
||||||
self.model_semaphore = model_semaphore
|
self.model_semaphore = model_semaphore
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
|
self.mc_algo = mc_algo
|
||||||
|
self.enable_flashvdm = enable_flashvdm
|
||||||
|
self.compile = compile
|
||||||
|
|
||||||
logger.info(f"Loading the model {model_path} on worker {self.worker_id} ...")
|
logger.info(f"Loading the model {model_path} on worker {self.worker_id} ...")
|
||||||
|
|
||||||
@@ -91,6 +97,11 @@ class ModelWorker:
|
|||||||
|
|
||||||
# Initialize shape generation pipeline (matching demo.py)
|
# Initialize shape generation pipeline (matching demo.py)
|
||||||
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
||||||
|
if self.enable_flashvdm:
|
||||||
|
mc_algo = 'mc' if self.device in ['cpu', 'mps'] else self.mc_algo
|
||||||
|
self.pipeline.enable_flashvdm(mc_algo=mc_algo)
|
||||||
|
if self.compile:
|
||||||
|
self.pipeline.compile()
|
||||||
|
|
||||||
# Initialize texture generation pipeline (matching demo.py)
|
# Initialize texture generation pipeline (matching demo.py)
|
||||||
max_num_view = 6 # can be 6 to 9
|
max_num_view = 6 # can be 6 to 9
|
||||||
|
|||||||
Reference in New Issue
Block a user