This commit is contained in:
SyncTwin GmbH
2025-06-22 11:23:06 +02:00
committed by Michael Wagner
parent 0506243637
commit f01b8af1d3
2 changed files with 122 additions and 14 deletions

View File

@@ -26,6 +26,7 @@ except Exception as e:
from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline
from hy3dshape.rembg import BackgroundRemover from hy3dshape.rembg import BackgroundRemover
from hy3dshape.utils import logger from hy3dshape.utils import logger
from textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig
def load_image_from_base64(image): def load_image_from_base64(image):
@@ -81,6 +82,15 @@ class ModelWorker:
# Initialize shape generation pipeline (matching demo.py) # Initialize shape generation pipeline (matching demo.py)
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path) self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
# Initialize texture generation pipeline (matching demo.py)
max_num_view = 6 # can be 6 to 9
resolution = 512 # can be 768 or 512
conf = Hunyuan3DPaintConfig(max_num_view, resolution)
conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth"
conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml"
conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr"
self.paint_pipeline = Hunyuan3DPaintPipeline(conf)
def get_queue_length(self): def get_queue_length(self):
""" """
Get the current queue length for model processing. Get the current queue length for model processing.
@@ -140,13 +150,32 @@ class ModelWorker:
logger.error(f"Shape generation failed: {e}") logger.error(f"Shape generation failed: {e}")
raise ValueError(f"Failed to generate 3D mesh: {str(e)}") raise ValueError(f"Failed to generate 3D mesh: {str(e)}")
# Export final mesh without texture # Export initial mesh without texture
file_type = params.get('type', 'glb') file_type = params.get('type', 'glb')
save_path = os.path.join(self.save_dir, f'{str(uid)}.{file_type}') initial_save_path = os.path.join(self.save_dir, f'{str(uid)}_initial.{file_type}')
mesh.export(save_path) mesh.export(initial_save_path)
# Generate textured mesh (matching demo.py)
try:
output_mesh_path = os.path.join(self.save_dir, f'{str(uid)}_textured.{file_type}')
textured_path = self.paint_pipeline(
mesh_path=initial_save_path,
image_path=image,
output_mesh_path=output_mesh_path
)
logger.info("---Texture generation takes %s seconds ---" % (time.time() - start_time))
# Use the textured GLB as the final output
final_save_path = textured_path.replace('.obj', '.glb') if textured_path.endswith('.obj') else textured_path
except Exception as e:
logger.error(f"Texture generation failed: {e}")
# Fall back to untextured mesh if texture generation fails
final_save_path = initial_save_path
logger.warning(f"Using untextured mesh as fallback: {final_save_path}")
if self.low_vram_mode: if self.low_vram_mode:
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time)) logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
return save_path, uid return final_save_path, uid

View File

