init
This commit is contained in:
384
hy3dshape/hy3dshape/data/dit_asl.py
Normal file
384
hy3dshape/hy3dshape/data/dit_asl.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
import os
|
||||
import io
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import traceback
|
||||
from typing import Optional, Union, List, Tuple, Dict
|
||||
|
||||
import json
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import trimesh
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from .utils import worker_init_fn, pytorch_worker_seed, make_seed
|
||||
|
||||
|
||||
class ResampledShards(torch.utils.data.dataset.IterableDataset):
|
||||
def __init__(self, datalist, nshards=sys.maxsize, worker_seed=None, deterministic=False):
|
||||
super().__init__()
|
||||
self.datalist = datalist
|
||||
self.nshards = nshards
|
||||
# If no worker_seed provided, use pytorch_worker_seed function; else use given seed
|
||||
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
||||
self.deterministic = deterministic
|
||||
self.epoch = -1
|
||||
|
||||
def __iter__(self):
|
||||
self.epoch += 1
|
||||
if self.deterministic:
|
||||
seed = make_seed(self.worker_seed(), self.epoch)
|
||||
else:
|
||||
seed = make_seed(self.worker_seed(), self.epoch,
|
||||
os.getpid(), time.time_ns(), os.urandom(4))
|
||||
self.rng = random.Random(seed)
|
||||
for _ in range(self.nshards):
|
||||
index = self.rng.randint(0, len(self.datalist) - 1)
|
||||
yield self.datalist[index]
|
||||
|
||||
|
||||
def read_npz(data):
|
||||
# Load a numpy .npz file from a file path or file-like object
|
||||
# The commented line shows how to load from bytes in memory
|
||||
# return np.load(io.BytesIO(data))
|
||||
return np.load(data)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
# Read and parse a JSON file from the given file path
|
||||
with open(path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
|
||||
|
||||
def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
|
||||
"""
|
||||
Pad the input image and mask to a square shape with padding ratio.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Input image array of shape (H, W, C).
|
||||
mask (np.ndarray): Corresponding mask array of shape (H, W).
|
||||
center (bool): Whether to center the original image in the padded output.
|
||||
padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
|
||||
|
||||
Returns:
|
||||
newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
|
||||
newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
max_side = max(h, w)
|
||||
|
||||
# Select padding ratio either fixed or randomly within the given range
|
||||
if padding_ratio_range[0] == padding_ratio_range[1]:
|
||||
padding_ratio = padding_ratio_range[0]
|
||||
else:
|
||||
padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
|
||||
resize_side = int(max_side * padding_ratio)
|
||||
# resize_side = int(max_side * 1.15)
|
||||
|
||||
pad_h = resize_side - h
|
||||
pad_w = resize_side - w
|
||||
if center:
|
||||
start_h = pad_h // 2
|
||||
else:
|
||||
start_h = pad_h - resize_side // 20
|
||||
|
||||
start_w = pad_w // 2
|
||||
|
||||
# Create new white image and black mask with padded size
|
||||
newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
|
||||
newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
|
||||
|
||||
# Place original image and mask into the padded canvas
|
||||
newimg[start_h:start_h + h, start_w:start_w + w] = image
|
||||
newmask[start_h:start_h + h, start_w:start_w + w] = mask
|
||||
|
||||
return newimg, newmask
|
||||
|
||||
|
||||
def viz_pc(surface, normal, image_input, name):
|
||||
image_input = image_input.cpu().numpy()
|
||||
image_input = image_input.transpose(1, 2, 0) * 0.5 + 0.5
|
||||
image_input = (image_input * 255).astype(np.uint8)
|
||||
cv2.imwrite(name + '.png', cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
|
||||
surface = surface.cpu().numpy()
|
||||
normal = normal.cpu().numpy()
|
||||
surface_mesh = trimesh.Trimesh(surface, vertex_colors=(normal + 1) / 2)
|
||||
surface_mesh.export(name + '.obj')
|
||||
|
||||
|
||||
class AlignedShapeLatentDataset(torch.utils.data.dataset.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_list: str = None,
|
||||
cond_stage_key: str = "image",
|
||||
image_transform = None,
|
||||
pc_size: int = 2048,
|
||||
pc_sharpedge_size: int = 2048,
|
||||
sharpedge_label: bool = False,
|
||||
return_normal: bool = False,
|
||||
deterministic = False,
|
||||
worker_seed = None,
|
||||
padding = True,
|
||||
padding_ratio_range=[1.15, 1.15]
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(data_list, str) and data_list.endswith('.json'):
|
||||
self.data_list = read_json(data_list_json)
|
||||
elif isinstance(data_list, str) and os.path.isdir(data_list):
|
||||
self.data_list = glob.glob(data_list + '/*')
|
||||
else:
|
||||
self.data_list = data_list
|
||||
assert isinstance(self.data_list, list)
|
||||
self.rng = random.Random(0)
|
||||
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.image_transform = image_transform
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.sharpedge_label = sharpedge_label
|
||||
self.return_normal = return_normal
|
||||
|
||||
self.padding = padding
|
||||
self.padding_ratio_range = padding_ratio_range
|
||||
|
||||
rank_zero_info(f'*' * 50)
|
||||
rank_zero_info(f'Dataset Infos:')
|
||||
rank_zero_info(f'# of 3D file: {len(self.data_list)}')
|
||||
rank_zero_info(f'# of Surface Points: {self.pc_size}')
|
||||
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
|
||||
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
|
||||
rank_zero_info(f'*' * 50)
|
||||
|
||||
|
||||
def load_surface_sdf_points(self, rng, random_surface, sharpedge_surface):
|
||||
surface_normal = []
|
||||
if self.pc_size > 0:
|
||||
ind = rng.choice(random_surface.shape[0], self.pc_size, replace=False)
|
||||
random_surface = random_surface[ind]
|
||||
if self.sharpedge_label:
|
||||
sharpedge_label = np.zeros((self.pc_size, 1))
|
||||
random_surface = np.concatenate((random_surface, sharpedge_label), axis=1)
|
||||
surface_normal.append(random_surface)
|
||||
|
||||
if self.pc_sharpedge_size > 0:
|
||||
ind_sharpedge = rng.choice(sharpedge_surface.shape[0], self.pc_sharpedge_size, replace=False)
|
||||
sharpedge_surface = sharpedge_surface[ind_sharpedge]
|
||||
if self.sharpedge_label:
|
||||
sharpedge_label = np.ones((self.pc_sharpedge_size, 1))
|
||||
sharpedge_surface = np.concatenate((sharpedge_surface, sharpedge_label), axis=1)
|
||||
surface_normal.append(sharpedge_surface)
|
||||
|
||||
surface_normal = np.concatenate(surface_normal, axis=0)
|
||||
surface_normal = torch.FloatTensor(surface_normal)
|
||||
surface = surface_normal[:, 0:3]
|
||||
normal = surface_normal[:, 3:6]
|
||||
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
|
||||
|
||||
geo_points = 0.0
|
||||
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
|
||||
if self.return_normal:
|
||||
surface = torch.cat([surface, normal], dim=-1)
|
||||
if self.sharpedge_label:
|
||||
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
|
||||
return surface, geo_points
|
||||
|
||||
def load_render(self, imgs_path):
|
||||
imgs_choice = self.rng.sample(imgs_path, 1)
|
||||
images, masks = [], []
|
||||
for image_path in imgs_choice:
|
||||
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
assert image.shape[2] == 4
|
||||
alpha = image[:, :, 3:4].astype(np.float32) / 255
|
||||
forground = image[:, :, :3]
|
||||
background = np.ones_like(forground) * 255
|
||||
img_new = forground * alpha + background * (1 - alpha)
|
||||
image = img_new.astype(np.uint8)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
mask = (alpha[:, :, 0] * 255).astype(np.uint8)
|
||||
|
||||
if self.padding:
|
||||
h, w = image.shape[:2]
|
||||
binary = mask > 0.3
|
||||
non_zero_coords = np.argwhere(binary)
|
||||
x_min, y_min = non_zero_coords.min(axis=0)
|
||||
x_max, y_max = non_zero_coords.max(axis=0)
|
||||
image, mask = padding(
|
||||
image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
||||
mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
||||
center=True, padding_ratio_range=self.padding_ratio_range)
|
||||
|
||||
if self.image_transform:
|
||||
image = self.image_transform(image)
|
||||
mask = np.stack((mask, mask, mask), axis=-1)
|
||||
mask = self.image_transform(mask)
|
||||
|
||||
images.append(image)
|
||||
masks.append(mask)
|
||||
|
||||
images = torch.cat(images, dim=0)
|
||||
masks = torch.cat(masks, dim=0)[:1, ...]
|
||||
return images, masks
|
||||
|
||||
def decode(self, item):
|
||||
uid = item.split('/')[-1]
|
||||
render_img_paths = [os.path.join(item, f'render_cond/{i:03d}.png') for i in range(24)]
|
||||
# transforms_json_path = os.path.join(item, 'render_cond/transforms.json')
|
||||
surface_npz_path = os.path.join(item, f'geo_data/{uid}_surface.npz')
|
||||
# sdf_npz_path = os.path.join(item, f'geo_data/{uid}_sdf.npz')
|
||||
# watertight_obj_path = os.path.join(item, f'geo_data/{uid}_watertight.obj')
|
||||
sample = {}
|
||||
sample["image"] = render_img_paths
|
||||
surface_data = read_npz(surface_npz_path)
|
||||
sample["random_surface"] = surface_data['random_surface']
|
||||
sample["sharpedge_surface"] = surface_data['sharp_surface']
|
||||
return sample
|
||||
|
||||
def transform(self, sample):
|
||||
rng = np.random.default_rng()
|
||||
random_surface = sample.get("random_surface", 0)
|
||||
sharpedge_surface = sample.get("sharpedge_surface", 0)
|
||||
image_input, mask_input = self.load_render(sample['image'])
|
||||
surface, geo_points = self.load_surface_sdf_points(rng, random_surface, sharpedge_surface)
|
||||
sample = {
|
||||
"surface": surface,
|
||||
"geo_points": geo_points,
|
||||
"image": image_input,
|
||||
"mask": mask_input,
|
||||
}
|
||||
return sample
|
||||
|
||||
def __iter__(self):
|
||||
total_num = 0
|
||||
failed_num = 0
|
||||
for data in ResampledShards(self.data_list):
|
||||
total_num += 1
|
||||
if total_num % 1000 == 0:
|
||||
print(f"Current failure rate of data loading:")
|
||||
print(f"{failed_num}/{total_num}={failed_num/total_num}")
|
||||
try:
|
||||
sample = self.decode(data)
|
||||
sample = self.transform(sample)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
failed_num += 1
|
||||
continue
|
||||
yield sample
|
||||
|
||||
|
||||
class AlignedShapeLatentModule(LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 4,
|
||||
val_num_workers: int = 2,
|
||||
train_data_list: str = None,
|
||||
val_data_list: str = None,
|
||||
cond_stage_key: str = "all",
|
||||
image_size: int = 224,
|
||||
mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
|
||||
std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
|
||||
pc_size: int = 2048,
|
||||
pc_sharpedge_size: int = 2048,
|
||||
sharpedge_label: bool = False,
|
||||
return_normal: bool = False,
|
||||
padding = True,
|
||||
padding_ratio_range=[1.15, 1.15]
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_num_workers = val_num_workers
|
||||
|
||||
self.train_data_list = train_data_list
|
||||
self.val_data_list = val_data_list
|
||||
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.image_size = image_size
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.train_image_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize(self.image_size),
|
||||
transforms.Normalize(mean=self.mean, std=self.std)])
|
||||
self.val_image_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize(self.image_size),
|
||||
transforms.Normalize(mean=self.mean, std=self.std)])
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.sharpedge_label = sharpedge_label
|
||||
self.return_normal = return_normal
|
||||
|
||||
self.padding = padding
|
||||
self.padding_ratio_range = padding_ratio_range
|
||||
|
||||
def train_dataloader(self):
|
||||
asl_params = {
|
||||
"data_list": self.train_data_list,
|
||||
"cond_stage_key": self.cond_stage_key,
|
||||
"image_transform": self.train_image_transform,
|
||||
"pc_size": self.pc_size,
|
||||
"pc_sharpedge_size": self.pc_sharpedge_size,
|
||||
"sharpedge_label": self.sharpedge_label,
|
||||
"return_normal": self.return_normal,
|
||||
"padding": self.padding,
|
||||
"padding_ratio_range": self.padding_ratio_range
|
||||
}
|
||||
dataset = AlignedShapeLatentDataset(**asl_params)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
asl_params = {
|
||||
"data_list": self.val_data_list,
|
||||
"cond_stage_key": self.cond_stage_key,
|
||||
"image_transform": self.val_image_transform,
|
||||
"pc_size": self.pc_size,
|
||||
"pc_sharpedge_size": self.pc_sharpedge_size,
|
||||
"sharpedge_label": self.sharpedge_label,
|
||||
"return_normal": self.return_normal,
|
||||
"padding": self.padding,
|
||||
"padding_ratio_range": self.padding_ratio_range
|
||||
}
|
||||
dataset = AlignedShapeLatentDataset(**asl_params)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.val_num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
Reference in New Issue
Block a user