This commit is contained in:
SyncTwin GmbH
2025-06-22 00:43:59 +00:00
committed by Michael Wagner
parent 2d201ec442
commit 0506243637
10 changed files with 1088 additions and 343 deletions

View File

@@ -19,351 +19,72 @@ import argparse
import asyncio
import base64
import logging
import logging.handlers
import os
import sys
import tempfile
import threading
import traceback
import uuid
import time
from io import BytesIO
# Apply torchvision compatibility fix before other imports
sys.path.insert(0, './hy3dshape')
sys.path.insert(0, './hy3dpaint')
try:
from torchvision_fix import apply_fix
apply_fix()
except ImportError:
print("Warning: torchvision_fix module not found, proceeding without compatibility fix")
except Exception as e:
print(f"Warning: Failed to apply torchvision fix: {e}")
from typing import Optional
import torch
import trimesh
import uvicorn
from PIL import Image
from fastapi import FastAPI, Request
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
# Updated imports to match gradio_app.py
from hy3dshape import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier, \
Hunyuan3DDiTFlowMatchingPipeline
from hy3dshape.pipelines import export_to_trimesh
from hy3dshape.rembg import BackgroundRemover
from hy3dshape.utils import logger
# Texture generation imports
try:
from hy3dpaint.textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig
from hy3dpaint.convert_utils import create_glb_with_pbr_materials
HAS_TEXTUREGEN = True
except ImportError:
print("Warning: Texture generation not available")
HAS_TEXTUREGEN = False
LOGDIR = '.'
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True, encoding='UTF-8')
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
SAVE_DIR = 'gradio_cache'
os.makedirs(SAVE_DIR, exist_ok=True)
# Import from root-level modules
from api_models import GenerationRequest, GenerationResponse, StatusResponse, HealthResponse
from logger_utils import build_logger
from constants import (
SERVER_ERROR_MSG, DEFAULT_SAVE_DIR, API_TITLE, API_DESCRIPTION,
API_VERSION, API_CONTACT, API_LICENSE_INFO, API_TAGS_METADATA
)
from model_worker import ModelWorker
# Global variables
SAVE_DIR = DEFAULT_SAVE_DIR
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("controller", f"{SAVE_DIR}/controller.log")
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
# Global worker and semaphore instances
worker = None
model_semaphore = None
def export_mesh(mesh, save_folder, textured=False, type='glb'):
"""
Export a mesh to a file in the specified folder, optionally including textures.
Args:
mesh (trimesh.Trimesh): The mesh object to export.
save_folder (str): Directory path where the mesh file will be saved.
textured (bool, optional): Whether to include textures/normals in the export. Defaults to False.
type (str, optional): File format to export ('glb' or 'obj' supported). Defaults to 'glb'.
Returns:
str: The full path to the exported mesh file.
"""
if textured:
path = os.path.join(save_folder, f'textured_mesh.{type}')
else:
path = os.path.join(save_folder, f'white_mesh.{type}')
if type not in ['glb', 'obj']:
mesh.export(path)
else:
mesh.export(path, include_normals=textured)
return path
class ModelWorker:
def __init__(self,
model_path='tencent/Hunyuan3D-2.1',
tex_model_path='tencent/Hunyuan3D-2.1',
subfolder='hunyuan3d-dit-v2-1',
device='cuda',
enable_tex=False,
low_vram_mode=False):
self.model_path = model_path
self.worker_id = worker_id
self.device = device
self.low_vram_mode = low_vram_mode
logger.info(f"Loading the model {model_path} on worker {worker_id} ...")
# Initialize background remover
self.rembg = BackgroundRemover()
# Initialize shape generation pipeline
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
model_path,
subfolder=subfolder,
use_safetensors=False,
device=device,
)
# Initialize texture generation pipeline if enabled
if enable_tex and HAS_TEXTUREGEN:
try:
conf = Hunyuan3DPaintConfig(max_num_view=8, resolution=768)
conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth"
conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml"
conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr"
self.pipeline_tex = Hunyuan3DPaintPipeline(conf)
except Exception as e:
logger.error(f"Failed to initialize texture pipeline: {e}")
self.pipeline_tex = None
else:
self.pipeline_tex = None
# Initialize mesh processing workers
self.floater_remove_worker = FloaterRemover()
self.degenerate_face_remove_worker = DegenerateFaceRemover()
self.face_reduce_worker = FaceReducer()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return args.limit_model_concurrency - model_semaphore._value + (len(
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
def get_status(self):
return {
"speed": 1,
"queue_length": self.get_queue_length(),
}
@torch.inference_mode()
def generate(self, uid, params):
start_time = time.time()
# Handle input image
if 'image' in params:
image = params["image"]
image = load_image_from_base64(image)
else:
raise ValueError("No input image provided")
# Remove background if needed
if params.get('remove_background', True) or image.mode == "RGB":
image = self.rembg(image.convert('RGB'))
# Handle existing mesh or generate new one
if 'mesh' in params:
mesh = trimesh.load(BytesIO(base64.b64decode(params["mesh"])), file_type='glb')
else:
# Generate new mesh
seed = params.get("seed", 1234)
generator = torch.Generator(self.device).manual_seed(seed)
octree_resolution = params.get("octree_resolution", 256)
num_inference_steps = params.get("num_inference_steps", 5)
guidance_scale = params.get('guidance_scale', 5.0)
num_chunks = params.get('num_chunks', 8000)
outputs = self.pipeline(
image=image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
octree_resolution=octree_resolution,
num_chunks=num_chunks,
output_type='mesh'
)
mesh = export_to_trimesh(outputs)[0]
logger.info("---Shape generation takes %s seconds ---" % (time.time() - start_time))
# Apply texture if requested
if params.get('texture', False) and self.pipeline_tex is not None:
# Post-process mesh for texture generation
mesh = self.floater_remove_worker(mesh)
mesh = self.degenerate_face_remove_worker(mesh)
mesh = self.face_reduce_worker(mesh, max_facenum=params.get('face_count', 40000))
# Generate texture
tex_start_time = time.time()
temp_obj_path = os.path.join(SAVE_DIR, f'{str(uid)}_temp.obj')
mesh.export(temp_obj_path)
text_path = os.path.join(SAVE_DIR, f'{str(uid)}_textured.obj')
self.pipeline_tex(mesh_path=temp_obj_path,
image_path=image,
output_mesh_path=text_path,
save_glb=False)
logger.info("---Texture generation takes %s seconds ---" % (time.time() - tex_start_time))
# Convert to GLB with PBR materials if requested
file_type = params.get('type', 'glb')
if file_type == 'glb':
glb_path = os.path.join(SAVE_DIR, f'{str(uid)}.glb')
# Create texture paths (these would be generated by the texture pipeline)
textures = {
'albedo': text_path.replace('.obj', '_albedo.png'),
'metallic': text_path.replace('.obj', '_metallic.png'),
'roughness': text_path.replace('.obj', '_roughness.jpg')
}
try:
create_glb_with_pbr_materials(text_path, textures, glb_path)
save_path = glb_path
except Exception as e:
logger.warning(f"Failed to create PBR GLB, using regular export: {e}")
mesh = trimesh.load(text_path)
mesh.export(save_path)
else:
# Load textured mesh for other formats
mesh = trimesh.load(text_path)
mesh.export(save_path)
else:
# Export final mesh without texture
file_type = params.get('type', 'glb')
save_path = os.path.join(SAVE_DIR, f'{str(uid)}.{file_type}')
mesh.export(save_path)
if self.low_vram_mode:
torch.cuda.empty_cache()
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
return save_path, uid
app = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(
title=API_TITLE,
description=API_DESCRIPTION,
version=API_VERSION,
contact=API_CONTACT,
license_info=API_LICENSE_INFO,
tags_metadata=API_TAGS_METADATA
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 你可以指定允许的来源
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头部
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/generate")
async def generate(request: Request):
logger.info("Worker generating...")
try:
params = await request.json()
except Exception as e:
logger.error(f"Failed to parse JSON request: {e}")
return JSONResponse({"error": "Invalid JSON request"}, status_code=400)
@app.post("/generate", tags=["generation"])
async def generate_3d_model(request: GenerationRequest):
"""
Generate a 3D model from an input image.
# Validate required parameters
if not params.get('image'):
return JSONResponse({"error": "Image parameter is required"}, status_code=400)
This endpoint takes an image and generates a 3D model with optional textures.
The generation process includes background removal, mesh generation, and optional texture mapping.
Returns:
FileResponse: The generated 3D model file (GLB or OBJ format)
"""
logger.info("Worker generating...")
# Convert Pydantic model to dict for compatibility
params = request.dict()
uid = uuid.uuid4()
try:
@@ -373,14 +94,14 @@ async def generate(request: Request):
traceback.print_exc()
logger.error(f"Caught ValueError: {e}")
ret = {
"text": server_error_msg,
"text": SERVER_ERROR_MSG,
"error_code": 1,
}
return JSONResponse(ret, status_code=404)
except torch.cuda.CudaError as e:
logger.error(f"Caught torch.cuda.CudaError: {e}")
ret = {
"text": server_error_msg,
"text": SERVER_ERROR_MSG,
"error_code": 1,
}
return JSONResponse(ret, status_code=404)
@@ -388,24 +109,27 @@ async def generate(request: Request):
logger.error(f"Caught Unknown Error: {e}")
traceback.print_exc()
ret = {
"text": server_error_msg,
"text": SERVER_ERROR_MSG,
"error_code": 1,
}
return JSONResponse(ret, status_code=404)
@app.post("/send")
async def generate(request: Request):
logger.info("Worker send...")
try:
params = await request.json()
except Exception as e:
logger.error(f"Failed to parse JSON request: {e}")
return JSONResponse({"error": "Invalid JSON request"}, status_code=400)
@app.post("/send", response_model=GenerationResponse, tags=["generation"])
async def send_generation_task(request: GenerationRequest):
"""
Send a 3D generation task to be processed asynchronously.
# Validate required parameters
if not params.get('image'):
return JSONResponse({"error": "Image parameter is required"}, status_code=400)
This endpoint starts the generation process in the background and returns a task ID.
Use the /status/{uid} endpoint to check the progress and retrieve the result.
Returns:
GenerationResponse: Contains the unique task identifier
"""
logger.info("Worker send...")
# Convert Pydantic model to dict for compatibility
params = request.dict()
uid = uuid.uuid4()
try:
@@ -418,14 +142,28 @@ async def generate(request: Request):
return JSONResponse(ret, status_code=500)
@app.get("/health")
@app.get("/health", response_model=HealthResponse, tags=["status"])
async def health_check():
"""Health check endpoint"""
"""
Health check endpoint to verify the service is running.
Returns:
HealthResponse: Service health status and worker identifier
"""
return JSONResponse({"status": "healthy", "worker_id": worker_id}, status_code=200)
@app.get("/status/{uid}")
@app.get("/status/{uid}", response_model=StatusResponse, tags=["status"])
async def status(uid: str):
"""
Check the status of a generation task.
Args:
uid: The unique identifier of the generation task
Returns:
StatusResponse: Current status of the task and result if completed
"""
save_file_path = os.path.join(SAVE_DIR, f'{uid}.glb')
print(save_file_path, os.path.exists(save_file_path))
if not os.path.exists(save_file_path):
@@ -467,10 +205,11 @@ if __name__ == "__main__":
worker = ModelWorker(
model_path=args.model_path,
tex_model_path=args.tex_model_path,
subfolder=args.subfolder,
device=args.device,
enable_tex=args.enable_tex,
low_vram_mode=args.low_vram_mode
low_vram_mode=args.low_vram_mode,
worker_id=worker_id,
model_semaphore=model_semaphore,
save_dir=SAVE_DIR
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")