Files
Hunyuan3D_2.1_Low_VRAM/model_worker.py
SyncTwin GmbH 0506243637 next
2025-06-28 11:49:13 +02:00

152 lines
4.9 KiB
Python

"""
Model worker for Hunyuan3D API server.
"""
import os
import time
import uuid
import base64
import trimesh
from io import BytesIO
from PIL import Image
import torch
# Apply torchvision compatibility fix before other imports
import sys
sys.path.insert(0, './hy3dshape')
sys.path.insert(0, './hy3dpaint')
try:
from torchvision_fix import apply_fix
apply_fix()
except ImportError:
print("Warning: torchvision_fix module not found, proceeding without compatibility fix")
except Exception as e:
print(f"Warning: Failed to apply torchvision fix: {e}")
from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline
from hy3dshape.rembg import BackgroundRemover
from hy3dshape.utils import logger
def load_image_from_base64(image):
"""
Load an image from base64 encoded string.
Args:
image (str): Base64 encoded image string
Returns:
PIL.Image: Loaded image
"""
return Image.open(BytesIO(base64.b64decode(image)))
class ModelWorker:
"""
Worker class for handling 3D model generation tasks.
"""
def __init__(self,
model_path='tencent/Hunyuan3D-2.1',
subfolder='hunyuan3d-dit-v2-1',
device='cuda',
low_vram_mode=False,
worker_id=None,
model_semaphore=None,
save_dir='gradio_cache'):
"""
Initialize the model worker.
Args:
model_path (str): Path to the shape generation model
subfolder (str): Subfolder containing the model files
device (str): Device to run the model on ('cuda' or 'cpu')
low_vram_mode (bool): Whether to use low VRAM mode
worker_id (str): Unique identifier for this worker
model_semaphore: Semaphore for controlling model concurrency
save_dir (str): Directory to save generated files
"""
self.model_path = model_path
self.worker_id = worker_id or str(uuid.uuid4())[:6]
self.device = device
self.low_vram_mode = low_vram_mode
self.model_semaphore = model_semaphore
self.save_dir = save_dir
logger.info(f"Loading the model {model_path} on worker {self.worker_id} ...")
# Initialize background remover
self.rembg = BackgroundRemover()
# Initialize shape generation pipeline (matching demo.py)
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
def get_queue_length(self):
"""
Get the current queue length for model processing.
Returns:
int: Number of tasks in the queue
"""
if self.model_semaphore is None:
return 0
else:
return (self.model_semaphore._value if hasattr(self.model_semaphore, '_value') else 0) + \
(len(self.model_semaphore._waiters) if hasattr(self.model_semaphore, '_waiters') and self.model_semaphore._waiters is not None else 0)
def get_status(self):
"""
Get the current status of the worker.
Returns:
dict: Status information including speed and queue length
"""
return {
"speed": 1,
"queue_length": self.get_queue_length(),
}
@torch.inference_mode()
def generate(self, uid, params):
"""
Generate a 3D model from the given parameters.
Args:
uid: Unique identifier for this generation task
params (dict): Generation parameters including image and options
Returns:
tuple: (file_path, uid) - Path to generated file and task ID
"""
start_time = time.time()
# Handle input image
if 'image' in params:
image = params["image"]
image = load_image_from_base64(image)
else:
raise ValueError("No input image provided")
# Convert to RGBA and remove background if needed (matching demo.py)
image = image.convert("RGBA")
if image.mode == "RGB":
image = self.rembg(image)
# Generate mesh using the same simple approach as demo.py
try:
mesh = self.pipeline(image=image)[0]
logger.info("---Shape generation takes %s seconds ---" % (time.time() - start_time))
except Exception as e:
logger.error(f"Shape generation failed: {e}")
raise ValueError(f"Failed to generate 3D mesh: {str(e)}")
# Export final mesh without texture
file_type = params.get('type', 'glb')
save_path = os.path.join(self.save_dir, f'{str(uid)}.{file_type}')
mesh.export(save_path)
if self.low_vram_mode:
torch.cuda.empty_cache()
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
return save_path, uid