feat: 为 api_server 加上 enable_flashvdm
This commit is contained in:
@@ -64,7 +64,10 @@ class ModelWorker:
|
||||
low_vram_mode=False,
|
||||
worker_id=None,
|
||||
model_semaphore=None,
|
||||
save_dir='gradio_cache'):
|
||||
save_dir='gradio_cache',
|
||||
mc_algo='mc',
|
||||
enable_flashvdm=False,
|
||||
compile=False):
|
||||
"""
|
||||
Initialize the model worker.
|
||||
|
||||
@@ -83,6 +86,9 @@ class ModelWorker:
|
||||
self.low_vram_mode = low_vram_mode
|
||||
self.model_semaphore = model_semaphore
|
||||
self.save_dir = save_dir
|
||||
self.mc_algo = mc_algo
|
||||
self.enable_flashvdm = enable_flashvdm
|
||||
self.compile = compile
|
||||
|
||||
logger.info(f"Loading the model {model_path} on worker {self.worker_id} ...")
|
||||
|
||||
@@ -91,7 +97,12 @@ class ModelWorker:
|
||||
|
||||
# Initialize shape generation pipeline (matching demo.py)
|
||||
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
||||
|
||||
if self.enable_flashvdm:
|
||||
mc_algo = 'mc' if self.device in ['cpu', 'mps'] else self.mc_algo
|
||||
self.pipeline.enable_flashvdm(mc_algo=mc_algo)
|
||||
if self.compile:
|
||||
self.pipeline.compile()
|
||||
|
||||
# Initialize texture generation pipeline (matching demo.py)
|
||||
max_num_view = 6 # can be 6 to 9
|
||||
resolution = 512 # can be 768 or 512
|
||||
|
||||
Reference in New Issue
Block a user