next
This commit is contained in:
@@ -26,6 +26,7 @@ except Exception as e:
|
|||||||
from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline
|
from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline
|
||||||
from hy3dshape.rembg import BackgroundRemover
|
from hy3dshape.rembg import BackgroundRemover
|
||||||
from hy3dshape.utils import logger
|
from hy3dshape.utils import logger
|
||||||
|
from textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig
|
||||||
|
|
||||||
|
|
||||||
def load_image_from_base64(image):
|
def load_image_from_base64(image):
|
||||||
@@ -81,6 +82,15 @@ class ModelWorker:
|
|||||||
# Initialize shape generation pipeline (matching demo.py)
|
# Initialize shape generation pipeline (matching demo.py)
|
||||||
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
||||||
|
|
||||||
|
# Initialize texture generation pipeline (matching demo.py)
|
||||||
|
max_num_view = 6 # can be 6 to 9
|
||||||
|
resolution = 512 # can be 768 or 512
|
||||||
|
conf = Hunyuan3DPaintConfig(max_num_view, resolution)
|
||||||
|
conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth"
|
||||||
|
conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml"
|
||||||
|
conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr"
|
||||||
|
self.paint_pipeline = Hunyuan3DPaintPipeline(conf)
|
||||||
|
|
||||||
def get_queue_length(self):
|
def get_queue_length(self):
|
||||||
"""
|
"""
|
||||||
Get the current queue length for model processing.
|
Get the current queue length for model processing.
|
||||||
@@ -140,13 +150,32 @@ class ModelWorker:
|
|||||||
logger.error(f"Shape generation failed: {e}")
|
logger.error(f"Shape generation failed: {e}")
|
||||||
raise ValueError(f"Failed to generate 3D mesh: {str(e)}")
|
raise ValueError(f"Failed to generate 3D mesh: {str(e)}")
|
||||||
|
|
||||||
# Export final mesh without texture
|
# Export initial mesh without texture
|
||||||
file_type = params.get('type', 'glb')
|
file_type = params.get('type', 'glb')
|
||||||
save_path = os.path.join(self.save_dir, f'{str(uid)}.{file_type}')
|
initial_save_path = os.path.join(self.save_dir, f'{str(uid)}_initial.{file_type}')
|
||||||
mesh.export(save_path)
|
mesh.export(initial_save_path)
|
||||||
|
|
||||||
|
# Generate textured mesh (matching demo.py)
|
||||||
|
try:
|
||||||
|
output_mesh_path = os.path.join(self.save_dir, f'{str(uid)}_textured.{file_type}')
|
||||||
|
textured_path = self.paint_pipeline(
|
||||||
|
mesh_path=initial_save_path,
|
||||||
|
image_path=image,
|
||||||
|
output_mesh_path=output_mesh_path
|
||||||
|
)
|
||||||
|
logger.info("---Texture generation takes %s seconds ---" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Use the textured GLB as the final output
|
||||||
|
final_save_path = textured_path.replace('.obj', '.glb') if textured_path.endswith('.obj') else textured_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Texture generation failed: {e}")
|
||||||
|
# Fall back to untextured mesh if texture generation fails
|
||||||
|
final_save_path = initial_save_path
|
||||||
|
logger.warning(f"Using untextured mesh as fallback: {final_save_path}")
|
||||||
|
|
||||||
if self.low_vram_mode:
|
if self.low_vram_mode:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
|
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
|
||||||
return save_path, uid
|
return final_save_path, uid
|
||||||
@@ -11,12 +11,39 @@ import json
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
# API server URL (adjust as needed)
|
# API server URL (adjust as needed)
|
||||||
API_BASE_URL = "http://localhost:8081"
|
API_BASE_URL = "http://localhost:8081"
|
||||||
|
|
||||||
|
def load_demo_image():
|
||||||
|
"""Load the demo image from assets/demo.png"""
|
||||||
|
try:
|
||||||
|
# Load the demo image
|
||||||
|
image_path = 'assets/demo.png'
|
||||||
|
image = Image.open(image_path).convert("RGBA")
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
image.save(buffer, format='PNG')
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
||||||
|
|
||||||
|
print(f"Loaded demo image from: {image_path}")
|
||||||
|
print(f"Image size: {image.size}")
|
||||||
|
print(f"Image mode: {image.mode}")
|
||||||
|
|
||||||
|
return img_base64
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: Demo image not found at {image_path}")
|
||||||
|
print("Creating fallback test image...")
|
||||||
|
return create_test_image()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading demo image: {e}")
|
||||||
|
print("Creating fallback test image...")
|
||||||
|
return create_test_image()
|
||||||
|
|
||||||
def create_test_image():
|
def create_test_image():
|
||||||
"""Create a simple test image for API testing"""
|
"""Create a simple test image for API testing (fallback)"""
|
||||||
# Create a simple 256x256 test image
|
# Create a simple 256x256 test image
|
||||||
img = Image.new('RGB', (256, 256), color='red')
|
img = Image.new('RGB', (256, 256), color='red')
|
||||||
|
|
||||||
@@ -27,15 +54,41 @@ def create_test_image():
|
|||||||
|
|
||||||
return img_base64
|
return img_base64
|
||||||
|
|
||||||
|
def save_glb_file(response, filename):
|
||||||
|
"""Save GLB file from response content"""
|
||||||
|
try:
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
f.write(response.content)
|
||||||
|
print(f"GLB file saved as: {filename}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving GLB file: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_base64_glb(base64_data, filename):
|
||||||
|
"""Save GLB file from base64 encoded data"""
|
||||||
|
try:
|
||||||
|
# Decode base64 data
|
||||||
|
glb_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
f.write(glb_data)
|
||||||
|
print(f"GLB file saved as: {filename}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving GLB file from base64: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def test_generation_request():
|
def test_generation_request():
|
||||||
"""Test the generation request with simplified parameters"""
|
"""Test the generation request with simplified parameters"""
|
||||||
print("Creating test image...")
|
print("Loading demo image...")
|
||||||
# Create test image
|
# Load demo image
|
||||||
test_image = create_test_image()
|
demo_image = load_demo_image()
|
||||||
|
|
||||||
# Simplified request payload with only the parameters the worker actually uses
|
# Simplified request payload with only the parameters the worker actually uses
|
||||||
request_data = {
|
request_data = {
|
||||||
"image": test_image,
|
"image": demo_image,
|
||||||
"type": "glb"
|
"type": "glb"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +96,7 @@ def test_generation_request():
|
|||||||
print("Request parameters:")
|
print("Request parameters:")
|
||||||
for key, value in request_data.items():
|
for key, value in request_data.items():
|
||||||
if key == "image":
|
if key == "image":
|
||||||
print(f" {key}: [base64 encoded image data]")
|
print(f" {key}: [base64 encoded demo image data]")
|
||||||
else:
|
else:
|
||||||
print(f" {key}: {value}")
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
@@ -52,6 +105,14 @@ def test_generation_request():
|
|||||||
print(f"Response status: {response.status_code}")
|
print(f"Response status: {response.status_code}")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("Success! Generated 3D model file received.")
|
print("Success! Generated 3D model file received.")
|
||||||
|
|
||||||
|
# Save the GLB file
|
||||||
|
timestamp = int(time.time())
|
||||||
|
filename = f"generated_model_{timestamp}.glb"
|
||||||
|
if save_glb_file(response, filename):
|
||||||
|
print(f"Model saved successfully to: {filename}")
|
||||||
|
else:
|
||||||
|
print("Failed to save model file")
|
||||||
else:
|
else:
|
||||||
print(f"Error: {response.text}")
|
print(f"Error: {response.text}")
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
@@ -60,10 +121,10 @@ def test_generation_request():
|
|||||||
def test_async_generation():
|
def test_async_generation():
|
||||||
"""Test the asynchronous generation endpoint"""
|
"""Test the asynchronous generation endpoint"""
|
||||||
|
|
||||||
test_image = create_test_image()
|
demo_image = load_demo_image()
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"image": test_image,
|
"image": demo_image,
|
||||||
"type": "glb"
|
"type": "glb"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +141,7 @@ def test_async_generation():
|
|||||||
print("Checking task status...")
|
print("Checking task status...")
|
||||||
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
|
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
|
||||||
print(f"Status: {status_response.json()}")
|
print(f"Status: {status_response.json()}")
|
||||||
|
|
||||||
# Poll status until completed
|
# Poll status until completed
|
||||||
while True:
|
while True:
|
||||||
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
|
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
|
||||||
@@ -88,7 +150,18 @@ def test_async_generation():
|
|||||||
|
|
||||||
if status_data['status'] == 'completed':
|
if status_data['status'] == 'completed':
|
||||||
print("Generation completed!")
|
print("Generation completed!")
|
||||||
print("Model data received in base64 format")
|
|
||||||
|
# Save the GLB file from base64 data
|
||||||
|
model_base64 = status_data.get('model_base64')
|
||||||
|
if model_base64:
|
||||||
|
timestamp = int(time.time())
|
||||||
|
filename = f"async_generated_model_{uid}_{timestamp}.glb"
|
||||||
|
if save_base64_glb(model_base64, filename):
|
||||||
|
print(f"Model saved successfully to: {filename}")
|
||||||
|
else:
|
||||||
|
print("Failed to save model file")
|
||||||
|
else:
|
||||||
|
print("No model data received in response")
|
||||||
break
|
break
|
||||||
elif status_data['status'] == 'error':
|
elif status_data['status'] == 'error':
|
||||||
print(f"Error: {status_data.get('message', 'Unknown error')}")
|
print(f"Error: {status_data.get('message', 'Unknown error')}")
|
||||||
@@ -155,13 +228,18 @@ def show_api_documentation_info():
|
|||||||
print(" - image: Base64 encoded input image (required)")
|
print(" - image: Base64 encoded input image (required)")
|
||||||
print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')")
|
print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')")
|
||||||
print()
|
print()
|
||||||
|
print("7. File Saving:")
|
||||||
|
print(" - GLB files are automatically saved with timestamps")
|
||||||
|
print(" - Direct generation saves as: generated_model_{timestamp}.glb")
|
||||||
|
print(" - Async generation saves as: async_generated_model_{uid}_{timestamp}.glb")
|
||||||
|
print()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
show_api_documentation_info()
|
show_api_documentation_info()
|
||||||
|
|
||||||
# Run tests if server is available
|
# Run tests if server is available
|
||||||
test_health_check()
|
test_health_check()
|
||||||
#test_generation_request()
|
test_generation_request()
|
||||||
test_async_generation()
|
test_async_generation()
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
@@ -169,4 +247,5 @@ if __name__ == "__main__":
|
|||||||
print(f"1. Start the API server: python api_server.py")
|
print(f"1. Start the API server: python api_server.py")
|
||||||
print(f"2. Open your browser to: {API_BASE_URL}/docs")
|
print(f"2. Open your browser to: {API_BASE_URL}/docs")
|
||||||
print("3. Explore the documented endpoints and parameters")
|
print("3. Explore the documented endpoints and parameters")
|
||||||
|
print("4. Generated GLB files will be saved in the current directory")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
Reference in New Issue
Block a user