@@ -11,12 +11,39 @@ import json
from PIL import Image from PIL import Image
import io import io
import time import time
import os
# API server URL (adjust as needed) # API server URL (adjust as needed)
API_BASE_URL = "http://localhost:8081" API_BASE_URL = "http://localhost:8081"
def load_demo_image():
"""Load the demo image from assets/demo.png"""
try:
# Load the demo image
image_path = 'assets/demo.png'
image = Image.open(image_path).convert("RGBA")
# Convert to base64
buffer = io.BytesIO()
image.save(buffer, format='PNG')
img_base64 = base64.b64encode(buffer.getvalue()).decode()
print(f"Loaded demo image from: {image_path}")
print(f"Image size: {image.size}")
print(f"Image mode: {image.mode}")
return img_base64
except FileNotFoundError:
print(f"Error: Demo image not found at {image_path}")
print("Creating fallback test image...")
return create_test_image()
except Exception as e:
print(f"Error loading demo image: {e}")
print("Creating fallback test image...")
return create_test_image()
def create_test_image(): def create_test_image():
"""Create a simple test image for API testing""" """Create a simple test image for API testing (fallback)"""
# Create a simple 256x256 test image # Create a simple 256x256 test image
img = Image.new('RGB', (256, 256), color='red') img = Image.new('RGB', (256, 256), color='red')
@@ -27,15 +54,41 @@ def create_test_image():
return img_base64 return img_base64
def save_glb_file(response, filename):
"""Save GLB file from response content"""
try:
with open(filename, 'wb') as f:
f.write(response.content)
print(f"GLB file saved as: {filename}")
return True
except Exception as e:
print(f"Error saving GLB file: {e}")
return False
def save_base64_glb(base64_data, filename):
"""Save GLB file from base64 encoded data"""
try:
# Decode base64 data
glb_data = base64.b64decode(base64_data)
# Save to file
with open(filename, 'wb') as f:
f.write(glb_data)
print(f"GLB file saved as: {filename}")
return True
except Exception as e:
print(f"Error saving GLB file from base64: {e}")
return False
def test_generation_request(): def test_generation_request():
"""Test the generation request with simplified parameters""" """Test the generation request with simplified parameters"""
print("Creating test image...") print("Loading demo image...")
# Create test image # Load demo image
test_image = create_test_image() demo_image = load_demo_image()
# Simplified request payload with only the parameters the worker actually uses # Simplified request payload with only the parameters the worker actually uses
request_data = { request_data = {
"image": test_image, "image": demo_image,
"type": "glb" "type": "glb"
} }
@@ -43,7 +96,7 @@ def test_generation_request():
print("Request parameters:") print("Request parameters:")
for key, value in request_data.items(): for key, value in request_data.items():
if key == "image": if key == "image":
print(f" {key}: [base64 encoded image data]") print(f" {key}: [base64 encoded demo image data]")
else: else:
print(f" {key}: {value}") print(f" {key}: {value}")
@@ -52,6 +105,14 @@ def test_generation_request():
print(f"Response status: {response.status_code}") print(f"Response status: {response.status_code}")
if response.status_code == 200: if response.status_code == 200:
print("Success! Generated 3D model file received.") print("Success! Generated 3D model file received.")
# Save the GLB file
timestamp = int(time.time())
filename = f"generated_model_{timestamp}.glb"
if save_glb_file(response, filename):
print(f"Model saved successfully to: {filename}")
else:
print("Failed to save model file")
else: else:
print(f"Error: {response.text}") print(f"Error: {response.text}")
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
@@ -60,10 +121,10 @@ def test_generation_request():
def test_async_generation(): def test_async_generation():
"""Test the asynchronous generation endpoint""" """Test the asynchronous generation endpoint"""
test_image = create_test_image() demo_image = load_demo_image()
request_data = { request_data = {
"image": test_image, "image": demo_image,
"type": "glb" "type": "glb"
} }
@@ -80,6 +141,7 @@ def test_async_generation():
print("Checking task status...") print("Checking task status...")
status_response = requests.get(f"{API_BASE_URL}/status/{uid}") status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
print(f"Status: {status_response.json()}") print(f"Status: {status_response.json()}")
# Poll status until completed # Poll status until completed
while True: while True:
status_response = requests.get(f"{API_BASE_URL}/status/{uid}") status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
@@ -88,7 +150,18 @@ def test_async_generation():
if status_data['status'] == 'completed': if status_data['status'] == 'completed':
print("Generation completed!") print("Generation completed!")
print("Model data received in base64 format")
# Save the GLB file from base64 data
model_base64 = status_data.get('model_base64')
if model_base64:
timestamp = int(time.time())
filename = f"async_generated_model_{uid}_{timestamp}.glb"
if save_base64_glb(model_base64, filename):
print(f"Model saved successfully to: {filename}")
else:
print("Failed to save model file")
else:
print("No model data received in response")
break break
elif status_data['status'] == 'error': elif status_data['status'] == 'error':
print(f"Error: {status_data.get('message', 'Unknown error')}") print(f"Error: {status_data.get('message', 'Unknown error')}")
@@ -155,13 +228,18 @@ def show_api_documentation_info():
print(" - image: Base64 encoded input image (required)") print(" - image: Base64 encoded input image (required)")
print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')") print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')")
print() print()
print("7. File Saving:")
print(" - GLB files are automatically saved with timestamps")
print(" - Direct generation saves as: generated_model_{timestamp}.glb")
print(" - Async generation saves as: async_generated_model_{uid}_{timestamp}.glb")
print()
if __name__ == "__main__": if __name__ == "__main__":
show_api_documentation_info() show_api_documentation_info()
# Run tests if server is available # Run tests if server is available
test_health_check() test_health_check()
#test_generation_request() test_generation_request()
test_async_generation() test_async_generation()
print("\n" + "=" * 60) print("\n" + "=" * 60)
@@ -169,4 +247,5 @@ if __name__ == "__main__":
print(f"1. Start the API server: python api_server.py") print(f"1. Start the API server: python api_server.py")
print(f"2. Open your browser to: {API_BASE_URL}/docs") print(f"2. Open your browser to: {API_BASE_URL}/docs")
print("3. Explore the documented endpoints and parameters") print("3. Explore the documented endpoints and parameters")
print("4. Generated GLB files will be saved in the current directory")
print("=" * 60) print("=" * 60)