next
This commit is contained in:
@@ -26,6 +26,7 @@ except Exception as e:
|
||||
from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline
|
||||
from hy3dshape.rembg import BackgroundRemover
|
||||
from hy3dshape.utils import logger
|
||||
from textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig
|
||||
|
||||
|
||||
def load_image_from_base64(image):
|
||||
@@ -80,6 +81,15 @@ class ModelWorker:
|
||||
|
||||
# Initialize shape generation pipeline (matching demo.py)
|
||||
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
||||
|
||||
# Initialize texture generation pipeline (matching demo.py)
|
||||
max_num_view = 6 # can be 6 to 9
|
||||
resolution = 512 # can be 768 or 512
|
||||
conf = Hunyuan3DPaintConfig(max_num_view, resolution)
|
||||
conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth"
|
||||
conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml"
|
||||
conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr"
|
||||
self.paint_pipeline = Hunyuan3DPaintPipeline(conf)
|
||||
|
||||
def get_queue_length(self):
|
||||
"""
|
||||
@@ -140,13 +150,32 @@ class ModelWorker:
|
||||
logger.error(f"Shape generation failed: {e}")
|
||||
raise ValueError(f"Failed to generate 3D mesh: {str(e)}")
|
||||
|
||||
# Export final mesh without texture
|
||||
# Export initial mesh without texture
|
||||
file_type = params.get('type', 'glb')
|
||||
save_path = os.path.join(self.save_dir, f'{str(uid)}.{file_type}')
|
||||
mesh.export(save_path)
|
||||
initial_save_path = os.path.join(self.save_dir, f'{str(uid)}_initial.{file_type}')
|
||||
mesh.export(initial_save_path)
|
||||
|
||||
# Generate textured mesh (matching demo.py)
|
||||
try:
|
||||
output_mesh_path = os.path.join(self.save_dir, f'{str(uid)}_textured.{file_type}')
|
||||
textured_path = self.paint_pipeline(
|
||||
mesh_path=initial_save_path,
|
||||
image_path=image,
|
||||
output_mesh_path=output_mesh_path
|
||||
)
|
||||
logger.info("---Texture generation takes %s seconds ---" % (time.time() - start_time))
|
||||
|
||||
# Use the textured GLB as the final output
|
||||
final_save_path = textured_path.replace('.obj', '.glb') if textured_path.endswith('.obj') else textured_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Texture generation failed: {e}")
|
||||
# Fall back to untextured mesh if texture generation fails
|
||||
final_save_path = initial_save_path
|
||||
logger.warning(f"Using untextured mesh as fallback: {final_save_path}")
|
||||
|
||||
if self.low_vram_mode:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time))
|
||||
return save_path, uid
|
||||
return final_save_path, uid
|
||||
@@ -11,12 +11,39 @@ import json
|
||||
from PIL import Image
|
||||
import io
|
||||
import time
|
||||
import os
|
||||
|
||||
# API server URL (adjust as needed)
|
||||
API_BASE_URL = "http://localhost:8081"
|
||||
|
||||
def load_demo_image():
|
||||
"""Load the demo image from assets/demo.png"""
|
||||
try:
|
||||
# Load the demo image
|
||||
image_path = 'assets/demo.png'
|
||||
image = Image.open(image_path).convert("RGBA")
|
||||
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format='PNG')
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
print(f"Loaded demo image from: {image_path}")
|
||||
print(f"Image size: {image.size}")
|
||||
print(f"Image mode: {image.mode}")
|
||||
|
||||
return img_base64
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Demo image not found at {image_path}")
|
||||
print("Creating fallback test image...")
|
||||
return create_test_image()
|
||||
except Exception as e:
|
||||
print(f"Error loading demo image: {e}")
|
||||
print("Creating fallback test image...")
|
||||
return create_test_image()
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image for API testing"""
|
||||
"""Create a simple test image for API testing (fallback)"""
|
||||
# Create a simple 256x256 test image
|
||||
img = Image.new('RGB', (256, 256), color='red')
|
||||
|
||||
@@ -27,15 +54,41 @@ def create_test_image():
|
||||
|
||||
return img_base64
|
||||
|
||||
def save_glb_file(response, filename):
|
||||
"""Save GLB file from response content"""
|
||||
try:
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(response.content)
|
||||
print(f"GLB file saved as: {filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error saving GLB file: {e}")
|
||||
return False
|
||||
|
||||
def save_base64_glb(base64_data, filename):
|
||||
"""Save GLB file from base64 encoded data"""
|
||||
try:
|
||||
# Decode base64 data
|
||||
glb_data = base64.b64decode(base64_data)
|
||||
|
||||
# Save to file
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(glb_data)
|
||||
print(f"GLB file saved as: {filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error saving GLB file from base64: {e}")
|
||||
return False
|
||||
|
||||
def test_generation_request():
|
||||
"""Test the generation request with simplified parameters"""
|
||||
print("Creating test image...")
|
||||
# Create test image
|
||||
test_image = create_test_image()
|
||||
print("Loading demo image...")
|
||||
# Load demo image
|
||||
demo_image = load_demo_image()
|
||||
|
||||
# Simplified request payload with only the parameters the worker actually uses
|
||||
request_data = {
|
||||
"image": test_image,
|
||||
"image": demo_image,
|
||||
"type": "glb"
|
||||
}
|
||||
|
||||
@@ -43,7 +96,7 @@ def test_generation_request():
|
||||
print("Request parameters:")
|
||||
for key, value in request_data.items():
|
||||
if key == "image":
|
||||
print(f" {key}: [base64 encoded image data]")
|
||||
print(f" {key}: [base64 encoded demo image data]")
|
||||
else:
|
||||
print(f" {key}: {value}")
|
||||
|
||||
@@ -52,6 +105,14 @@ def test_generation_request():
|
||||
print(f"Response status: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
print("Success! Generated 3D model file received.")
|
||||
|
||||
# Save the GLB file
|
||||
timestamp = int(time.time())
|
||||
filename = f"generated_model_{timestamp}.glb"
|
||||
if save_glb_file(response, filename):
|
||||
print(f"Model saved successfully to: {filename}")
|
||||
else:
|
||||
print("Failed to save model file")
|
||||
else:
|
||||
print(f"Error: {response.text}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
@@ -60,10 +121,10 @@ def test_generation_request():
|
||||
def test_async_generation():
|
||||
"""Test the asynchronous generation endpoint"""
|
||||
|
||||
test_image = create_test_image()
|
||||
demo_image = load_demo_image()
|
||||
|
||||
request_data = {
|
||||
"image": test_image,
|
||||
"image": demo_image,
|
||||
"type": "glb"
|
||||
}
|
||||
|
||||
@@ -80,6 +141,7 @@ def test_async_generation():
|
||||
print("Checking task status...")
|
||||
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
|
||||
print(f"Status: {status_response.json()}")
|
||||
|
||||
# Poll status until completed
|
||||
while True:
|
||||
status_response = requests.get(f"{API_BASE_URL}/status/{uid}")
|
||||
@@ -88,7 +150,18 @@ def test_async_generation():
|
||||
|
||||
if status_data['status'] == 'completed':
|
||||
print("Generation completed!")
|
||||
print("Model data received in base64 format")
|
||||
|
||||
# Save the GLB file from base64 data
|
||||
model_base64 = status_data.get('model_base64')
|
||||
if model_base64:
|
||||
timestamp = int(time.time())
|
||||
filename = f"async_generated_model_{uid}_{timestamp}.glb"
|
||||
if save_base64_glb(model_base64, filename):
|
||||
print(f"Model saved successfully to: {filename}")
|
||||
else:
|
||||
print("Failed to save model file")
|
||||
else:
|
||||
print("No model data received in response")
|
||||
break
|
||||
elif status_data['status'] == 'error':
|
||||
print(f"Error: {status_data.get('message', 'Unknown error')}")
|
||||
@@ -155,13 +228,18 @@ def show_api_documentation_info():
|
||||
print(" - image: Base64 encoded input image (required)")
|
||||
print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')")
|
||||
print()
|
||||
print("7. File Saving:")
|
||||
print(" - GLB files are automatically saved with timestamps")
|
||||
print(" - Direct generation saves as: generated_model_{timestamp}.glb")
|
||||
print(" - Async generation saves as: async_generated_model_{uid}_{timestamp}.glb")
|
||||
print()
|
||||
|
||||
if __name__ == "__main__":
|
||||
show_api_documentation_info()
|
||||
|
||||
# Run tests if server is available
|
||||
test_health_check()
|
||||
#test_generation_request()
|
||||
test_generation_request()
|
||||
test_async_generation()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
@@ -169,4 +247,5 @@ if __name__ == "__main__":
|
||||
print(f"1. Start the API server: python api_server.py")
|
||||
print(f"2. Open your browser to: {API_BASE_URL}/docs")
|
||||
print("3. Explore the documented endpoints and parameters")
|
||||
print("4. Generated GLB files will be saved in the current directory")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user