diff --git a/model_worker.py b/model_worker.py index 1e0abf9..cc2dc41 100644 --- a/model_worker.py +++ b/model_worker.py @@ -26,6 +26,7 @@ except Exception as e: from hy3dshape import Hunyuan3DDiTFlowMatchingPipeline from hy3dshape.rembg import BackgroundRemover from hy3dshape.utils import logger +from textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig def load_image_from_base64(image): @@ -80,6 +81,15 @@ class ModelWorker: # Initialize shape generation pipeline (matching demo.py) self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path) + + # Initialize texture generation pipeline (matching demo.py) + max_num_view = 6 # can be 6 to 9 + resolution = 512 # can be 768 or 512 + conf = Hunyuan3DPaintConfig(max_num_view, resolution) + conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth" + conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml" + conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr" + self.paint_pipeline = Hunyuan3DPaintPipeline(conf) def get_queue_length(self): """ @@ -140,13 +150,32 @@ class ModelWorker: logger.error(f"Shape generation failed: {e}") raise ValueError(f"Failed to generate 3D mesh: {str(e)}") - # Export final mesh without texture + # Export initial mesh without texture file_type = params.get('type', 'glb') - save_path = os.path.join(self.save_dir, f'{str(uid)}.{file_type}') - mesh.export(save_path) + initial_save_path = os.path.join(self.save_dir, f'{str(uid)}_initial.{file_type}') + mesh.export(initial_save_path) + + # Generate textured mesh (matching demo.py) + try: + output_mesh_path = os.path.join(self.save_dir, f'{str(uid)}_textured.{file_type}') + textured_path = self.paint_pipeline( + mesh_path=initial_save_path, + image_path=image, + output_mesh_path=output_mesh_path + ) + logger.info("---Texture generation takes %s seconds ---" % (time.time() - start_time)) + + # Use the textured GLB as the final output + final_save_path = textured_path.replace('.obj', '.glb') if textured_path.endswith('.obj') else textured_path + + except Exception as e: + logger.error(f"Texture generation failed: {e}") + # Fall back to untextured mesh if texture generation fails + final_save_path = initial_save_path + logger.warning(f"Using untextured mesh as fallback: {final_save_path}") if self.low_vram_mode: torch.cuda.empty_cache() logger.info("---Total generation takes %s seconds ---" % (time.time() - start_time)) - return save_path, uid \ No newline at end of file + return final_save_path, uid \ No newline at end of file diff --git a/test_api_server.py b/test_api_server.py index d97ed77..bb21d40 100644 --- a/test_api_server.py +++ b/test_api_server.py @@ -11,12 +11,39 @@ import json from PIL import Image import io import time +import os # API server URL (adjust as needed) API_BASE_URL = "http://localhost:8081" +def load_demo_image(): + """Load the demo image from assets/demo.png""" + try: + # Load the demo image + image_path = 'assets/demo.png' + image = Image.open(image_path).convert("RGBA") + + # Convert to base64 + buffer = io.BytesIO() + image.save(buffer, format='PNG') + img_base64 = base64.b64encode(buffer.getvalue()).decode() + + print(f"Loaded demo image from: {image_path}") + print(f"Image size: {image.size}") + print(f"Image mode: {image.mode}") + + return img_base64 + except FileNotFoundError: + print(f"Error: Demo image not found at {image_path}") + print("Creating fallback test image...") + return create_test_image() + except Exception as e: + print(f"Error loading demo image: {e}") + print("Creating fallback test image...") + return create_test_image() + def create_test_image(): - """Create a simple test image for API testing""" + """Create a simple test image for API testing (fallback)""" # Create a simple 256x256 test image img = Image.new('RGB', (256, 256), color='red') @@ -27,15 +54,41 @@ def create_test_image(): return img_base64 +def save_glb_file(response, filename): + """Save GLB file from response content""" + try: + with open(filename, 'wb') as f: + f.write(response.content) + print(f"GLB file saved as: {filename}") + return True + except Exception as e: + print(f"Error saving GLB file: {e}") + return False + +def save_base64_glb(base64_data, filename): + """Save GLB file from base64 encoded data""" + try: + # Decode base64 data + glb_data = base64.b64decode(base64_data) + + # Save to file + with open(filename, 'wb') as f: + f.write(glb_data) + print(f"GLB file saved as: {filename}") + return True + except Exception as e: + print(f"Error saving GLB file from base64: {e}") + return False + def test_generation_request(): """Test the generation request with simplified parameters""" - print("Creating test image...") - # Create test image - test_image = create_test_image() + print("Loading demo image...") + # Load demo image + demo_image = load_demo_image() # Simplified request payload with only the parameters the worker actually uses request_data = { - "image": test_image, + "image": demo_image, "type": "glb" } @@ -43,7 +96,7 @@ def test_generation_request(): print("Request parameters:") for key, value in request_data.items(): if key == "image": - print(f" {key}: [base64 encoded image data]") + print(f" {key}: [base64 encoded demo image data]") else: print(f" {key}: {value}") @@ -52,6 +105,14 @@ def test_generation_request(): print(f"Response status: {response.status_code}") if response.status_code == 200: print("Success! Generated 3D model file received.") + + # Save the GLB file + timestamp = int(time.time()) + filename = f"generated_model_{timestamp}.glb" + if save_glb_file(response, filename): + print(f"Model saved successfully to: {filename}") + else: + print("Failed to save model file") else: print(f"Error: {response.text}") except requests.exceptions.ConnectionError: @@ -60,10 +121,10 @@ def test_generation_request(): def test_async_generation(): """Test the asynchronous generation endpoint""" - test_image = create_test_image() + demo_image = load_demo_image() request_data = { - "image": test_image, + "image": demo_image, "type": "glb" } @@ -80,6 +141,7 @@ def test_async_generation(): print("Checking task status...") status_response = requests.get(f"{API_BASE_URL}/status/{uid}") print(f"Status: {status_response.json()}") + # Poll status until completed while True: status_response = requests.get(f"{API_BASE_URL}/status/{uid}") @@ -88,7 +150,18 @@ def test_async_generation(): if status_data['status'] == 'completed': print("Generation completed!") - print("Model data received in base64 format") + + # Save the GLB file from base64 data + model_base64 = status_data.get('model_base64') + if model_base64: + timestamp = int(time.time()) + filename = f"async_generated_model_{uid}_{timestamp}.glb" + if save_base64_glb(model_base64, filename): + print(f"Model saved successfully to: {filename}") + else: + print("Failed to save model file") + else: + print("No model data received in response") break elif status_data['status'] == 'error': print(f"Error: {status_data.get('message', 'Unknown error')}") @@ -155,13 +228,18 @@ def show_api_documentation_info(): print(" - image: Base64 encoded input image (required)") print(" - type: Output file format - 'glb' or 'obj' (optional, default: 'glb')") print() + print("7. File Saving:") + print(" - GLB files are automatically saved with timestamps") + print(" - Direct generation saves as: generated_model_{timestamp}.glb") + print(" - Async generation saves as: async_generated_model_{uid}_{timestamp}.glb") + print() if __name__ == "__main__": show_api_documentation_info() # Run tests if server is available test_health_check() - #test_generation_request() + test_generation_request() test_async_generation() print("\n" + "=" * 60) @@ -169,4 +247,5 @@ if __name__ == "__main__": print(f"1. Start the API server: python api_server.py") print(f"2. Open your browser to: {API_BASE_URL}/docs") print("3. Explore the documented endpoints and parameters") + print("4. Generated GLB files will be saved in the current directory") print("=" * 60) \ No newline at end of file