Files
Hunyuan3D_2.1_Low_VRAM/hy3dshape/hy3dshape/data/dit_asl.py
Huiwenshi c88bee648e init
2025-06-13 23:53:14 +08:00

385 lines
15 KiB
Python

# -*- coding: utf-8 -*-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import io
import sys
import time
import random
import traceback
from typing import Optional, Union, List, Tuple, Dict
import json
import glob
import cv2
import numpy as np
import trimesh
import torch
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_info
from .utils import worker_init_fn, pytorch_worker_seed, make_seed
class ResampledShards(torch.utils.data.dataset.IterableDataset):
def __init__(self, datalist, nshards=sys.maxsize, worker_seed=None, deterministic=False):
super().__init__()
self.datalist = datalist
self.nshards = nshards
# If no worker_seed provided, use pytorch_worker_seed function; else use given seed
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
self.deterministic = deterministic
self.epoch = -1
def __iter__(self):
self.epoch += 1
if self.deterministic:
seed = make_seed(self.worker_seed(), self.epoch)
else:
seed = make_seed(self.worker_seed(), self.epoch,
os.getpid(), time.time_ns(), os.urandom(4))
self.rng = random.Random(seed)
for _ in range(self.nshards):
index = self.rng.randint(0, len(self.datalist) - 1)
yield self.datalist[index]
def read_npz(data):
# Load a numpy .npz file from a file path or file-like object
# The commented line shows how to load from bytes in memory
# return np.load(io.BytesIO(data))
return np.load(data)
def read_json(path):
# Read and parse a JSON file from the given file path
with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
"""
Pad the input image and mask to a square shape with padding ratio.
Args:
image (np.ndarray): Input image array of shape (H, W, C).
mask (np.ndarray): Corresponding mask array of shape (H, W).
center (bool): Whether to center the original image in the padded output.
padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
Returns:
newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
"""
h, w = image.shape[:2]
max_side = max(h, w)
# Select padding ratio either fixed or randomly within the given range
if padding_ratio_range[0] == padding_ratio_range[1]:
padding_ratio = padding_ratio_range[0]
else:
padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
resize_side = int(max_side * padding_ratio)
# resize_side = int(max_side * 1.15)
pad_h = resize_side - h
pad_w = resize_side - w
if center:
start_h = pad_h // 2
else:
start_h = pad_h - resize_side // 20
start_w = pad_w // 2
# Create new white image and black mask with padded size
newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
# Place original image and mask into the padded canvas
newimg[start_h:start_h + h, start_w:start_w + w] = image
newmask[start_h:start_h + h, start_w:start_w + w] = mask
return newimg, newmask
def viz_pc(surface, normal, image_input, name):
image_input = image_input.cpu().numpy()
image_input = image_input.transpose(1, 2, 0) * 0.5 + 0.5
image_input = (image_input * 255).astype(np.uint8)
cv2.imwrite(name + '.png', cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
surface = surface.cpu().numpy()
normal = normal.cpu().numpy()
surface_mesh = trimesh.Trimesh(surface, vertex_colors=(normal + 1) / 2)
surface_mesh.export(name + '.obj')
class AlignedShapeLatentDataset(torch.utils.data.dataset.IterableDataset):
def __init__(
self,
data_list: str = None,
cond_stage_key: str = "image",
image_transform = None,
pc_size: int = 2048,
pc_sharpedge_size: int = 2048,
sharpedge_label: bool = False,
return_normal: bool = False,
deterministic = False,
worker_seed = None,
padding = True,
padding_ratio_range=[1.15, 1.15]
):
super().__init__()
if isinstance(data_list, str) and data_list.endswith('.json'):
self.data_list = read_json(data_list_json)
elif isinstance(data_list, str) and os.path.isdir(data_list):
self.data_list = glob.glob(data_list + '/*')
else:
self.data_list = data_list
assert isinstance(self.data_list, list)
self.rng = random.Random(0)
self.cond_stage_key = cond_stage_key
self.image_transform = image_transform
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.sharpedge_label = sharpedge_label
self.return_normal = return_normal
self.padding = padding
self.padding_ratio_range = padding_ratio_range
rank_zero_info(f'*' * 50)
rank_zero_info(f'Dataset Infos:')
rank_zero_info(f'# of 3D file: {len(self.data_list)}')
rank_zero_info(f'# of Surface Points: {self.pc_size}')
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
rank_zero_info(f'*' * 50)
def load_surface_sdf_points(self, rng, random_surface, sharpedge_surface):
surface_normal = []
if self.pc_size > 0:
ind = rng.choice(random_surface.shape[0], self.pc_size, replace=False)
random_surface = random_surface[ind]
if self.sharpedge_label:
sharpedge_label = np.zeros((self.pc_size, 1))
random_surface = np.concatenate((random_surface, sharpedge_label), axis=1)
surface_normal.append(random_surface)
if self.pc_sharpedge_size > 0:
ind_sharpedge = rng.choice(sharpedge_surface.shape[0], self.pc_sharpedge_size, replace=False)
sharpedge_surface = sharpedge_surface[ind_sharpedge]
if self.sharpedge_label:
sharpedge_label = np.ones((self.pc_sharpedge_size, 1))
sharpedge_surface = np.concatenate((sharpedge_surface, sharpedge_label), axis=1)
surface_normal.append(sharpedge_surface)
surface_normal = np.concatenate(surface_normal, axis=0)
surface_normal = torch.FloatTensor(surface_normal)
surface = surface_normal[:, 0:3]
normal = surface_normal[:, 3:6]
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
geo_points = 0.0
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
if self.return_normal:
surface = torch.cat([surface, normal], dim=-1)
if self.sharpedge_label:
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
return surface, geo_points
def load_render(self, imgs_path):
imgs_choice = self.rng.sample(imgs_path, 1)
images, masks = [], []
for image_path in imgs_choice:
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
assert image.shape[2] == 4
alpha = image[:, :, 3:4].astype(np.float32) / 255
forground = image[:, :, :3]
background = np.ones_like(forground) * 255
img_new = forground * alpha + background * (1 - alpha)
image = img_new.astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = (alpha[:, :, 0] * 255).astype(np.uint8)
if self.padding:
h, w = image.shape[:2]
binary = mask > 0.3
non_zero_coords = np.argwhere(binary)
x_min, y_min = non_zero_coords.min(axis=0)
x_max, y_max = non_zero_coords.max(axis=0)
image, mask = padding(
image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
center=True, padding_ratio_range=self.padding_ratio_range)
if self.image_transform:
image = self.image_transform(image)
mask = np.stack((mask, mask, mask), axis=-1)
mask = self.image_transform(mask)
images.append(image)
masks.append(mask)
images = torch.cat(images, dim=0)
masks = torch.cat(masks, dim=0)[:1, ...]
return images, masks
def decode(self, item):
uid = item.split('/')[-1]
render_img_paths = [os.path.join(item, f'render_cond/{i:03d}.png') for i in range(24)]
# transforms_json_path = os.path.join(item, 'render_cond/transforms.json')
surface_npz_path = os.path.join(item, f'geo_data/{uid}_surface.npz')
# sdf_npz_path = os.path.join(item, f'geo_data/{uid}_sdf.npz')
# watertight_obj_path = os.path.join(item, f'geo_data/{uid}_watertight.obj')
sample = {}
sample["image"] = render_img_paths
surface_data = read_npz(surface_npz_path)
sample["random_surface"] = surface_data['random_surface']
sample["sharpedge_surface"] = surface_data['sharp_surface']
return sample
def transform(self, sample):
rng = np.random.default_rng()
random_surface = sample.get("random_surface", 0)
sharpedge_surface = sample.get("sharpedge_surface", 0)
image_input, mask_input = self.load_render(sample['image'])
surface, geo_points = self.load_surface_sdf_points(rng, random_surface, sharpedge_surface)
sample = {
"surface": surface,
"geo_points": geo_points,
"image": image_input,
"mask": mask_input,
}
return sample
def __iter__(self):
total_num = 0
failed_num = 0
for data in ResampledShards(self.data_list):
total_num += 1
if total_num % 1000 == 0:
print(f"Current failure rate of data loading:")
print(f"{failed_num}/{total_num}={failed_num/total_num}")
try:
sample = self.decode(data)
sample = self.transform(sample)
except Exception as err:
print(err)
failed_num += 1
continue
yield sample
class AlignedShapeLatentModule(LightningDataModule):
def __init__(
self,
batch_size: int = 1,
num_workers: int = 4,
val_num_workers: int = 2,
train_data_list: str = None,
val_data_list: str = None,
cond_stage_key: str = "all",
image_size: int = 224,
mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
pc_size: int = 2048,
pc_sharpedge_size: int = 2048,
sharpedge_label: bool = False,
return_normal: bool = False,
padding = True,
padding_ratio_range=[1.15, 1.15]
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.val_num_workers = val_num_workers
self.train_data_list = train_data_list
self.val_data_list = val_data_list
self.cond_stage_key = cond_stage_key
self.image_size = image_size
self.mean = mean
self.std = std
self.train_image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(self.image_size),
transforms.Normalize(mean=self.mean, std=self.std)])
self.val_image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(self.image_size),
transforms.Normalize(mean=self.mean, std=self.std)])
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.sharpedge_label = sharpedge_label
self.return_normal = return_normal
self.padding = padding
self.padding_ratio_range = padding_ratio_range
def train_dataloader(self):
asl_params = {
"data_list": self.train_data_list,
"cond_stage_key": self.cond_stage_key,
"image_transform": self.train_image_transform,
"pc_size": self.pc_size,
"pc_sharpedge_size": self.pc_sharpedge_size,
"sharpedge_label": self.sharpedge_label,
"return_normal": self.return_normal,
"padding": self.padding,
"padding_ratio_range": self.padding_ratio_range
}
dataset = AlignedShapeLatentDataset(**asl_params)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
worker_init_fn=worker_init_fn,
)
def val_dataloader(self):
asl_params = {
"data_list": self.val_data_list,
"cond_stage_key": self.cond_stage_key,
"image_transform": self.val_image_transform,
"pc_size": self.pc_size,
"pc_sharpedge_size": self.pc_sharpedge_size,
"sharpedge_label": self.sharpedge_label,
"return_normal": self.return_normal,
"padding": self.padding,
"padding_ratio_range": self.padding_ratio_range
}
dataset = AlignedShapeLatentDataset(**asl_params)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.val_num_workers,
pin_memory=True,
drop_last=True,
worker_init_fn=worker_init_fn,
)