This commit is contained in:
SyncTwin GmbH
2025-06-22 11:23:06 +02:00
parent 5ffc16c1f6
commit 6693141004
2 changed files with 122 additions and 14 deletions

View File

@@ -26,6 +26,7 @@ except Exception as e:
from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline
from hy3dshape.rembg import BackgroundRemover
from hy3dshape.utils import logger
from textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig
def load_image_from_base64(image):
@@ -80,6 +81,15 @@ class ModelWorker:
# Initialize shape generation pipeline (matching demo.py)
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
# Initialize texture generation pipeline (matching demo.py)
max_num_view = 6 # can be 6 to 9
resolution = 512 # can be 768 or 512
conf = Hunyuan3DPaintConfig(max_num_view, resolution)
conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth"
conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml"
conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr"
self.paint_pipeline = Hunyuan3DPaintPipeline(conf)
def get_queue_length(self):
"""
@@ -140,13 +150,32 @@ class ModelWorker:
logger.error(f"Shape generation failed: {e}")
raise ValueError(f"Failed to generate 3D mesh: {str(e)}")
# Export final mesh without texture
# Export initial mesh without texture
file_type = params.get('type', 'glb')
save_path = os.path.join(self.save_dir, f'{str(uid)}.{file_type}')
mesh.export(save_path)
initial_save_path = os.path.join(self.save_dir, f'{str(uid)}_initial.{file_type}')
mesh.export(initial_save_path)
# Generate textured mesh (matching demo.py)
try:
output_mesh_path = os.path.join(self.save_dir, f'{str(uid)}_textured.{file_type}')
textured_path = self.paint_pipeline(
mesh_path=initial_save_path,
image_path=image,
output_mesh_path=output_mesh_path
)
logger.info("---Texture generation takes %s seconds ---" % (time.time() - start_time))
# Use the textured GLB as the final output
final_save_path = textured_path.replace('.obj', '.glb') if textured_path.endswith('.obj') else textured_path
except Exception as e:
logger.error(f"Texture generation failed: {e}")
# Fall back to untextured mesh if texture generation fails
final_save_path = initial_save_path
logger.warning(f"Using untextured mesh as fallback: {final_save_path}")
if self.low_vram_mode:
torch.cuda.empty_cache()
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
return save_path, uid
return final_save_path, uid

View File

@@ -11,12 +11,39 @@ import json
from PIL import Image
import io
import time
import os
# API server URL (adjust as needed)
API_BASE_URL = "http://localhost:8081"
def load_demo_image():
"""Load the demo image from assets/demo.png"""
try:
# Load the demo image
image_path = 'assets/demo.png'
image = Image.open(image_path).convert("RGBA")
# Convert to base64
buffer = io.BytesIO()
image.save(buffer, format='PNG')
img_base64 = base64.b64encode(buffer.getvalue()).decode()
print(f"Loaded demo image from: {image_path}")
print(f"Image size: {image.size}")
print(f"Image mode: {image.mode}")
return img_base64
except FileNotFoundError:
print(f"Error: Demo image not found at {image_path}")
print("Creating fallback test image...")
return create_test_image()
except Exception as e:
print(f"Error loading demo image: {e}")
print("Creating fallback test image...")
return create_test_image()
def create_test_image():
"""Create a simple test image for API testing"""
"""Create a simple test image for API testing (fallback)"""
# Create a simple 256x256 test image
img = Image.new('RGB', (256, 256), color='red')
@@ -27,15 +54,41 @@ def create_test_image():
return img_base64
def save_glb_file(response, filename):
"""Save GLB file from response content"""
try:
with open(filename, 'wb') as f:
f.write(response.content)
print(f"GLB file saved as: {filename}")
return True
except Exception as e:
print(f"Error saving GLB file: {e}")
return False
def save_base64_glb(base64_data, filename):
"""Save GLB file from base64 encoded data"""
try:
# Decode base64 data
glb_data = base64.b64decode(base64_data)
# Save to file
with open(filename, 'wb') as f:
f.write(glb_data)
print(f"GLB file saved as: {filename}")
return True
except Exception as e:
print(f"Error saving GLB file from base64: {e}")
return False
def test_generation_request():
"""Test the generation request with simplified parameters"""
print("Creating test image...")
# Create test image
test_image = create_test_image()
print("Loading demo image...")
# Load demo image
demo_image = load_demo_image()
# Simplified request payload with only the parameters the worker actually uses
request_data = {
"image": test_image,
"image": demo_image,
"type": "glb"
}
@@ -43,7 +96,7 @@ def test_generation_request():
print("Request parameters:")
for key, value in request_data.items():
if key == "image":
print(f" {key}: [base64 encoded image data]")
print(f" {key}: [base64 encoded demo image data]")
else:
print(f" {key}: {value}")
@@ -52,6 +105,14 @@ def test_generation_request():
print(f"Response status: {response.status_code}")
if response.status_code == 200:
print("Success! Generated 3D model file received.")
# Save the GLB file
timestamp = int(time.time())
filename = f"generated_model_{timestamp}.glb"
if save_glb_file(response, filename):
print(f"Model saved successfully to: {filename}")
else:
print("Failed to save model file")
else:
print(f"Error: {response.text}")
except requests.exceptions.ConnectionError:
@@ -60,10 +121,10 @@ def test_generation_request():
def test_async_generation():
"""Test the asynchronous generation endpoint"""
test_image = create_test_image()
demo_image = load_demo_image()
request_data = {
"image": test_image,
"image": demo_image,
"type": "glb"
}
@@ -80,6 +141,7 @@ def test_async_generation():
print("Checking task status...")
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
print(f"Status: {status_response.json()}")
# Poll status until completed
while True:
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
@@ -88,7 +150,18 @@ def test_async_generation():
if status_data['status'] == 'completed':
print("Generation completed!")
print("Model data received in base64 format")
# Save the GLB file from base64 data
model_base64 = status_data.get('model_base64')
if model_base64:
timestamp = int(time.time())
filename = f"async_generated_model_{uid}_{timestamp}.glb"
if save_base64_glb(model_base64, filename):
print(f"Model saved successfully to: {filename}")
else:
print("Failed to save model file")
else:
print("No model data received in response")
break
elif status_data['status'] == 'error':
print(f"Error: {status_data.get('message', 'Unknown error')}")
@@ -155,13 +228,18 @@ def show_api_documentation_info():
print(" - image: Base64 encoded input image (required)")
print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')")
print()
print("7. File Saving:")
print(" - GLB files are automatically saved with timestamps")
print(" - Direct generation saves as: generated_model_{timestamp}.glb")
print(" - Async generation saves as: async_generated_model_{uid}_{timestamp}.glb")
print()
if __name__ == "__main__":
show_api_documentation_info()
# Run tests if server is available
test_health_check()
#test_generation_request()
test_generation_request()
test_async_generation()
print("\n" + "=" * 60)
@@ -169,4 +247,5 @@ if __name__ == "__main__":
print(f"1. Start the API server: python api_server.py")
print(f"2. Open your browser to: {API_BASE_URL}/docs")
print("3. Explore the documented endpoints and parameters")
print("4. Generated GLB files will be saved in the current directory")
print("=" * 60)