init
This commit is contained in:
17
hy3dshape/hy3dshape/__init__.py
Normal file
17
hy3dshape/hy3dshape/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from .pipelines import Hunyuan3DDiTPipeline, Hunyuan3DDiTFlowMatchingPipeline
|
||||
from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier
|
||||
from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
|
||||
384
hy3dshape/hy3dshape/data/dit_asl.py
Normal file
384
hy3dshape/hy3dshape/data/dit_asl.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
import os
|
||||
import io
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
import traceback
|
||||
from typing import Optional, Union, List, Tuple, Dict
|
||||
|
||||
import json
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import trimesh
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from .utils import worker_init_fn, pytorch_worker_seed, make_seed
|
||||
|
||||
|
||||
class ResampledShards(torch.utils.data.dataset.IterableDataset):
|
||||
def __init__(self, datalist, nshards=sys.maxsize, worker_seed=None, deterministic=False):
|
||||
super().__init__()
|
||||
self.datalist = datalist
|
||||
self.nshards = nshards
|
||||
# If no worker_seed provided, use pytorch_worker_seed function; else use given seed
|
||||
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
||||
self.deterministic = deterministic
|
||||
self.epoch = -1
|
||||
|
||||
def __iter__(self):
|
||||
self.epoch += 1
|
||||
if self.deterministic:
|
||||
seed = make_seed(self.worker_seed(), self.epoch)
|
||||
else:
|
||||
seed = make_seed(self.worker_seed(), self.epoch,
|
||||
os.getpid(), time.time_ns(), os.urandom(4))
|
||||
self.rng = random.Random(seed)
|
||||
for _ in range(self.nshards):
|
||||
index = self.rng.randint(0, len(self.datalist) - 1)
|
||||
yield self.datalist[index]
|
||||
|
||||
|
||||
def read_npz(data):
|
||||
# Load a numpy .npz file from a file path or file-like object
|
||||
# The commented line shows how to load from bytes in memory
|
||||
# return np.load(io.BytesIO(data))
|
||||
return np.load(data)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
# Read and parse a JSON file from the given file path
|
||||
with open(path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
|
||||
|
||||
def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
|
||||
"""
|
||||
Pad the input image and mask to a square shape with padding ratio.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Input image array of shape (H, W, C).
|
||||
mask (np.ndarray): Corresponding mask array of shape (H, W).
|
||||
center (bool): Whether to center the original image in the padded output.
|
||||
padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
|
||||
|
||||
Returns:
|
||||
newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
|
||||
newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
max_side = max(h, w)
|
||||
|
||||
# Select padding ratio either fixed or randomly within the given range
|
||||
if padding_ratio_range[0] == padding_ratio_range[1]:
|
||||
padding_ratio = padding_ratio_range[0]
|
||||
else:
|
||||
padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
|
||||
resize_side = int(max_side * padding_ratio)
|
||||
# resize_side = int(max_side * 1.15)
|
||||
|
||||
pad_h = resize_side - h
|
||||
pad_w = resize_side - w
|
||||
if center:
|
||||
start_h = pad_h // 2
|
||||
else:
|
||||
start_h = pad_h - resize_side // 20
|
||||
|
||||
start_w = pad_w // 2
|
||||
|
||||
# Create new white image and black mask with padded size
|
||||
newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
|
||||
newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
|
||||
|
||||
# Place original image and mask into the padded canvas
|
||||
newimg[start_h:start_h + h, start_w:start_w + w] = image
|
||||
newmask[start_h:start_h + h, start_w:start_w + w] = mask
|
||||
|
||||
return newimg, newmask
|
||||
|
||||
|
||||
def viz_pc(surface, normal, image_input, name):
|
||||
image_input = image_input.cpu().numpy()
|
||||
image_input = image_input.transpose(1, 2, 0) * 0.5 + 0.5
|
||||
image_input = (image_input * 255).astype(np.uint8)
|
||||
cv2.imwrite(name + '.png', cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
|
||||
surface = surface.cpu().numpy()
|
||||
normal = normal.cpu().numpy()
|
||||
surface_mesh = trimesh.Trimesh(surface, vertex_colors=(normal + 1) / 2)
|
||||
surface_mesh.export(name + '.obj')
|
||||
|
||||
|
||||
class AlignedShapeLatentDataset(torch.utils.data.dataset.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_list: str = None,
|
||||
cond_stage_key: str = "image",
|
||||
image_transform = None,
|
||||
pc_size: int = 2048,
|
||||
pc_sharpedge_size: int = 2048,
|
||||
sharpedge_label: bool = False,
|
||||
return_normal: bool = False,
|
||||
deterministic = False,
|
||||
worker_seed = None,
|
||||
padding = True,
|
||||
padding_ratio_range=[1.15, 1.15]
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(data_list, str) and data_list.endswith('.json'):
|
||||
self.data_list = read_json(data_list_json)
|
||||
elif isinstance(data_list, str) and os.path.isdir(data_list):
|
||||
self.data_list = glob.glob(data_list + '/*')
|
||||
else:
|
||||
self.data_list = data_list
|
||||
assert isinstance(self.data_list, list)
|
||||
self.rng = random.Random(0)
|
||||
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.image_transform = image_transform
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.sharpedge_label = sharpedge_label
|
||||
self.return_normal = return_normal
|
||||
|
||||
self.padding = padding
|
||||
self.padding_ratio_range = padding_ratio_range
|
||||
|
||||
rank_zero_info(f'*' * 50)
|
||||
rank_zero_info(f'Dataset Infos:')
|
||||
rank_zero_info(f'# of 3D file: {len(self.data_list)}')
|
||||
rank_zero_info(f'# of Surface Points: {self.pc_size}')
|
||||
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
|
||||
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
|
||||
rank_zero_info(f'*' * 50)
|
||||
|
||||
|
||||
def load_surface_sdf_points(self, rng, random_surface, sharpedge_surface):
|
||||
surface_normal = []
|
||||
if self.pc_size > 0:
|
||||
ind = rng.choice(random_surface.shape[0], self.pc_size, replace=False)
|
||||
random_surface = random_surface[ind]
|
||||
if self.sharpedge_label:
|
||||
sharpedge_label = np.zeros((self.pc_size, 1))
|
||||
random_surface = np.concatenate((random_surface, sharpedge_label), axis=1)
|
||||
surface_normal.append(random_surface)
|
||||
|
||||
if self.pc_sharpedge_size > 0:
|
||||
ind_sharpedge = rng.choice(sharpedge_surface.shape[0], self.pc_sharpedge_size, replace=False)
|
||||
sharpedge_surface = sharpedge_surface[ind_sharpedge]
|
||||
if self.sharpedge_label:
|
||||
sharpedge_label = np.ones((self.pc_sharpedge_size, 1))
|
||||
sharpedge_surface = np.concatenate((sharpedge_surface, sharpedge_label), axis=1)
|
||||
surface_normal.append(sharpedge_surface)
|
||||
|
||||
surface_normal = np.concatenate(surface_normal, axis=0)
|
||||
surface_normal = torch.FloatTensor(surface_normal)
|
||||
surface = surface_normal[:, 0:3]
|
||||
normal = surface_normal[:, 3:6]
|
||||
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
|
||||
|
||||
geo_points = 0.0
|
||||
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
|
||||
if self.return_normal:
|
||||
surface = torch.cat([surface, normal], dim=-1)
|
||||
if self.sharpedge_label:
|
||||
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
|
||||
return surface, geo_points
|
||||
|
||||
def load_render(self, imgs_path):
|
||||
imgs_choice = self.rng.sample(imgs_path, 1)
|
||||
images, masks = [], []
|
||||
for image_path in imgs_choice:
|
||||
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
assert image.shape[2] == 4
|
||||
alpha = image[:, :, 3:4].astype(np.float32) / 255
|
||||
forground = image[:, :, :3]
|
||||
background = np.ones_like(forground) * 255
|
||||
img_new = forground * alpha + background * (1 - alpha)
|
||||
image = img_new.astype(np.uint8)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
mask = (alpha[:, :, 0] * 255).astype(np.uint8)
|
||||
|
||||
if self.padding:
|
||||
h, w = image.shape[:2]
|
||||
binary = mask > 0.3
|
||||
non_zero_coords = np.argwhere(binary)
|
||||
x_min, y_min = non_zero_coords.min(axis=0)
|
||||
x_max, y_max = non_zero_coords.max(axis=0)
|
||||
image, mask = padding(
|
||||
image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
||||
mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
||||
center=True, padding_ratio_range=self.padding_ratio_range)
|
||||
|
||||
if self.image_transform:
|
||||
image = self.image_transform(image)
|
||||
mask = np.stack((mask, mask, mask), axis=-1)
|
||||
mask = self.image_transform(mask)
|
||||
|
||||
images.append(image)
|
||||
masks.append(mask)
|
||||
|
||||
images = torch.cat(images, dim=0)
|
||||
masks = torch.cat(masks, dim=0)[:1, ...]
|
||||
return images, masks
|
||||
|
||||
def decode(self, item):
|
||||
uid = item.split('/')[-1]
|
||||
render_img_paths = [os.path.join(item, f'render_cond/{i:03d}.png') for i in range(24)]
|
||||
# transforms_json_path = os.path.join(item, 'render_cond/transforms.json')
|
||||
surface_npz_path = os.path.join(item, f'geo_data/{uid}_surface.npz')
|
||||
# sdf_npz_path = os.path.join(item, f'geo_data/{uid}_sdf.npz')
|
||||
# watertight_obj_path = os.path.join(item, f'geo_data/{uid}_watertight.obj')
|
||||
sample = {}
|
||||
sample["image"] = render_img_paths
|
||||
surface_data = read_npz(surface_npz_path)
|
||||
sample["random_surface"] = surface_data['random_surface']
|
||||
sample["sharpedge_surface"] = surface_data['sharp_surface']
|
||||
return sample
|
||||
|
||||
def transform(self, sample):
|
||||
rng = np.random.default_rng()
|
||||
random_surface = sample.get("random_surface", 0)
|
||||
sharpedge_surface = sample.get("sharpedge_surface", 0)
|
||||
image_input, mask_input = self.load_render(sample['image'])
|
||||
surface, geo_points = self.load_surface_sdf_points(rng, random_surface, sharpedge_surface)
|
||||
sample = {
|
||||
"surface": surface,
|
||||
"geo_points": geo_points,
|
||||
"image": image_input,
|
||||
"mask": mask_input,
|
||||
}
|
||||
return sample
|
||||
|
||||
def __iter__(self):
|
||||
total_num = 0
|
||||
failed_num = 0
|
||||
for data in ResampledShards(self.data_list):
|
||||
total_num += 1
|
||||
if total_num % 1000 == 0:
|
||||
print(f"Current failure rate of data loading:")
|
||||
print(f"{failed_num}/{total_num}={failed_num/total_num}")
|
||||
try:
|
||||
sample = self.decode(data)
|
||||
sample = self.transform(sample)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
failed_num += 1
|
||||
continue
|
||||
yield sample
|
||||
|
||||
|
||||
class AlignedShapeLatentModule(LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 4,
|
||||
val_num_workers: int = 2,
|
||||
train_data_list: str = None,
|
||||
val_data_list: str = None,
|
||||
cond_stage_key: str = "all",
|
||||
image_size: int = 224,
|
||||
mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
|
||||
std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
|
||||
pc_size: int = 2048,
|
||||
pc_sharpedge_size: int = 2048,
|
||||
sharpedge_label: bool = False,
|
||||
return_normal: bool = False,
|
||||
padding = True,
|
||||
padding_ratio_range=[1.15, 1.15]
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_num_workers = val_num_workers
|
||||
|
||||
self.train_data_list = train_data_list
|
||||
self.val_data_list = val_data_list
|
||||
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.image_size = image_size
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.train_image_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize(self.image_size),
|
||||
transforms.Normalize(mean=self.mean, std=self.std)])
|
||||
self.val_image_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize(self.image_size),
|
||||
transforms.Normalize(mean=self.mean, std=self.std)])
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.sharpedge_label = sharpedge_label
|
||||
self.return_normal = return_normal
|
||||
|
||||
self.padding = padding
|
||||
self.padding_ratio_range = padding_ratio_range
|
||||
|
||||
def train_dataloader(self):
|
||||
asl_params = {
|
||||
"data_list": self.train_data_list,
|
||||
"cond_stage_key": self.cond_stage_key,
|
||||
"image_transform": self.train_image_transform,
|
||||
"pc_size": self.pc_size,
|
||||
"pc_sharpedge_size": self.pc_sharpedge_size,
|
||||
"sharpedge_label": self.sharpedge_label,
|
||||
"return_normal": self.return_normal,
|
||||
"padding": self.padding,
|
||||
"padding_ratio_range": self.padding_ratio_range
|
||||
}
|
||||
dataset = AlignedShapeLatentDataset(**asl_params)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
asl_params = {
|
||||
"data_list": self.val_data_list,
|
||||
"cond_stage_key": self.cond_stage_key,
|
||||
"image_transform": self.val_image_transform,
|
||||
"pc_size": self.pc_size,
|
||||
"pc_sharpedge_size": self.pc_sharpedge_size,
|
||||
"sharpedge_label": self.sharpedge_label,
|
||||
"return_normal": self.return_normal,
|
||||
"padding": self.padding,
|
||||
"padding_ratio_range": self.padding_ratio_range
|
||||
}
|
||||
dataset = AlignedShapeLatentDataset(**asl_params)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.val_num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
186
hy3dshape/hy3dshape/data/utils.py
Normal file
186
hy3dshape/hy3dshape/data/utils.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# This file is part of the WebDataset library.
|
||||
# See the LICENSE file for licensing terms (BSD-style).
|
||||
|
||||
|
||||
"""Miscellaneous utility functions."""
|
||||
|
||||
import importlib
|
||||
import itertools as itt
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Callable, Iterator, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_seed(*args):
|
||||
seed = 0
|
||||
for arg in args:
|
||||
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
|
||||
return seed
|
||||
|
||||
|
||||
class PipelineStage:
|
||||
def invoke(self, *args, **kw):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def identity(x: Any) -> Any:
|
||||
"""Return the argument as is."""
|
||||
return x
|
||||
|
||||
|
||||
def safe_eval(s: str, expr: str = "{}"):
|
||||
"""Evaluate the given expression more safely."""
|
||||
if re.sub("[^A-Za-z0-9_]", "", s) != s:
|
||||
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
|
||||
return eval(expr.format(s))
|
||||
|
||||
|
||||
def lookup_sym(sym: str, modules: list):
|
||||
"""Look up a symbol in a list of modules."""
|
||||
for mname in modules:
|
||||
module = importlib.import_module(mname, package="webdataset")
|
||||
result = getattr(module, sym, None)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def repeatedly0(
|
||||
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
|
||||
):
|
||||
"""Repeatedly returns batches from a DataLoader."""
|
||||
for _ in range(nepochs):
|
||||
yield from itt.islice(loader, nbatches)
|
||||
|
||||
|
||||
def guess_batchsize(batch: Union[tuple, list]):
|
||||
"""Guess the batch size by looking at the length of the first element in a tuple."""
|
||||
return len(batch[0])
|
||||
|
||||
|
||||
def repeatedly(
|
||||
source: Iterator,
|
||||
nepochs: int = None,
|
||||
nbatches: int = None,
|
||||
nsamples: int = None,
|
||||
batchsize: Callable[..., int] = guess_batchsize,
|
||||
):
|
||||
"""Repeatedly yield samples from an iterator."""
|
||||
epoch = 0
|
||||
batch = 0
|
||||
total = 0
|
||||
while True:
|
||||
for sample in source:
|
||||
yield sample
|
||||
batch += 1
|
||||
if nbatches is not None and batch >= nbatches:
|
||||
return
|
||||
if nsamples is not None:
|
||||
total += guess_batchsize(sample)
|
||||
if total >= nsamples:
|
||||
return
|
||||
epoch += 1
|
||||
if nepochs is not None and epoch >= nepochs:
|
||||
return
|
||||
|
||||
|
||||
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
|
||||
"""Return node and worker info for PyTorch and some distributed environments."""
|
||||
rank = 0
|
||||
world_size = 1
|
||||
worker = 0
|
||||
num_workers = 1
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
else:
|
||||
try:
|
||||
import torch.distributed
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
group = group or torch.distributed.group.WORLD
|
||||
rank = torch.distributed.get_rank(group=group)
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
||||
worker = int(os.environ["WORKER"])
|
||||
num_workers = int(os.environ["NUM_WORKERS"])
|
||||
else:
|
||||
try:
|
||||
import torch.utils.data
|
||||
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
worker = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
return rank, world_size, worker, num_workers
|
||||
|
||||
|
||||
def pytorch_worker_seed(group=None):
|
||||
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
||||
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
|
||||
return rank * 1000 + worker
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
worker_id = worker_info.id
|
||||
|
||||
# dataset = worker_info.dataset
|
||||
# split_size = dataset.num_records // worker_info.num_workers
|
||||
# # reset num_records to the true number to retain reliable length information
|
||||
# dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
||||
# current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
# return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
def collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
samples (list[dict]):
|
||||
combine_tensors:
|
||||
combine_scalars:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
result = {}
|
||||
|
||||
keys = samples[0].keys()
|
||||
|
||||
for key in keys:
|
||||
result[key] = []
|
||||
|
||||
for sample in samples:
|
||||
for key in keys:
|
||||
val = sample[key]
|
||||
result[key].append(val)
|
||||
|
||||
for key in keys:
|
||||
val_list = result[key]
|
||||
if isinstance(val_list[0], (int, float)):
|
||||
if combine_scalars:
|
||||
result[key] = np.array(result[key])
|
||||
|
||||
elif isinstance(val_list[0], torch.Tensor):
|
||||
if combine_tensors:
|
||||
result[key] = torch.stack(val_list)
|
||||
|
||||
elif isinstance(val_list[0], np.ndarray):
|
||||
if combine_tensors:
|
||||
result[key] = np.stack(val_list)
|
||||
|
||||
return result
|
||||
28
hy3dshape/hy3dshape/models/__init__.py
Normal file
28
hy3dshape/hy3dshape/models/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0
|
||||
# and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
from .autoencoders import ShapeVAE
|
||||
from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
|
||||
from .denoisers import Hunyuan3DDiT
|
||||
20
hy3dshape/hy3dshape/models/autoencoders/__init__.py
Normal file
20
hy3dshape/hy3dshape/models/autoencoders/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from .attention_blocks import CrossAttentionDecoder
|
||||
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
|
||||
FlashVDMTopMCrossAttentionProcessor
|
||||
from .model import ShapeVAE, VectsetVAE
|
||||
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
|
||||
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
|
||||
716
hy3dshape/hy3dshape/models/autoencoders/attention_blocks.py
Normal file
716
hy3dshape/hy3dshape/models/autoencoders/attention_blocks.py
Normal file
@@ -0,0 +1,716 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0
|
||||
# and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from .attention_processors import CrossAttentionProcessor
|
||||
from ...utils import logger
|
||||
|
||||
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
||||
|
||||
if os.environ.get('USE_SAGEATTN', '0') == '1':
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ImportError:
|
||||
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||
scaled_dot_product_attention = sageattn
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
], here f_i is the frequency.
|
||||
|
||||
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||
|
||||
Args:
|
||||
num_freqs (int): the number of frequencies, default is 6;
|
||||
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||
input_dim (int): the input dimension, default is 3;
|
||||
include_input (bool): include the input tensor or not, default is True.
|
||||
|
||||
Attributes:
|
||||
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||
|
||||
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||
otherwise, it is input_dim * num_freqs * 2.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_freqs: int = 6,
|
||||
logspace: bool = True,
|
||||
input_dim: int = 3,
|
||||
include_input: bool = True,
|
||||
include_pi: bool = True) -> None:
|
||||
|
||||
"""The initialization"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (num_freqs - 1),
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if include_pi:
|
||||
frequencies *= torch.pi
|
||||
|
||||
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||
self.include_input = include_input
|
||||
self.num_freqs = num_freqs
|
||||
|
||||
self.out_dim = self.get_dims(input_dim)
|
||||
|
||||
def get_dims(self, input_dim):
|
||||
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||
|
||||
return out_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Forward process.
|
||||
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
|
||||
Returns:
|
||||
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||
where temp is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
|
||||
if self.num_freqs > 0:
|
||||
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
||||
if self.include_input:
|
||||
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if self.drop_prob == 0. or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, *,
|
||||
width: int,
|
||||
expand_ratio: int = 4,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = nn.Linear(width, width * expand_ratio)
|
||||
self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data: Optional[int] = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
n_data: Optional[int] = None,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
kv_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_data = n_data
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
n_data=n_data,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.kv_cache = kv_cache
|
||||
self.data = None
|
||||
|
||||
def forward(self, x, data):
|
||||
x = self.c_q(x)
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logger.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
x = self.attention(x, data)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_data: Optional[int] = None,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
n_data=n_data,
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_ctx: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_ctx = n_ctx
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.heads // 3
|
||||
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x)
|
||||
x = self.drop_path(self.c_proj(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for block in self.resblocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_latents: int,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
downsample_ratio: int = 1,
|
||||
enable_ln_post: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_ln_post = enable_ln_post
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = nn.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
n_data=num_latents,
|
||||
width=width,
|
||||
mlp_expand_ratio=mlp_expand_ratio,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if self.enable_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
self.output_proj = nn.Linear(width, out_channels)
|
||||
self.label_type = label_type
|
||||
self.count = 0
|
||||
|
||||
def set_cross_attention_processor(self, processor):
|
||||
self.cross_attn_decoder.attn.attention.attn_processor = processor
|
||||
|
||||
def set_default_cross_attention_processor(self):
|
||||
self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
|
||||
|
||||
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||
self.count += query_embeddings.shape[1]
|
||||
if self.downsample_ratio != 1:
|
||||
latents = self.latents_proj(latents)
|
||||
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||
if self.enable_ln_post:
|
||||
x = self.ln_post(x)
|
||||
occ = self.output_proj(x)
|
||||
return occ
|
||||
|
||||
|
||||
def fps(
|
||||
src: torch.Tensor,
|
||||
batch: Optional[Tensor] = None,
|
||||
ratio: Optional[Union[Tensor, float]] = None,
|
||||
random_start: bool = True,
|
||||
batch_size: Optional[int] = None,
|
||||
ptr: Optional[Union[Tensor, List[int]]] = None,
|
||||
):
|
||||
src = src.float()
|
||||
from torch_cluster import fps as fps_fn
|
||||
output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
|
||||
return output
|
||||
|
||||
|
||||
class PointCrossAttentionEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, *,
|
||||
num_latents: int,
|
||||
downsample_ratio: float,
|
||||
pc_size: int,
|
||||
pc_sharpedge_size: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
normal_pe: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_latents = num_latents
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.point_feats = point_feats
|
||||
self.normal_pe = normal_pe
|
||||
|
||||
if pc_sharpedge_size == 0:
|
||||
print(
|
||||
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given, using pc_size as pc_sharpedge_size')
|
||||
else:
|
||||
print(
|
||||
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def sample_points_and_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
||||
B, N, D = pc.shape
|
||||
num_pts = self.num_latents * self.downsample_ratio
|
||||
|
||||
# Compute number of latents
|
||||
num_latents = int(num_pts / self.downsample_ratio)
|
||||
|
||||
# Compute the number of random and sharpedge latents
|
||||
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||
num_sharpedge_query = num_latents - num_random_query
|
||||
|
||||
# Split random and sharpedge surface points
|
||||
random_pc, sharpedge_pc = torch.split(pc, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||
assert sharpedge_pc.shape[
|
||||
1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||
|
||||
# Randomly select random surface points and random query points
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_ratio = num_random_query / input_random_pc_size
|
||||
idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[:input_random_pc_size]
|
||||
input_random_pc = random_pc[:, idx_random_pc, :]
|
||||
flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
|
||||
N_down = int(flatten_input_random_pc.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
idx_query_random = fps(flatten_input_random_pc, batch_down, ratio=random_query_ratio)
|
||||
query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
|
||||
|
||||
# Randomly select sharpedge surface points and sharpedge query points
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(pc.device)
|
||||
query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(pc.device)
|
||||
else:
|
||||
sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
|
||||
idx_sharpedge_pc = torch.randperm(sharpedge_pc.shape[1], device=sharpedge_pc.device)[
|
||||
:input_sharpedge_pc_size]
|
||||
input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
|
||||
flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(B * input_sharpedge_pc_size, D)
|
||||
N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
idx_query_sharpedge = fps(flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio)
|
||||
query_sharpedge_pc = flatten_input_sharpedge_surface_points[idx_query_sharpedge].view(B, -1, D)
|
||||
|
||||
# Concatenate random and sharpedge surface points and query points
|
||||
query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
|
||||
input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
|
||||
|
||||
# PE
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
# Concat normal if given
|
||||
if self.point_feats != 0:
|
||||
|
||||
random_surface_feats, sharpedge_surface_feats = torch.split(feats, [self.pc_size, self.pc_sharpedge_size],
|
||||
dim=1)
|
||||
input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
|
||||
flatten_input_random_surface_feats = input_random_surface_feats.view(B * input_random_pc_size, -1)
|
||||
query_random_feats = flatten_input_random_surface_feats[idx_query_random].view(B, -1,
|
||||
flatten_input_random_surface_feats.shape[
|
||||
-1])
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_feats = torch.zeros(B, 0, self.point_feats,
|
||||
dtype=input_random_surface_feats.dtype).to(pc.device)
|
||||
query_sharpedge_feats = torch.zeros(B, 0, self.point_feats, dtype=query_random_feats.dtype).to(
|
||||
pc.device)
|
||||
else:
|
||||
input_sharpedge_surface_feats = sharpedge_surface_feats[:, idx_sharpedge_pc, :]
|
||||
flatten_input_sharpedge_surface_feats = input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size,
|
||||
-1)
|
||||
query_sharpedge_feats = flatten_input_sharpedge_surface_feats[idx_query_sharpedge].view(B, -1,
|
||||
flatten_input_sharpedge_surface_feats.shape[
|
||||
-1])
|
||||
|
||||
query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
|
||||
input_feats = torch.cat([input_random_surface_feats, input_sharpedge_surface_feats], dim=1)
|
||||
|
||||
if self.normal_pe:
|
||||
query_normal_pe = self.fourier_embedder(query_feats[..., :3])
|
||||
input_normal_pe = self.fourier_embedder(input_feats[..., :3])
|
||||
query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
|
||||
input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
|
||||
|
||||
query = torch.cat([query, query_feats], dim=-1)
|
||||
data = torch.cat([data, input_feats], dim=-1)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
||||
input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
||||
|
||||
# print(f'query_pc: {query_pc.shape}')
|
||||
# print(f'input_pc: {input_pc.shape}')
|
||||
# print(f'query_random_pc: {query_random_pc.shape}')
|
||||
# print(f'input_random_pc: {input_random_pc.shape}')
|
||||
# print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}')
|
||||
# print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}')
|
||||
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), [query_pc, input_pc,
|
||||
query_random_pc, input_random_pc,
|
||||
query_sharpedge_pc,
|
||||
input_sharpedge_pc]
|
||||
|
||||
def forward(self, pc, feats):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
query, data, pc_infos = self.sample_points_and_latents(pc, feats)
|
||||
|
||||
query = self.input_proj(query)
|
||||
query = query
|
||||
data = self.input_proj(data)
|
||||
data = data
|
||||
|
||||
latents = self.cross_attn(query, data)
|
||||
if self.self_attn is not None:
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents, pc_infos
|
||||
@@ -0,0 +1,96 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
scaled_dot_product_attention = F.scaled_dot_product_attention
|
||||
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ImportError:
|
||||
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||
scaled_dot_product_attention = sageattn
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class FlashVDMCrossAttentionProcessor:
|
||||
def __init__(self, topk=None):
|
||||
self.topk = topk
|
||||
|
||||
def __call__(self, attn, q, k, v):
|
||||
if k.shape[-2] == 3072:
|
||||
topk = 1024
|
||||
elif k.shape[-2] == 512:
|
||||
topk = 256
|
||||
else:
|
||||
topk = k.shape[-2] // 3
|
||||
|
||||
if self.topk is True:
|
||||
q1 = q[:, :, ::100, :]
|
||||
sim = q1 @ k.transpose(-1, -2)
|
||||
sim = torch.mean(sim, -2)
|
||||
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
||||
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
||||
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
||||
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
||||
out = scaled_dot_product_attention(q, k0, v0)
|
||||
elif self.topk is False:
|
||||
out = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
idx, counts = self.topk
|
||||
start = 0
|
||||
outs = []
|
||||
for grid_coord, count in zip(idx, counts):
|
||||
end = start + count
|
||||
q_chunk = q[:, :, start:end, :]
|
||||
k0, v0 = self.select_topkv(q_chunk, k, v, topk)
|
||||
out = scaled_dot_product_attention(q_chunk, k0, v0)
|
||||
outs.append(out)
|
||||
start += count
|
||||
out = torch.cat(outs, dim=-2)
|
||||
self.topk = False
|
||||
return out
|
||||
|
||||
def select_topkv(self, q_chunk, k, v, topk):
|
||||
q1 = q_chunk[:, :, ::50, :]
|
||||
sim = q1 @ k.transpose(-1, -2)
|
||||
sim = torch.mean(sim, -2)
|
||||
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
||||
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
||||
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
||||
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
||||
return k0, v0
|
||||
|
||||
|
||||
class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
|
||||
def select_topkv(self, q_chunk, k, v, topk):
|
||||
q1 = q_chunk[:, :, ::30, :]
|
||||
sim = q1 @ k.transpose(-1, -2)
|
||||
# sim = sim.to(torch.float32)
|
||||
sim = sim.softmax(-1)
|
||||
sim = torch.mean(sim, 1)
|
||||
activated_token = torch.where(sim > 1e-6)[2]
|
||||
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
||||
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
|
||||
v0 = torch.gather(v, dim=-2, index=index)
|
||||
k0 = torch.gather(k, dim=-2, index=index)
|
||||
return k0, v0
|
||||
339
hy3dshape/hy3dshape/models/autoencoders/model.py
Normal file
339
hy3dshape/hy3dshape/models/autoencoders/model.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0
|
||||
# and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
|
||||
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder, PointCrossAttentionEncoder
|
||||
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
|
||||
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
|
||||
from ...utils import logger, synchronize_timer, smart_load_model
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
|
||||
"""
|
||||
Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
|
||||
|
||||
Args:
|
||||
parameters (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
|
||||
or a list of two tensors [mean, logvar].
|
||||
deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
|
||||
Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
|
||||
concatenated if parameters is a single tensor. Default is 1.
|
||||
"""
|
||||
self.feat_dim = feat_dim
|
||||
self.parameters = parameters
|
||||
|
||||
if isinstance(parameters, list):
|
||||
self.mean = parameters[0]
|
||||
self.logvar = parameters[1]
|
||||
else:
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean)
|
||||
|
||||
def sample(self):
|
||||
"""
|
||||
Sample from the diagonal Gaussian distribution.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A sample tensor with the same shape as the mean.
|
||||
"""
|
||||
x = self.mean + self.std * torch.randn_like(self.mean)
|
||||
return x
|
||||
|
||||
def kl(self, other=None, dims=(1, 2, 3)):
|
||||
"""
|
||||
Compute the Kullback-Leibler (KL) divergence between this distribution and another.
|
||||
|
||||
If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
|
||||
|
||||
Args:
|
||||
other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
|
||||
dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
|
||||
Default is (1, 2, 3).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The mean KL divergence value.
|
||||
"""
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.mean(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=dims)
|
||||
else:
|
||||
return 0.5 * torch.mean(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=dims)
|
||||
|
||||
def nll(self, sample, dims=(1, 2, 3)):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class VectsetVAE(nn.Module):
|
||||
|
||||
@classmethod
|
||||
@synchronize_timer('VectsetVAE Model Loading')
|
||||
def from_single_file(
|
||||
cls,
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# load ckpt
|
||||
if use_safetensors:
|
||||
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
if use_safetensors:
|
||||
import safetensors.torch
|
||||
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
|
||||
model_kwargs = config['params']
|
||||
model_kwargs.update(kwargs)
|
||||
|
||||
model = cls(**model_kwargs)
|
||||
model.load_state_dict(ckpt)
|
||||
model.to(device=device, dtype=dtype)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=False,
|
||||
variant='fp16',
|
||||
subfolder='hunyuan3d-vae-v2-1',
|
||||
**kwargs,
|
||||
):
|
||||
config_path, ckpt_path = smart_load_model(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
use_safetensors=use_safetensors,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
return cls.from_single_file(
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_safetensors=use_safetensors,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
state_dict = torch.load(path, map_location="cpu")
|
||||
state_dict = state_dict.get("state_dict", state_dict)
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del state_dict[k]
|
||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
volume_decoder=None,
|
||||
surface_extractor=None
|
||||
):
|
||||
super().__init__()
|
||||
if volume_decoder is None:
|
||||
volume_decoder = VanillaVolumeDecoder()
|
||||
if surface_extractor is None:
|
||||
surface_extractor = MCSurfaceExtractor()
|
||||
self.volume_decoder = volume_decoder
|
||||
self.surface_extractor = surface_extractor
|
||||
|
||||
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
|
||||
with synchronize_timer('Volume decoding'):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
|
||||
with synchronize_timer('Surface extraction'):
|
||||
outputs = self.surface_extractor(grid_logits, **kwargs)
|
||||
return outputs
|
||||
|
||||
def enable_flashvdm_decoder(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
adaptive_kv_selection=True,
|
||||
topk_mode='mean',
|
||||
mc_algo='dmc',
|
||||
):
|
||||
if enabled:
|
||||
if adaptive_kv_selection:
|
||||
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
|
||||
else:
|
||||
self.volume_decoder = HierarchicalVolumeDecoding()
|
||||
if mc_algo not in SurfaceExtractors.keys():
|
||||
raise ValueError(f'Unsupported mc_algo {mc_algo}, available:{list(SurfaceExtractors.keys())}')
|
||||
self.surface_extractor = SurfaceExtractors[mc_algo]()
|
||||
else:
|
||||
self.volume_decoder = VanillaVolumeDecoder()
|
||||
self.surface_extractor = MCSurfaceExtractor()
|
||||
|
||||
|
||||
class ShapeVAE(VectsetVAE):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_latents: int,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
num_encoder_layers: int = 8,
|
||||
pc_size: int = 5120,
|
||||
pc_sharpedge_size: int = 5120,
|
||||
point_feats: int = 3,
|
||||
downsample_ratio: int = 20,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
use_ln_post: bool = True,
|
||||
ckpt_path = None
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttentionEncoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
num_latents=num_latents,
|
||||
downsample_ratio=self.downsample_ratio,
|
||||
pc_size=pc_size,
|
||||
pc_sharpedge_size=pc_sharpedge_size,
|
||||
point_feats=point_feats,
|
||||
width=width,
|
||||
heads=heads,
|
||||
layers=num_encoder_layers,
|
||||
qkv_bias=qkv_bias,
|
||||
use_ln_post=use_ln_post,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
self.pre_kl = nn.Linear(width, embed_dim * 2)
|
||||
self.post_kl = nn.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
num_latents=num_latents,
|
||||
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||
downsample_ratio=geo_decoder_downsample_ratio,
|
||||
enable_ln_post=self.geo_decoder_ln_post,
|
||||
width=width // geo_decoder_downsample_ratio,
|
||||
heads=heads // geo_decoder_downsample_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.latent_shape = (num_latents, embed_dim)
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
def forward(self, latents):
|
||||
latents = self.post_kl(latents)
|
||||
latents = self.transformer(latents)
|
||||
return latents
|
||||
|
||||
def encode(self, surface, sample_posterior=True):
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents, _ = self.encoder(pc, feats)
|
||||
# print(latents.shape, self.pre_kl.weight.shape)
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
||||
if sample_posterior:
|
||||
latents = posterior.sample()
|
||||
else:
|
||||
latents = posterior.mode()
|
||||
return latents
|
||||
|
||||
def decode(self, latents):
|
||||
latents = self.post_kl(latents)
|
||||
latents = self.transformer(latents)
|
||||
return latents
|
||||
164
hy3dshape/hy3dshape/models/autoencoders/surface_extractors.py
Normal file
164
hy3dshape/hy3dshape/models/autoencoders/surface_extractors.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from skimage import measure
|
||||
|
||||
|
||||
class Latent2MeshOutput:
|
||||
def __init__(self, mesh_v=None, mesh_f=None):
|
||||
self.mesh_v = mesh_v
|
||||
self.mesh_f = mesh_f
|
||||
|
||||
|
||||
def center_vertices(vertices):
|
||||
"""Translate the vertices so that bounding box is centered at zero."""
|
||||
vert_min = vertices.min(dim=0)[0]
|
||||
vert_max = vertices.max(dim=0)[0]
|
||||
vert_center = 0.5 * (vert_min + vert_max)
|
||||
return vertices - vert_center
|
||||
|
||||
|
||||
class SurfaceExtractor:
|
||||
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
|
||||
"""
|
||||
Compute grid size, bounding box minimum coordinates, and bounding box size based on input
|
||||
bounds and resolution.
|
||||
|
||||
Args:
|
||||
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
|
||||
float representing half side length.
|
||||
If float, bounds are assumed symmetric around zero in all axes.
|
||||
Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
|
||||
octree_resolution (int): Resolution of the octree grid.
|
||||
|
||||
Returns:
|
||||
grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
|
||||
bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
|
||||
bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
|
||||
"""
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
return grid_size, bbox_min, bbox_size
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
"""
|
||||
Abstract method to extract surface mesh from grid logits.
|
||||
|
||||
This method should be implemented by subclasses.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Always, since this is an abstract method.
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
def __call__(self, grid_logits, **kwargs):
|
||||
"""
|
||||
Process a batch of grid logits to extract surface meshes.
|
||||
|
||||
Args:
|
||||
grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
|
||||
**kwargs: Additional keyword arguments passed to the `run` method.
|
||||
|
||||
Returns:
|
||||
List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
|
||||
If extraction fails for a grid, None is appended at that position.
|
||||
"""
|
||||
outputs = []
|
||||
for i in range(grid_logits.shape[0]):
|
||||
try:
|
||||
vertices, faces = self.run(grid_logits[i], **kwargs)
|
||||
vertices = vertices.astype(np.float32)
|
||||
faces = np.ascontiguousarray(faces)
|
||||
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
|
||||
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
outputs.append(None)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MCSurfaceExtractor(SurfaceExtractor):
|
||||
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
|
||||
"""
|
||||
Extract surface mesh using the Marching Cubes algorithm.
|
||||
|
||||
Args:
|
||||
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
||||
mc_level (float): The level (iso-value) at which to extract the surface.
|
||||
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
|
||||
octree_resolution (int): Resolution of the octree grid.
|
||||
**kwargs: Additional keyword arguments (ignored).
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
||||
- vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
|
||||
box coordinates.
|
||||
- faces (np.ndarray): Extracted mesh faces (triangles).
|
||||
"""
|
||||
vertices, faces, normals, _ = measure.marching_cubes(grid_logit.cpu().numpy(),
|
||||
mc_level,
|
||||
method="lewiner")
|
||||
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
|
||||
vertices = vertices / grid_size * bbox_size + bbox_min
|
||||
return vertices, faces
|
||||
|
||||
|
||||
class DMCSurfaceExtractor(SurfaceExtractor):
|
||||
def run(self, grid_logit, *, octree_resolution, **kwargs):
|
||||
"""
|
||||
Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
|
||||
|
||||
Args:
|
||||
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
||||
octree_resolution (int): Resolution of the octree grid.
|
||||
**kwargs: Additional keyword arguments (ignored).
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
||||
- vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
|
||||
- faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
|
||||
|
||||
Raises:
|
||||
ImportError: If the 'diso' package is not installed.
|
||||
"""
|
||||
device = grid_logit.device
|
||||
if not hasattr(self, 'dmc'):
|
||||
try:
|
||||
from diso import DiffDMC
|
||||
self.dmc = DiffDMC(dtype=torch.float32).to(device)
|
||||
except:
|
||||
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
|
||||
sdf = -grid_logit / octree_resolution
|
||||
sdf = sdf.to(torch.float32).contiguous()
|
||||
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
|
||||
verts = center_vertices(verts)
|
||||
vertices = verts.detach().cpu().numpy()
|
||||
faces = faces.detach().cpu().numpy()[:, ::-1]
|
||||
return vertices, faces
|
||||
|
||||
|
||||
SurfaceExtractors = {
|
||||
'mc': MCSurfaceExtractor,
|
||||
'dmc': DMCSurfaceExtractor,
|
||||
}
|
||||
435
hy3dshape/hy3dshape/models/autoencoders/volume_decoders.py
Normal file
435
hy3dshape/hy3dshape/models/autoencoders/volume_decoders.py
Normal file
@@ -0,0 +1,435 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from typing import Union, Tuple, List, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
from .attention_blocks import CrossAttentionDecoder
|
||||
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
|
||||
from ...utils import logger
|
||||
|
||||
|
||||
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
|
||||
device = input_tensor.device
|
||||
D = input_tensor.shape[0]
|
||||
signed_val = 0.0
|
||||
|
||||
# 添加偏移并处理无效值
|
||||
val = input_tensor + alpha
|
||||
valid_mask = val > -9000 # 假设-9000是无效值
|
||||
|
||||
# 改进的邻居获取函数(保持维度一致)
|
||||
def get_neighbor(t, shift, axis):
|
||||
"""根据指定轴进行位移并保持维度一致"""
|
||||
if shift == 0:
|
||||
return t.clone()
|
||||
|
||||
# 确定填充轴(输入为[D, D, D]对应z,y,x轴)
|
||||
pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后]
|
||||
|
||||
# 根据轴类型设置填充
|
||||
if axis == 0: # x轴(最后一个维度)
|
||||
pad_idx = 0 if shift > 0 else 1
|
||||
pad_dims[pad_idx] = abs(shift)
|
||||
elif axis == 1: # y轴(中间维度)
|
||||
pad_idx = 2 if shift > 0 else 3
|
||||
pad_dims[pad_idx] = abs(shift)
|
||||
elif axis == 2: # z轴(第一个维度)
|
||||
pad_idx = 4 if shift > 0 else 5
|
||||
pad_dims[pad_idx] = abs(shift)
|
||||
|
||||
# 执行填充(添加batch和channel维度适配F.pad)
|
||||
padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad
|
||||
|
||||
# 构建动态切片索引
|
||||
slice_dims = [slice(None)] * 3 # 初始化为全切片
|
||||
if axis == 0: # x轴(dim=2)
|
||||
if shift > 0:
|
||||
slice_dims[0] = slice(shift, None)
|
||||
else:
|
||||
slice_dims[0] = slice(None, shift)
|
||||
elif axis == 1: # y轴(dim=1)
|
||||
if shift > 0:
|
||||
slice_dims[1] = slice(shift, None)
|
||||
else:
|
||||
slice_dims[1] = slice(None, shift)
|
||||
elif axis == 2: # z轴(dim=0)
|
||||
if shift > 0:
|
||||
slice_dims[2] = slice(shift, None)
|
||||
else:
|
||||
slice_dims[2] = slice(None, shift)
|
||||
|
||||
# 应用切片并恢复维度
|
||||
padded = padded.squeeze(0).squeeze(0)
|
||||
sliced = padded[slice_dims]
|
||||
return sliced
|
||||
|
||||
# 获取各方向邻居(确保维度一致)
|
||||
left = get_neighbor(val, 1, axis=0) # x方向
|
||||
right = get_neighbor(val, -1, axis=0)
|
||||
back = get_neighbor(val, 1, axis=1) # y方向
|
||||
front = get_neighbor(val, -1, axis=1)
|
||||
down = get_neighbor(val, 1, axis=2) # z方向
|
||||
up = get_neighbor(val, -1, axis=2)
|
||||
|
||||
# 处理边界无效值(使用where保持维度一致)
|
||||
def safe_where(neighbor):
|
||||
return torch.where(neighbor > -9000, neighbor, val)
|
||||
|
||||
left = safe_where(left)
|
||||
right = safe_where(right)
|
||||
back = safe_where(back)
|
||||
front = safe_where(front)
|
||||
down = safe_where(down)
|
||||
up = safe_where(up)
|
||||
|
||||
# 计算符号一致性(转换为float32确保精度)
|
||||
sign = torch.sign(val.to(torch.float32))
|
||||
neighbors_sign = torch.stack([
|
||||
torch.sign(left.to(torch.float32)),
|
||||
torch.sign(right.to(torch.float32)),
|
||||
torch.sign(back.to(torch.float32)),
|
||||
torch.sign(front.to(torch.float32)),
|
||||
torch.sign(down.to(torch.float32)),
|
||||
torch.sign(up.to(torch.float32))
|
||||
], dim=0)
|
||||
|
||||
# 检查所有符号是否一致
|
||||
same_sign = torch.all(neighbors_sign == sign, dim=0)
|
||||
|
||||
# 生成最终掩码
|
||||
mask = (~same_sign).to(torch.int32)
|
||||
return mask * valid_mask.to(torch.int32)
|
||||
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class HierarchicalVolumeDecoding:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
mc_level: float = 0.0,
|
||||
octree_resolution: int = None,
|
||||
min_resolution: int = 63,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
|
||||
resolutions = []
|
||||
if octree_resolution < min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
while octree_resolution >= min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
octree_resolution = octree_resolution // 2
|
||||
resolutions.reverse()
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=resolutions[0],
|
||||
indexing="ij"
|
||||
)
|
||||
|
||||
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
||||
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
||||
|
||||
grid_size = np.array(grid_size)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
batch_size = latents.shape[0]
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
||||
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
|
||||
queries = xyz_samples[start: start + num_chunks, :]
|
||||
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=batch_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
|
||||
|
||||
for octree_depth_now in resolutions[1:]:
|
||||
grid_size = np.array([octree_depth_now + 1] * 3)
|
||||
resolution = bbox_size / octree_depth_now
|
||||
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
||||
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
||||
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
||||
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
||||
|
||||
if octree_depth_now == resolutions[-1]:
|
||||
expand_num = 0
|
||||
else:
|
||||
expand_num = 1
|
||||
for i in range(expand_num):
|
||||
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
||||
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
||||
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
||||
for i in range(2 - expand_num):
|
||||
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
||||
nidx = torch.where(next_index > 0)
|
||||
|
||||
next_points = torch.stack(nidx, dim=1)
|
||||
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
|
||||
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, next_points.shape[0], num_chunks),
|
||||
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
|
||||
queries = next_points[start: start + num_chunks, :]
|
||||
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
|
||||
batch_logits.append(logits)
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
next_logits[nidx] = grid_logits[0, ..., 0]
|
||||
grid_logits = next_logits.unsqueeze(0)
|
||||
grid_logits[grid_logits == -10000.] = float('nan')
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FlashVDMVolumeDecoding:
|
||||
def __init__(self, topk_mode='mean'):
|
||||
if topk_mode not in ['mean', 'merge']:
|
||||
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
|
||||
|
||||
if topk_mode == 'mean':
|
||||
self.processor = FlashVDMCrossAttentionProcessor()
|
||||
else:
|
||||
self.processor = FlashVDMTopMCrossAttentionProcessor()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: CrossAttentionDecoder,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
mc_level: float = 0.0,
|
||||
octree_resolution: int = None,
|
||||
min_resolution: int = 63,
|
||||
mini_grid_num: int = 4,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
processor = self.processor
|
||||
geo_decoder.set_cross_attention_processor(processor)
|
||||
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
|
||||
resolutions = []
|
||||
if octree_resolution < min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
while octree_resolution >= min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
octree_resolution = octree_resolution // 2
|
||||
resolutions.reverse()
|
||||
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
|
||||
for i, resolution in enumerate(resolutions[1:]):
|
||||
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
|
||||
|
||||
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=resolutions[0],
|
||||
indexing="ij"
|
||||
)
|
||||
|
||||
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
||||
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
||||
|
||||
grid_size = np.array(grid_size)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
|
||||
batch_size = latents.shape[0]
|
||||
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
|
||||
xyz_samples = xyz_samples.view(
|
||||
mini_grid_num, mini_grid_size,
|
||||
mini_grid_num, mini_grid_size,
|
||||
mini_grid_num, mini_grid_size, 3
|
||||
).permute(
|
||||
0, 2, 4, 1, 3, 5, 6
|
||||
).reshape(
|
||||
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
|
||||
)
|
||||
batch_logits = []
|
||||
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
|
||||
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
|
||||
queries = xyz_samples[start: start + num_batchs, :]
|
||||
batch = queries.shape[0]
|
||||
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
|
||||
processor.topk = True
|
||||
logits = geo_decoder(queries=queries, latents=batch_latents)
|
||||
batch_logits.append(logits)
|
||||
grid_logits = torch.cat(batch_logits, dim=0).reshape(
|
||||
mini_grid_num, mini_grid_num, mini_grid_num,
|
||||
mini_grid_size, mini_grid_size,
|
||||
mini_grid_size
|
||||
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||
(batch_size, grid_size[0], grid_size[1], grid_size[2])
|
||||
)
|
||||
|
||||
for octree_depth_now in resolutions[1:]:
|
||||
grid_size = np.array([octree_depth_now + 1] * 3)
|
||||
resolution = bbox_size / octree_depth_now
|
||||
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
||||
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
||||
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
||||
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
||||
|
||||
if octree_depth_now == resolutions[-1]:
|
||||
expand_num = 0
|
||||
else:
|
||||
expand_num = 1
|
||||
for i in range(expand_num):
|
||||
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
||||
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
||||
|
||||
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
||||
for i in range(2 - expand_num):
|
||||
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
||||
nidx = torch.where(next_index > 0)
|
||||
|
||||
next_points = torch.stack(nidx, dim=1)
|
||||
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
|
||||
torch.tensor(bbox_min, dtype=torch.float32, device=device))
|
||||
|
||||
query_grid_num = 6
|
||||
min_val = next_points.min(axis=0).values
|
||||
max_val = next_points.max(axis=0).values
|
||||
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
|
||||
index = torch.floor(vol_queries_index).long()
|
||||
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
|
||||
index = index.sort()
|
||||
next_points = next_points[index.indices].unsqueeze(0).contiguous()
|
||||
unique_values = torch.unique(index.values, return_counts=True)
|
||||
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
|
||||
input_grid = [[], []]
|
||||
logits_grid_list = []
|
||||
start_num = 0
|
||||
sum_num = 0
|
||||
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
|
||||
if sum_num + count < num_chunks or sum_num == 0:
|
||||
sum_num += count
|
||||
input_grid[0].append(grid_index)
|
||||
input_grid[1].append(count)
|
||||
else:
|
||||
processor.topk = input_grid
|
||||
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
||||
start_num = start_num + sum_num
|
||||
logits_grid_list.append(logits_grid)
|
||||
input_grid = [[grid_index], [count]]
|
||||
sum_num = count
|
||||
if sum_num > 0:
|
||||
processor.topk = input_grid
|
||||
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
||||
logits_grid_list.append(logits_grid)
|
||||
logits_grid = torch.cat(logits_grid_list, dim=1)
|
||||
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
|
||||
next_logits[nidx] = grid_logits
|
||||
grid_logits = next_logits.unsqueeze(0)
|
||||
|
||||
grid_logits[grid_logits == -10000.] = float('nan')
|
||||
|
||||
return grid_logits
|
||||
257
hy3dshape/hy3dshape/models/conditioner.py
Normal file
257
hy3dshape/hy3dshape/models/conditioner.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0
|
||||
# and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionConfig,
|
||||
Dinov2Model,
|
||||
Dinov2Config,
|
||||
)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1)
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
version=None,
|
||||
config=None,
|
||||
use_cls_token=True,
|
||||
image_size=224,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
self.model = self.MODEL_CLASS.from_pretrained(version)
|
||||
else:
|
||||
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
self.use_cls_token = use_cls_token
|
||||
self.size = image_size // 14
|
||||
self.num_patches = (image_size // 14) ** 2
|
||||
if self.use_cls_token:
|
||||
self.num_patches += 1
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.Normalize(
|
||||
mean=self.mean,
|
||||
std=self.std,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
|
||||
image = image.to(self.model.device, dtype=self.model.dtype)
|
||||
inputs = self.transform(image)
|
||||
outputs = self.model(inputs)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
if not self.use_cls_token:
|
||||
last_hidden_state = last_hidden_state[:, 1:, :]
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
def unconditional_embedding(self, batch_size, **kwargs):
|
||||
device = next(self.model.parameters()).device
|
||||
dtype = next(self.model.parameters()).dtype
|
||||
zero = torch.zeros(
|
||||
batch_size,
|
||||
self.num_patches,
|
||||
self.model.config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return zero
|
||||
|
||||
|
||||
class CLIPImageEncoder(ImageEncoder):
|
||||
MODEL_CLASS = CLIPVisionModelWithProjection
|
||||
MODEL_CONFIG_CLASS = CLIPVisionConfig
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
|
||||
class DinoImageEncoder(ImageEncoder):
|
||||
MODEL_CLASS = Dinov2Model
|
||||
MODEL_CONFIG_CLASS = Dinov2Config
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
class DinoImageEncoderMV(DinoImageEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
version=None,
|
||||
config=None,
|
||||
use_cls_token=True,
|
||||
image_size=224,
|
||||
view_num=4,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(version, config, use_cls_token, image_size, **kwargs)
|
||||
self.view_num = view_num
|
||||
self.num_patches = self.num_patches
|
||||
pos = np.arange(self.view_num, dtype=np.float32)
|
||||
view_embedding = torch.from_numpy(
|
||||
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
|
||||
|
||||
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
|
||||
self.view_embed = view_embedding.unsqueeze(0)
|
||||
|
||||
def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
|
||||
image = image.to(self.model.device, dtype=self.model.dtype)
|
||||
|
||||
bs, num_views, c, h, w = image.shape
|
||||
image = image.view(bs * num_views, c, h, w)
|
||||
|
||||
inputs = self.transform(image)
|
||||
outputs = self.model(inputs)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = last_hidden_state.view(
|
||||
bs, num_views, last_hidden_state.shape[-2],
|
||||
last_hidden_state.shape[-1]
|
||||
)
|
||||
|
||||
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
|
||||
if view_idxs is not None:
|
||||
assert len(view_idxs) == bs
|
||||
view_embeddings = []
|
||||
for i in range(bs):
|
||||
view_idx = view_idxs[i]
|
||||
assert num_views == len(view_idx)
|
||||
view_embeddings.append(self.view_embed[:, view_idx, ...])
|
||||
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
|
||||
|
||||
if num_views != self.view_num:
|
||||
view_embedding = view_embedding[:, :num_views, ...]
|
||||
last_hidden_state = last_hidden_state + view_embedding
|
||||
last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
|
||||
last_hidden_state.shape[-1])
|
||||
return last_hidden_state
|
||||
|
||||
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
|
||||
device = next(self.model.parameters()).device
|
||||
dtype = next(self.model.parameters()).dtype
|
||||
zero = torch.zeros(
|
||||
batch_size,
|
||||
self.num_patches * len(view_idxs[0]),
|
||||
self.model.config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
return zero
|
||||
|
||||
|
||||
def build_image_encoder(config):
|
||||
if config['type'] == 'CLIPImageEncoder':
|
||||
return CLIPImageEncoder(**config['kwargs'])
|
||||
elif config['type'] == 'DinoImageEncoder':
|
||||
return DinoImageEncoder(**config['kwargs'])
|
||||
elif config['type'] == 'DinoImageEncoderMV':
|
||||
return DinoImageEncoderMV(**config['kwargs'])
|
||||
else:
|
||||
raise ValueError(f'Unknown image encoder type: {config["type"]}')
|
||||
|
||||
|
||||
class DualImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
main_image_encoder,
|
||||
additional_image_encoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
||||
|
||||
def forward(self, image, mask=None, **kwargs):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
||||
'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
|
||||
}
|
||||
return outputs
|
||||
|
||||
def unconditional_embedding(self, batch_size, **kwargs):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
||||
'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
class SingleImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
main_image_encoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||
|
||||
def forward(self, image, mask=None, **kwargs):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
||||
}
|
||||
return outputs
|
||||
|
||||
def unconditional_embedding(self, batch_size, **kwargs):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
||||
}
|
||||
return outputs
|
||||
15
hy3dshape/hy3dshape/models/denoisers/__init__.py
Normal file
15
hy3dshape/hy3dshape/models/denoisers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from .hunyuan3ddit import Hunyuan3DDiT
|
||||
404
hy3dshape/hy3dshape/models/denoisers/hunyuan3ddit.py
Normal file
404
hy3dshape/hy3dshape/models/denoisers/hunyuan3ddit.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange
|
||||
|
||||
# set up attention backend
|
||||
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
||||
if os.environ.get('USE_SAGEATTN', '0') == '1':
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ImportError:
|
||||
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||
scaled_dot_product_attention = sageattn
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
||||
x = scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
return x
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, approximate='tanh'):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :]
|
||||
out = out.chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class Hunyuan3DDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
context_in_dim: int = 1536,
|
||||
hidden_size: int = 1024,
|
||||
mlp_ratio: float = 4.0,
|
||||
num_heads: int = 16,
|
||||
depth: int = 16,
|
||||
depth_single_blocks: int = 32,
|
||||
axes_dim: List[int] = [64],
|
||||
theta: int = 10_000,
|
||||
qkv_bias: bool = True,
|
||||
time_factor: float = 1000,
|
||||
guidance_embed: bool = False,
|
||||
ckpt_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.context_in_dim = context_in_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.depth = depth
|
||||
self.depth_single_blocks = depth_single_blocks
|
||||
self.axes_dim = axes_dim
|
||||
self.theta = theta
|
||||
self.qkv_bias = qkv_bias
|
||||
self.time_factor = time_factor
|
||||
self.out_channels = self.in_channels
|
||||
self.guidance_embed = guidance_embed
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||
)
|
||||
pe_dim = hidden_size // num_heads
|
||||
if sum(axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.latent_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.cond_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else nn.Identity()
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
if ckpt_path is not None:
|
||||
print('restored denoiser ckpt', ckpt_path)
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
if 'state_dict' not in ckpt:
|
||||
# deepspeed ckpt
|
||||
state_dict = {}
|
||||
for k in ckpt.keys():
|
||||
new_k = k.replace('_forward_module.', '')
|
||||
state_dict[new_k] = ckpt[k]
|
||||
else:
|
||||
state_dict = ckpt["state_dict"]
|
||||
|
||||
final_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('model.'):
|
||||
final_state_dict[k.replace('model.', '')] = v
|
||||
else:
|
||||
final_state_dict[k] = v
|
||||
missing, unexpected = self.load_state_dict(final_state_dict, strict=False)
|
||||
print('unexpected keys:', unexpected)
|
||||
print('missing keys:', missing)
|
||||
|
||||
def forward(self, x, t, contexts, **kwargs) -> Tensor:
|
||||
cond = contexts['main']
|
||||
latent = self.latent_in(x)
|
||||
|
||||
vec = self.time_in(timestep_embedding(t, 256, self.time_factor).to(dtype=latent.dtype))
|
||||
if self.guidance_embed:
|
||||
guidance = kwargs.get('guidance', None)
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.time_factor))
|
||||
|
||||
cond = self.cond_in(cond)
|
||||
pe = None
|
||||
|
||||
for block in self.double_blocks:
|
||||
latent, cond = block(img=latent, txt=cond, vec=vec, pe=pe)
|
||||
|
||||
latent = torch.cat((cond, latent), 1)
|
||||
for block in self.single_blocks:
|
||||
latent = block(latent, vec=vec, pe=pe)
|
||||
|
||||
latent = latent[:, cond.shape[1]:, ...]
|
||||
latent = self.final_layer(latent, vec)
|
||||
return latent
|
||||
596
hy3dshape/hy3dshape/models/denoisers/hunyuandit.py
Normal file
596
hy3dshape/hy3dshape/models/denoisers/hunyuandit.py
Normal file
@@ -0,0 +1,596 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0
|
||||
# and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .moe_layers import MoEBlock
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
downscale_freq_shift: float = 0.0,
|
||||
scale: int = 1,
|
||||
max_period: int = 10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.max_period = max_period
|
||||
|
||||
def forward(self, timesteps):
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
embedding_dim = self.num_channels
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(self.max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - self.downscale_freq_shift)
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
emb = self.scale * emb
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None):
|
||||
super().__init__()
|
||||
if out_size is None:
|
||||
out_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, frequency_embedding_size, bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(frequency_embedding_size, out_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, frequency_embedding_size, bias=False)
|
||||
|
||||
self.time_embed = Timesteps(hidden_size)
|
||||
|
||||
def forward(self, t, condition):
|
||||
|
||||
t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
|
||||
|
||||
# t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||
if condition is not None:
|
||||
t_freq = t_freq + self.cond_proj(condition)
|
||||
|
||||
t = self.mlp(t_freq)
|
||||
t = t.unsqueeze(dim=1)
|
||||
return t
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, *, width: int):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.fc1 = nn.Linear(width, width * 4)
|
||||
self.fc2 = nn.Linear(width * 4, width)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
with_decoupled_ca=False,
|
||||
decoupled_ca_dim=16,
|
||||
decoupled_ca_weight=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
self.num_heads = num_heads
|
||||
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
||||
self.head_dim = self.qdim // num_heads
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
|
||||
self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
|
||||
self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
|
||||
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.out_proj = nn.Linear(qdim, qdim, bias=True)
|
||||
|
||||
self.with_dca = with_decoupled_ca
|
||||
if self.with_dca:
|
||||
self.kv_proj_dca = nn.Linear(kdim, 2 * qdim, bias=qkv_bias)
|
||||
self.k_norm_dca = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.dca_dim = decoupled_ca_dim
|
||||
self.dca_weight = decoupled_ca_weight
|
||||
|
||||
def forward(self, x, y):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
y: torch.Tensor
|
||||
(batch, seqlen2, hidden_dim2)
|
||||
freqs_cis_img: torch.Tensor
|
||||
(batch, hidden_dim // 2), RoPE for image
|
||||
"""
|
||||
b, s1, c = x.shape # [b, s1, D]
|
||||
|
||||
if self.with_dca:
|
||||
token_len = y.shape[1]
|
||||
context_dca = y[:, -self.dca_dim:, :]
|
||||
kv_dca = self.kv_proj_dca(context_dca).view(b, self.dca_dim, 2, self.num_heads, self.head_dim)
|
||||
k_dca, v_dca = kv_dca.unbind(dim=2) # [b, s, h, d]
|
||||
k_dca = self.k_norm_dca(k_dca)
|
||||
y = y[:, :(token_len - self.dca_dim), :]
|
||||
|
||||
_, s2, c = y.shape # [b, s2, 1024]
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(y)
|
||||
v = self.to_v(y)
|
||||
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
||||
k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
||||
v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=True
|
||||
):
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads), (q, k, v))
|
||||
context = F.scaled_dot_product_attention(
|
||||
q, k, v
|
||||
).transpose(1, 2).reshape(b, s1, -1)
|
||||
|
||||
if self.with_dca:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=True
|
||||
):
|
||||
k_dca, v_dca = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads),
|
||||
(k_dca, v_dca))
|
||||
context_dca = F.scaled_dot_product_attention(
|
||||
q, k_dca, v_dca).transpose(1, 2).reshape(b, s1, -1)
|
||||
|
||||
context = context + self.dca_weight * context_dca
|
||||
|
||||
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
We rename some layer names to align with flash attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.head_dim = self.dim // num_heads
|
||||
# This assertion is aligned with flash attention
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(x)
|
||||
v = self.to_v(x)
|
||||
|
||||
qkv = torch.cat((q, k, v), dim=-1)
|
||||
split_size = qkv.shape[-1] // self.num_heads // 3
|
||||
qkv = qkv.view(1, -1, self.num_heads, split_size * 3)
|
||||
q, k, v = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
|
||||
k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
|
||||
v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
q = self.q_norm(q) # [b, h, s, d]
|
||||
k = self.k_norm(k) # [b, h, s, d]
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=True
|
||||
):
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
text_states_dim=1024,
|
||||
use_flash_attn=False,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm_layer=nn.RMSNorm,
|
||||
with_decoupled_ca=False,
|
||||
decoupled_ca_dim=16,
|
||||
decoupled_ca_weight=1.0,
|
||||
init_scale=1.0,
|
||||
qkv_bias=True,
|
||||
skip_connection=True,
|
||||
timested_modulate=False,
|
||||
use_moe: bool = False,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
use_ele_affine = True
|
||||
|
||||
# ========================= Self-Attention =========================
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||
norm_layer=qk_norm_layer)
|
||||
|
||||
# ========================= FFN =========================
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
||||
|
||||
# ========================= Add =========================
|
||||
# Simply use add like SDXL.
|
||||
self.timested_modulate = timested_modulate
|
||||
if self.timested_modulate:
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(c_emb_size, hidden_size, bias=True)
|
||||
)
|
||||
|
||||
# ========================= Cross-Attention =========================
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm, norm_layer=qk_norm_layer,
|
||||
with_decoupled_ca=with_decoupled_ca, decoupled_ca_dim=decoupled_ca_dim,
|
||||
decoupled_ca_weight=decoupled_ca_weight, init_scale=init_scale,
|
||||
)
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
if skip_connection:
|
||||
self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.use_moe = use_moe
|
||||
if self.use_moe:
|
||||
print("using moe")
|
||||
self.moe = MoEBlock(
|
||||
hidden_size,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k,
|
||||
dropout=0.0,
|
||||
activation_fn="gelu",
|
||||
final_dropout=False,
|
||||
ff_inner_dim=int(hidden_size * 4.0),
|
||||
ff_bias=True,
|
||||
)
|
||||
else:
|
||||
self.mlp = MLP(width=hidden_size)
|
||||
|
||||
def forward(self, x, c=None, text_states=None, skip_value=None):
|
||||
|
||||
if self.skip_linear is not None:
|
||||
cat = torch.cat([skip_value, x], dim=-1)
|
||||
x = self.skip_linear(cat)
|
||||
x = self.skip_norm(x)
|
||||
|
||||
# Self-Attention
|
||||
if self.timested_modulate:
|
||||
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
||||
x = x + shift_msa
|
||||
|
||||
attn_out = self.attn1(self.norm1(x))
|
||||
|
||||
x = x + attn_out
|
||||
|
||||
# Cross-Attention
|
||||
x = x + self.attn2(self.norm2(x), text_states)
|
||||
|
||||
# FFN Layer
|
||||
mlp_inputs = self.norm3(x)
|
||||
|
||||
if self.use_moe:
|
||||
x = x + self.moe(mlp_inputs)
|
||||
else:
|
||||
x = x + self.mlp(mlp_inputs)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x, attention_mask=None):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
|
||||
global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
|
||||
x = torch.cat([global_emb[None,], x], dim=0)
|
||||
|
||||
else:
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of HunYuanDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, final_hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.final_hidden_size = final_hidden_size
|
||||
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = x[:, 1:]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=1024,
|
||||
in_channels=4,
|
||||
hidden_size=1024,
|
||||
context_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
norm_type='layer',
|
||||
qk_norm_type='rms',
|
||||
qk_norm=False,
|
||||
text_len=257,
|
||||
with_decoupled_ca=False,
|
||||
additional_cond_hidden_state=768,
|
||||
decoupled_ca_dim=16,
|
||||
decoupled_ca_weight=1.0,
|
||||
use_pos_emb=False,
|
||||
use_attention_pooling=True,
|
||||
guidance_cond_proj_dim=None,
|
||||
qkv_bias=True,
|
||||
num_moe_layers: int = 6,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.depth = depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.norm = nn.LayerNorm if norm_type == 'layer' else nn.RMSNorm
|
||||
self.qk_norm = nn.RMSNorm if qk_norm_type == 'rms' else nn.LayerNorm
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.with_decoupled_ca = with_decoupled_ca
|
||||
self.decoupled_ca_dim = decoupled_ca_dim
|
||||
self.decoupled_ca_weight = decoupled_ca_weight
|
||||
self.use_pos_emb = use_pos_emb
|
||||
self.use_attention_pooling = use_attention_pooling
|
||||
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||
|
||||
self.text_len = text_len
|
||||
|
||||
self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim)
|
||||
|
||||
# Will use fixed sin-cos embedding:
|
||||
if self.use_pos_emb:
|
||||
self.register_buffer("pos_embed", torch.zeros(1, input_size, hidden_size))
|
||||
pos = np.arange(self.input_size, dtype=np.float32)
|
||||
pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], pos)
|
||||
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||
|
||||
self.use_attention_pooling = use_attention_pooling
|
||||
if use_attention_pooling:
|
||||
self.pooler = AttentionPool(self.text_len, context_dim, num_heads=8, output_dim=1024)
|
||||
self.extra_embedder = nn.Sequential(
|
||||
nn.Linear(1024, hidden_size * 4),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size * 4, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
if with_decoupled_ca:
|
||||
self.additional_cond_hidden_state = additional_cond_hidden_state
|
||||
self.additional_cond_proj = nn.Sequential(
|
||||
nn.Linear(additional_cond_hidden_state, hidden_size * 4),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size * 4, 1024, bias=True),
|
||||
)
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
text_states_dim=context_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer=self.norm,
|
||||
qk_norm_layer=self.qk_norm,
|
||||
skip_connection=layer > depth // 2,
|
||||
with_decoupled_ca=with_decoupled_ca,
|
||||
decoupled_ca_dim=decoupled_ca_dim,
|
||||
decoupled_ca_weight=decoupled_ca_weight,
|
||||
qkv_bias=qkv_bias,
|
||||
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k
|
||||
)
|
||||
for layer in range(depth)
|
||||
])
|
||||
self.depth = depth
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, self.out_channels)
|
||||
|
||||
def forward(self, x, t, contexts, **kwargs):
|
||||
cond = contexts['main']
|
||||
|
||||
t = self.t_embedder(t, condition=kwargs.get('guidance_cond'))
|
||||
x = self.x_embedder(x)
|
||||
|
||||
if self.use_pos_emb:
|
||||
pos_embed = self.pos_embed.to(x.dtype)
|
||||
x = x + pos_embed
|
||||
|
||||
if self.use_attention_pooling:
|
||||
extra_vec = self.pooler(cond, None)
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
else:
|
||||
c = t
|
||||
|
||||
if self.with_decoupled_ca:
|
||||
additional_cond = self.additional_cond_proj(contexts['additional'])
|
||||
cond = torch.cat([cond, additional_cond], dim=1)
|
||||
|
||||
x = torch.cat([c, x], dim=1)
|
||||
|
||||
skip_value_list = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
|
||||
x = block(x, c, cond, skip_value=skip_value)
|
||||
if layer < self.depth // 2:
|
||||
skip_value_list.append(x)
|
||||
|
||||
x = self.final_layer(x)
|
||||
return x
|
||||
177
hy3dshape/hy3dshape/models/denoisers/moe_layers.py
Normal file
177
hy3dshape/hy3dshape/models/denoisers/moe_layers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
||||
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention import FeedForward
|
||||
|
||||
class AddAuxiliaryLoss(torch.autograd.Function):
|
||||
"""
|
||||
The trick function of adding auxiliary (aux) loss,
|
||||
which includes the gradient of the aux loss during backpropagation.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
assert loss.numel() == 1
|
||||
ctx.dtype = loss.dtype
|
||||
ctx.required_aux_loss = loss.requires_grad
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_loss = None
|
||||
if ctx.required_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
||||
return grad_output, grad_loss
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
|
||||
super().__init__()
|
||||
self.top_k = num_experts_per_tok
|
||||
self.n_routed_experts = num_experts
|
||||
|
||||
self.scoring_func = 'softmax'
|
||||
self.alpha = aux_loss_alpha
|
||||
self.seq_aux = False
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = False
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
# print(bsz, seq_len, h)
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
### select top-k experts
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(
|
||||
1,
|
||||
topk_idx_for_aux_loss,
|
||||
torch.ones(
|
||||
bsz, seq_len * aux_topk,
|
||||
device=hidden_states.device
|
||||
)
|
||||
).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean()
|
||||
aux_loss = aux_loss * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1),
|
||||
num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
class MoEBlock(nn.Module):
|
||||
def __init__(self, dim, num_experts=8, moe_top_k=2,
|
||||
activation_fn = "gelu", dropout=0.0, final_dropout = False,
|
||||
ff_inner_dim = None, ff_bias = True):
|
||||
super().__init__()
|
||||
self.moe_top_k = moe_top_k
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(dim,dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias)
|
||||
for i in range(num_experts)])
|
||||
self.gate = MoEGate(embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k)
|
||||
|
||||
self.shared_experts = FeedForward(dim,dropout=dropout, activation_fn=activation_fn,
|
||||
final_dropout=final_dropout, inner_dim=ff_inner_dim,
|
||||
bias=ff_bias)
|
||||
|
||||
def initialize_weight(self):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
|
||||
y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.moe_top_k
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# for fp16 and other dtype
|
||||
expert_cache = expert_cache.to(expert_out.dtype)
|
||||
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
|
||||
expert_out,
|
||||
reduce='sum')
|
||||
return expert_cache
|
||||
354
hy3dshape/hy3dshape/models/diffusion/flow_matching_sit.py
Normal file
354
hy3dshape/hy3dshape/models/diffusion/flow_matching_sit.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import lr_scheduler
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from ...utils.ema import LitEma
|
||||
from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model
|
||||
|
||||
|
||||
|
||||
class Diffuser(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
first_stage_config,
|
||||
cond_stage_config,
|
||||
denoiser_cfg,
|
||||
scheduler_cfg,
|
||||
optimizer_cfg,
|
||||
pipeline_cfg=None,
|
||||
image_processor_cfg=None,
|
||||
lora_config=None,
|
||||
ema_config=None,
|
||||
first_stage_key: str = "surface",
|
||||
cond_stage_key: str = "image",
|
||||
scale_by_std: bool = False,
|
||||
z_scale_factor: float = 1.0,
|
||||
ckpt_path: Optional[str] = None,
|
||||
ignore_keys: Union[Tuple[str], List[str]] = (),
|
||||
torch_compile: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.first_stage_key = first_stage_key
|
||||
self.cond_stage_key = cond_stage_key
|
||||
|
||||
# ========= init optimizer config ========= #
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
|
||||
# ========= init diffusion scheduler ========= #
|
||||
self.scheduler_cfg = scheduler_cfg
|
||||
self.sampler = None
|
||||
if 'transport' in scheduler_cfg:
|
||||
self.transport = instantiate_from_config(scheduler_cfg.transport)
|
||||
self.sampler = instantiate_from_config(scheduler_cfg.sampler, transport=self.transport)
|
||||
self.sample_fn = self.sampler.sample_ode(**scheduler_cfg.sampler.ode_params)
|
||||
|
||||
# ========= init the model ========= #
|
||||
self.denoiser_cfg = denoiser_cfg
|
||||
self.model = instantiate_from_config(denoiser_cfg, device=None, dtype=None)
|
||||
self.cond_stage_model = instantiate_from_config(cond_stage_config)
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
# ========= config lora model ========= #
|
||||
if lora_config is not None:
|
||||
from peft import LoraConfig, get_peft_model
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_config.rank,
|
||||
lora_alpha=lora_config.rank,
|
||||
target_modules=lora_config.get('target_modules')
|
||||
)
|
||||
self.model = get_peft_model(self.model, loraconfig)
|
||||
|
||||
# ========= config ema model ========= #
|
||||
self.ema_config = ema_config
|
||||
if self.ema_config is not None:
|
||||
if self.ema_config.ema_model == 'DSEma':
|
||||
# from michelangelo.models.modules.ema_deepspeed import DSEma
|
||||
from ..utils.ema_deepspeed import DSEma
|
||||
self.model_ema = DSEma(self.model, decay=self.ema_config.ema_decay)
|
||||
else:
|
||||
self.model_ema = LitEma(self.model, decay=self.ema_config.ema_decay)
|
||||
#do not initilize EMA weight from ckpt path, since I need to change moe layers
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
# ========= init vae at last to prevent it is overridden by loaded ckpt ========= #
|
||||
self.first_stage_model = instantiate_non_trainable_model(first_stage_config)
|
||||
|
||||
self.scale_by_std = scale_by_std
|
||||
if scale_by_std:
|
||||
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
|
||||
else:
|
||||
self.z_scale_factor = z_scale_factor
|
||||
|
||||
# ========= init pipeline for inference ========= #
|
||||
self.image_processor_cfg = image_processor_cfg
|
||||
self.image_processor = None
|
||||
if self.image_processor_cfg is not None:
|
||||
self.image_processor = instantiate_from_config(self.image_processor_cfg)
|
||||
self.pipeline_cfg = pipeline_cfg
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
self.pipeline = instantiate_from_config(
|
||||
pipeline_cfg,
|
||||
vae=self.first_stage_model,
|
||||
model=self.model,
|
||||
scheduler=scheduler, # self.sampler,
|
||||
conditioner=self.cond_stage_model,
|
||||
image_processor=self.image_processor,
|
||||
)
|
||||
|
||||
# ========= torch compile to accelerate ========= #
|
||||
self.torch_compile = torch_compile
|
||||
if self.torch_compile:
|
||||
torch.nn.Module.compile(self.model)
|
||||
torch.nn.Module.compile(self.first_stage_model)
|
||||
torch.nn.Module.compile(self.cond_stage_model)
|
||||
print(f'*' * 100)
|
||||
print(f'Compile model for acceleration')
|
||||
print(f'*' * 100)
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
|
||||
self.model_ema.store(self.model)
|
||||
self.model_ema.copy_to(self.model)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
|
||||
self.model_ema.restore(self.model)
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
ckpt = torch.load(path, map_location="cpu")
|
||||
if 'state_dict' not in ckpt:
|
||||
# deepspeed ckpt
|
||||
state_dict = {}
|
||||
for k in ckpt.keys():
|
||||
new_k = k.replace('_forward_module.', '')
|
||||
state_dict[new_k] = ckpt[k]
|
||||
else:
|
||||
state_dict = ckpt["state_dict"]
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if ik in k:
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del state_dict[k]
|
||||
|
||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""
|
||||
The pt_model is trained separately, so we already have access to its
|
||||
checkpoint and load it separately with `self.set_pt_model`.
|
||||
|
||||
However, the PL Trainer is strict about
|
||||
checkpoint loading (not configurable), so it expects the loaded state_dict
|
||||
to match exactly the keys in the model state_dict.
|
||||
|
||||
So, when loading the checkpoint, before matching keys, we add all pt_model keys
|
||||
from self.state_dict() to the checkpoint state dict, so that they match
|
||||
"""
|
||||
for key in self.state_dict().keys():
|
||||
if key.startswith("model_ema") and key not in checkpoint["state_dict"]:
|
||||
checkpoint["state_dict"][key] = self.state_dict()[key]
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List, List]:
|
||||
lr = self.learning_rate
|
||||
|
||||
params_list = []
|
||||
trainable_parameters = list(self.model.parameters())
|
||||
params_list.append({'params': trainable_parameters, 'lr': lr})
|
||||
|
||||
no_decay = ['bias', 'norm.weight', 'norm.bias', 'norm1.weight', 'norm1.bias', 'norm2.weight', 'norm2.bias']
|
||||
|
||||
|
||||
if self.optimizer_cfg.get('train_image_encoder', False):
|
||||
image_encoder_parameters = list(self.cond_stage_model.named_parameters())
|
||||
image_encoder_parameters_decay = [param for name, param in image_encoder_parameters if
|
||||
not any((no_decay_name in name) for no_decay_name in no_decay)]
|
||||
image_encoder_parameters_nodecay = [param for name, param in image_encoder_parameters if
|
||||
any((no_decay_name in name) for no_decay_name in no_decay)]
|
||||
# filter trainable params
|
||||
image_encoder_parameters_decay = [param for param in image_encoder_parameters_decay if
|
||||
param.requires_grad]
|
||||
image_encoder_parameters_nodecay = [param for param in image_encoder_parameters_nodecay if
|
||||
param.requires_grad]
|
||||
|
||||
print(f"Image Encoder Params: {len(image_encoder_parameters_decay)} decay, ")
|
||||
print(f"Image Encoder Params: {len(image_encoder_parameters_nodecay)} nodecay, ")
|
||||
|
||||
image_encoder_lr = self.optimizer_cfg['image_encoder_lr']
|
||||
image_encoder_lr_multiply = self.optimizer_cfg.get('image_encoder_lr_multiply', 1.0)
|
||||
image_encoder_lr = image_encoder_lr if image_encoder_lr is not None else lr * image_encoder_lr_multiply
|
||||
params_list.append(
|
||||
{'params': image_encoder_parameters_decay, 'lr': image_encoder_lr,
|
||||
'weight_decay': 0.05})
|
||||
params_list.append(
|
||||
{'params': image_encoder_parameters_nodecay, 'lr': image_encoder_lr,
|
||||
'weight_decay': 0.})
|
||||
|
||||
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
|
||||
if hasattr(self.optimizer_cfg, 'scheduler'):
|
||||
scheduler_func = instantiate_from_config(
|
||||
self.optimizer_cfg.scheduler,
|
||||
max_decay_steps=self.trainer.max_steps,
|
||||
lr_max=lr
|
||||
)
|
||||
scheduler = {
|
||||
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1
|
||||
}
|
||||
schedulers = [scheduler]
|
||||
else:
|
||||
schedulers = []
|
||||
optimizers = [optimizer]
|
||||
|
||||
return optimizers, schedulers
|
||||
|
||||
@rank_zero_only
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start(self, batch, batch_idx):
|
||||
# only for very first batch
|
||||
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
|
||||
and batch_idx == 0 and self.ckpt_path is None:
|
||||
# set rescale weight to 1./std of encodings
|
||||
print("### USING STD-RESCALING ###")
|
||||
|
||||
z_q = self.encode_first_stage(batch[self.first_stage_key])
|
||||
z = z_q.detach()
|
||||
|
||||
del self.z_scale_factor
|
||||
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
|
||||
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
|
||||
|
||||
print("### USING STD-RESCALING ###")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.ema_config is not None:
|
||||
self.model_ema(self.model)
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
pl.seed_everything(self.trainer.global_rank)
|
||||
|
||||
def forward(self, batch):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): #float32 for text
|
||||
contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'))
|
||||
# t5_text = contexts['t5_text']['prompt_embeds']
|
||||
# nan_count = torch.isnan(t5_text).sum()
|
||||
# if nan_count > 0:
|
||||
# print("t5_text has %d NaN values"%(nan_count))
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
with torch.no_grad():
|
||||
latents = self.first_stage_model.encode(batch[self.first_stage_key], sample_posterior=True)
|
||||
latents = self.z_scale_factor * latents
|
||||
# print(latents.shape)
|
||||
|
||||
# check vae encode and decode is ok? answer is ok !
|
||||
# import time
|
||||
# from hy3dshape.pipelines import export_to_trimesh
|
||||
# latents = 1. / self.z_scale_factor * latents
|
||||
# latents = self.first_stage_model(latents)
|
||||
# outputs = self.first_stage_model.latents2mesh(
|
||||
# latents,
|
||||
# bounds=1.01,
|
||||
# mc_level=0.0,
|
||||
# num_chunks=20000,
|
||||
# octree_resolution=256,
|
||||
# mc_algo='mc',
|
||||
# enable_pbar=True
|
||||
# )
|
||||
# mesh = export_to_trimesh(outputs)
|
||||
# if isinstance(mesh, list):
|
||||
# for midx, m in enumerate(mesh):
|
||||
# m.export(f"check_{midx}_{time.time()}.glb")
|
||||
# else:
|
||||
# mesh.export(f"check_{time.time()}.glb")
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
loss = self.transport.training_losses(self.model, latents, dict(contexts=contexts))["loss"].mean()
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
||||
loss = self.forward(batch)
|
||||
split = 'train'
|
||||
loss_dict = {
|
||||
f"{split}/simple": loss.detach(),
|
||||
f"{split}/total_loss": loss.detach(),
|
||||
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
||||
}
|
||||
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx, optimizer_idx=0):
|
||||
loss = self.forward(batch)
|
||||
split = 'val'
|
||||
loss_dict = {
|
||||
f"{split}/simple": loss.detach(),
|
||||
f"{split}/total_loss": loss.detach(),
|
||||
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
||||
}
|
||||
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, batch, output_type='trimesh', **kwargs):
|
||||
self.cond_stage_model.disable_drop = True
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
with self.ema_scope("Sample"):
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
try:
|
||||
self.pipeline.device = self.device
|
||||
self.pipeline.dtype = self.dtype
|
||||
print("### USING PIPELINE ###")
|
||||
print(f'device: {self.device} dtype : {self.dtype}')
|
||||
additional_params = {'output_type':output_type}
|
||||
|
||||
image = batch.get("image", None)
|
||||
mask = batch.get('mask', None)
|
||||
|
||||
# if not isinstance(image, torch.Tensor): print(image.shape)
|
||||
# if isinstance(mask, torch.Tensor): print(mask.shape)
|
||||
|
||||
outputs = self.pipeline(image=image,
|
||||
mask=mask,
|
||||
generator=generator,
|
||||
**additional_params)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"Unexpected {e=}, {type(e)=}")
|
||||
with open("error.txt", "a") as f:
|
||||
f.write(str(e))
|
||||
f.write(traceback.format_exc())
|
||||
f.write("\n")
|
||||
outputs = [None]
|
||||
self.cond_stage_model.disable_drop = False
|
||||
return [outputs]
|
||||
97
hy3dshape/hy3dshape/models/diffusion/transport/__init__.py
Executable file
97
hy3dshape/hy3dshape/models/diffusion/transport/__init__.py
Executable file
@@ -0,0 +1,97 @@
|
||||
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||
# which is licensed under the MIT License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
||||
|
||||
|
||||
def create_transport(
|
||||
path_type='Linear',
|
||||
prediction="velocity",
|
||||
loss_weight=None,
|
||||
train_eps=None,
|
||||
sample_eps=None,
|
||||
train_sample_type="uniform",
|
||||
mean = 0.0,
|
||||
std = 1.0,
|
||||
shift_scale = 1.0,
|
||||
):
|
||||
"""function for creating Transport object
|
||||
**Note**: model prediction defaults to velocity
|
||||
Args:
|
||||
- path_type: type of path to use; default to linear
|
||||
- learn_score: set model prediction to score
|
||||
- learn_noise: set model prediction to noise
|
||||
- velocity_weighted: weight loss by velocity weight
|
||||
- likelihood_weighted: weight loss by likelihood weight
|
||||
- train_eps: small epsilon for avoiding instability during training
|
||||
- sample_eps: small epsilon for avoiding instability during sampling
|
||||
"""
|
||||
|
||||
if prediction == "noise":
|
||||
model_type = ModelType.NOISE
|
||||
elif prediction == "score":
|
||||
model_type = ModelType.SCORE
|
||||
else:
|
||||
model_type = ModelType.VELOCITY
|
||||
|
||||
if loss_weight == "velocity":
|
||||
loss_type = WeightType.VELOCITY
|
||||
elif loss_weight == "likelihood":
|
||||
loss_type = WeightType.LIKELIHOOD
|
||||
else:
|
||||
loss_type = WeightType.NONE
|
||||
|
||||
path_choice = {
|
||||
"Linear": PathType.LINEAR,
|
||||
"GVP": PathType.GVP,
|
||||
"VP": PathType.VP,
|
||||
}
|
||||
|
||||
path_type = path_choice[path_type]
|
||||
|
||||
if (path_type in [PathType.VP]):
|
||||
train_eps = 1e-5 if train_eps is None else train_eps
|
||||
sample_eps = 1e-3 if train_eps is None else sample_eps
|
||||
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
||||
train_eps = 1e-3 if train_eps is None else train_eps
|
||||
sample_eps = 1e-3 if train_eps is None else sample_eps
|
||||
else: # velocity & [GVP, LINEAR] is stable everywhere
|
||||
train_eps = 0
|
||||
sample_eps = 0
|
||||
|
||||
# create flow state
|
||||
state = Transport(
|
||||
model_type=model_type,
|
||||
path_type=path_type,
|
||||
loss_type=loss_type,
|
||||
train_eps=train_eps,
|
||||
sample_eps=sample_eps,
|
||||
train_sample_type=train_sample_type,
|
||||
mean=mean,
|
||||
std=std,
|
||||
shift_scale =shift_scale,
|
||||
)
|
||||
|
||||
return state
|
||||
142
hy3dshape/hy3dshape/models/diffusion/transport/integrators.py
Executable file
142
hy3dshape/hy3dshape/models/diffusion/transport/integrators.py
Executable file
@@ -0,0 +1,142 @@
|
||||
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||
# which is licensed under the MIT License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torchdiffeq import odeint
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
|
||||
class sde:
|
||||
"""SDE solver class"""
|
||||
def __init__(
|
||||
self,
|
||||
drift,
|
||||
diffusion,
|
||||
*,
|
||||
t0,
|
||||
t1,
|
||||
num_steps,
|
||||
sampler_type,
|
||||
):
|
||||
assert t0 < t1, "SDE sampler has to be in forward time"
|
||||
|
||||
self.num_timesteps = num_steps
|
||||
self.t = th.linspace(t0, t1, num_steps)
|
||||
self.dt = self.t[1] - self.t[0]
|
||||
self.drift = drift
|
||||
self.diffusion = diffusion
|
||||
self.sampler_type = sampler_type
|
||||
|
||||
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
||||
w_cur = th.randn(x.size()).to(x)
|
||||
t = th.ones(x.size(0)).to(x) * t
|
||||
dw = w_cur * th.sqrt(self.dt)
|
||||
drift = self.drift(x, t, model, **model_kwargs)
|
||||
diffusion = self.diffusion(x, t)
|
||||
mean_x = x + drift * self.dt
|
||||
x = mean_x + th.sqrt(2 * diffusion) * dw
|
||||
return x, mean_x
|
||||
|
||||
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
||||
w_cur = th.randn(x.size()).to(x)
|
||||
dw = w_cur * th.sqrt(self.dt)
|
||||
t_cur = th.ones(x.size(0)).to(x) * t
|
||||
diffusion = self.diffusion(x, t_cur)
|
||||
xhat = x + th.sqrt(2 * diffusion) * dw
|
||||
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
||||
xp = xhat + self.dt * K1
|
||||
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
||||
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
||||
|
||||
def __forward_fn(self):
|
||||
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
||||
sampler_dict = {
|
||||
"Euler": self.__Euler_Maruyama_step,
|
||||
"Heun": self.__Heun_step,
|
||||
}
|
||||
|
||||
try:
|
||||
sampler = sampler_dict[self.sampler_type]
|
||||
except:
|
||||
raise NotImplementedError("Smapler type not implemented.")
|
||||
|
||||
return sampler
|
||||
|
||||
def sample(self, init, model, **model_kwargs):
|
||||
"""forward loop of sde"""
|
||||
x = init
|
||||
mean_x = init
|
||||
samples = []
|
||||
sampler = self.__forward_fn()
|
||||
for ti in self.t[:-1]:
|
||||
with th.no_grad():
|
||||
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
||||
samples.append(x)
|
||||
|
||||
return samples
|
||||
|
||||
class ode:
|
||||
"""ODE solver class"""
|
||||
def __init__(
|
||||
self,
|
||||
drift,
|
||||
*,
|
||||
t0,
|
||||
t1,
|
||||
sampler_type,
|
||||
num_steps,
|
||||
atol,
|
||||
rtol,
|
||||
):
|
||||
assert t0 < t1, "ODE sampler has to be in forward time"
|
||||
|
||||
self.drift = drift
|
||||
self.t = th.linspace(t0, t1, num_steps)
|
||||
self.atol = atol
|
||||
self.rtol = rtol
|
||||
self.sampler_type = sampler_type
|
||||
|
||||
def sample(self, x, model, **model_kwargs):
|
||||
|
||||
device = x[0].device if isinstance(x, tuple) else x.device
|
||||
def _fn(t, x):
|
||||
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
||||
model_output = self.drift(x, t, model, **model_kwargs)
|
||||
return model_output
|
||||
|
||||
t = self.t.to(device)
|
||||
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
||||
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
||||
samples = odeint(
|
||||
_fn,
|
||||
x,
|
||||
t,
|
||||
method=self.sampler_type,
|
||||
atol=atol,
|
||||
rtol=rtol
|
||||
)
|
||||
return samples
|
||||
220
hy3dshape/hy3dshape/models/diffusion/transport/path.py
Executable file
220
hy3dshape/hy3dshape/models/diffusion/transport/path.py
Executable file
@@ -0,0 +1,220 @@
|
||||
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||
# which is licensed under the MIT License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch as th
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
def expand_t_like_x(t, x):
|
||||
"""Function to reshape time t to broadcastable dimension of x
|
||||
Args:
|
||||
t: [batch_dim,], time vector
|
||||
x: [batch_dim,...], data point
|
||||
"""
|
||||
dims = [1] * (len(x.size()) - 1)
|
||||
t = t.view(t.size(0), *dims)
|
||||
return t
|
||||
|
||||
|
||||
#################### Coupling Plans ####################
|
||||
|
||||
class ICPlan:
|
||||
"""Linear Coupling Plan"""
|
||||
def __init__(self, sigma=0.0):
|
||||
self.sigma = sigma
|
||||
|
||||
def compute_alpha_t(self, t):
|
||||
"""Compute the data coefficient along the path"""
|
||||
return t, 1
|
||||
|
||||
def compute_sigma_t(self, t):
|
||||
"""Compute the noise coefficient along the path"""
|
||||
return 1 - t, -1
|
||||
|
||||
def compute_d_alpha_alpha_ratio_t(self, t):
|
||||
"""Compute the ratio between d_alpha and alpha"""
|
||||
return 1 / t
|
||||
|
||||
def compute_drift(self, x, t):
|
||||
"""We always output sde according to score parametrization; """
|
||||
t = expand_t_like_x(t, x)
|
||||
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
||||
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
||||
drift = alpha_ratio * x
|
||||
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
||||
|
||||
return -drift, diffusion
|
||||
|
||||
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
||||
"""Compute the diffusion term of the SDE
|
||||
Args:
|
||||
x: [batch_dim, ...], data point
|
||||
t: [batch_dim,], time vector
|
||||
form: str, form of the diffusion term
|
||||
norm: float, norm of the diffusion term
|
||||
"""
|
||||
t = expand_t_like_x(t, x)
|
||||
choices = {
|
||||
"constant": norm,
|
||||
"SBDM": norm * self.compute_drift(x, t)[1],
|
||||
"sigma": norm * self.compute_sigma_t(t)[0],
|
||||
"linear": norm * (1 - t),
|
||||
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
||||
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
||||
}
|
||||
|
||||
try:
|
||||
diffusion = choices[form]
|
||||
except KeyError:
|
||||
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
||||
|
||||
return diffusion
|
||||
|
||||
def get_score_from_velocity(self, velocity, x, t):
|
||||
"""Wrapper function: transfrom velocity prediction model to score
|
||||
Args:
|
||||
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
||||
x: [batch_dim, ...] shaped tensor; x_t data point
|
||||
t: [batch_dim,] time tensor
|
||||
"""
|
||||
t = expand_t_like_x(t, x)
|
||||
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
||||
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
||||
mean = x
|
||||
reverse_alpha_ratio = alpha_t / d_alpha_t
|
||||
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
||||
score = (reverse_alpha_ratio * velocity - mean) / var
|
||||
return score
|
||||
|
||||
def get_noise_from_velocity(self, velocity, x, t):
|
||||
"""Wrapper function: transfrom velocity prediction model to denoiser
|
||||
Args:
|
||||
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
||||
x: [batch_dim, ...] shaped tensor; x_t data point
|
||||
t: [batch_dim,] time tensor
|
||||
"""
|
||||
t = expand_t_like_x(t, x)
|
||||
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
||||
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
||||
mean = x
|
||||
reverse_alpha_ratio = alpha_t / d_alpha_t
|
||||
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
||||
noise = (reverse_alpha_ratio * velocity - mean) / var
|
||||
return noise
|
||||
|
||||
def get_velocity_from_score(self, score, x, t):
|
||||
"""Wrapper function: transfrom score prediction model to velocity
|
||||
Args:
|
||||
score: [batch_dim, ...] shaped tensor; score model output
|
||||
x: [batch_dim, ...] shaped tensor; x_t data point
|
||||
t: [batch_dim,] time tensor
|
||||
"""
|
||||
t = expand_t_like_x(t, x)
|
||||
drift, var = self.compute_drift(x, t)
|
||||
velocity = var * score - drift
|
||||
return velocity
|
||||
|
||||
def compute_mu_t(self, t, x0, x1):
|
||||
"""Compute the mean of time-dependent density p_t"""
|
||||
t = expand_t_like_x(t, x1)
|
||||
alpha_t, _ = self.compute_alpha_t(t)
|
||||
sigma_t, _ = self.compute_sigma_t(t)
|
||||
# t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
|
||||
return alpha_t * x1 + sigma_t * x0
|
||||
|
||||
def compute_xt(self, t, x0, x1):
|
||||
"""Sample xt from time-dependent density p_t; rng is required"""
|
||||
xt = self.compute_mu_t(t, x0, x1)
|
||||
return xt
|
||||
|
||||
def compute_ut(self, t, x0, x1, xt):
|
||||
"""Compute the vector field corresponding to p_t"""
|
||||
t = expand_t_like_x(t, x1)
|
||||
_, d_alpha_t = self.compute_alpha_t(t)
|
||||
_, d_sigma_t = self.compute_sigma_t(t)
|
||||
return d_alpha_t * x1 + d_sigma_t * x0
|
||||
|
||||
def plan(self, t, x0, x1):
|
||||
xt = self.compute_xt(t, x0, x1)
|
||||
ut = self.compute_ut(t, x0, x1, xt)
|
||||
return t, xt, ut
|
||||
|
||||
|
||||
class VPCPlan(ICPlan):
|
||||
"""class for VP path flow matching"""
|
||||
|
||||
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
|
||||
(self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
||||
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
|
||||
(self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
||||
|
||||
|
||||
def compute_alpha_t(self, t):
|
||||
"""Compute coefficient of x1"""
|
||||
alpha_t = self.log_mean_coeff(t)
|
||||
alpha_t = th.exp(alpha_t)
|
||||
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
||||
return alpha_t, d_alpha_t
|
||||
|
||||
def compute_sigma_t(self, t):
|
||||
"""Compute coefficient of x0"""
|
||||
p_sigma_t = 2 * self.log_mean_coeff(t)
|
||||
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
||||
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
||||
return sigma_t, d_sigma_t
|
||||
|
||||
def compute_d_alpha_alpha_ratio_t(self, t):
|
||||
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
||||
return self.d_log_mean_coeff(t)
|
||||
|
||||
def compute_drift(self, x, t):
|
||||
"""Compute the drift term of the SDE"""
|
||||
t = expand_t_like_x(t, x)
|
||||
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
||||
return -0.5 * beta_t * x, beta_t / 2
|
||||
|
||||
|
||||
class GVPCPlan(ICPlan):
|
||||
def __init__(self, sigma=0.0):
|
||||
super().__init__(sigma)
|
||||
|
||||
def compute_alpha_t(self, t):
|
||||
"""Compute coefficient of x1"""
|
||||
alpha_t = th.sin(t * np.pi / 2)
|
||||
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
||||
return alpha_t, d_alpha_t
|
||||
|
||||
def compute_sigma_t(self, t):
|
||||
"""Compute coefficient of x0"""
|
||||
sigma_t = th.cos(t * np.pi / 2)
|
||||
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
||||
return sigma_t, d_sigma_t
|
||||
|
||||
def compute_d_alpha_alpha_ratio_t(self, t):
|
||||
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
||||
return np.pi / (2 * th.tan(t * np.pi / 2))
|
||||
534
hy3dshape/hy3dshape/models/diffusion/transport/transport.py
Executable file
534
hy3dshape/hy3dshape/models/diffusion/transport/transport.py
Executable file
@@ -0,0 +1,534 @@
|
||||
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||
# which is licensed under the MIT License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch as th
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
import enum
|
||||
|
||||
from . import path
|
||||
from .utils import EasyDict, log_state, mean_flat
|
||||
from .integrators import ode, sde
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
"""
|
||||
Which type of output the model predicts.
|
||||
"""
|
||||
|
||||
NOISE = enum.auto() # the model predicts epsilon
|
||||
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
||||
VELOCITY = enum.auto() # the model predicts v(x)
|
||||
|
||||
|
||||
class PathType(enum.Enum):
|
||||
"""
|
||||
Which type of path to use.
|
||||
"""
|
||||
|
||||
LINEAR = enum.auto()
|
||||
GVP = enum.auto()
|
||||
VP = enum.auto()
|
||||
|
||||
|
||||
class WeightType(enum.Enum):
|
||||
"""
|
||||
Which type of weighting to use.
|
||||
"""
|
||||
|
||||
NONE = enum.auto()
|
||||
VELOCITY = enum.auto()
|
||||
LIKELIHOOD = enum.auto()
|
||||
|
||||
|
||||
class Transport:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_type,
|
||||
path_type,
|
||||
loss_type,
|
||||
train_eps,
|
||||
sample_eps,
|
||||
train_sample_type = "uniform",
|
||||
**kwargs,
|
||||
):
|
||||
path_options = {
|
||||
PathType.LINEAR: path.ICPlan,
|
||||
PathType.GVP: path.GVPCPlan,
|
||||
PathType.VP: path.VPCPlan,
|
||||
}
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.model_type = model_type
|
||||
self.path_sampler = path_options[path_type]()
|
||||
self.train_eps = train_eps
|
||||
self.sample_eps = sample_eps
|
||||
self.train_sample_type = train_sample_type
|
||||
if self.train_sample_type == "logit_normal":
|
||||
self.mean = kwargs['mean']
|
||||
self.std = kwargs['std']
|
||||
self.shift_scale = kwargs['shift_scale']
|
||||
print(f"using logit normal sample, shift scale is {self.shift_scale}")
|
||||
|
||||
def prior_logp(self, z):
|
||||
'''
|
||||
Standard multivariate normal prior
|
||||
Assume z is batched
|
||||
'''
|
||||
shape = th.tensor(z.size())
|
||||
N = th.prod(shape[1:])
|
||||
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
|
||||
return th.vmap(_fn)(z)
|
||||
|
||||
def check_interval(
|
||||
self,
|
||||
train_eps,
|
||||
sample_eps,
|
||||
*,
|
||||
diffusion_form="SBDM",
|
||||
sde=False,
|
||||
reverse=False,
|
||||
eval=False,
|
||||
last_step_size=0.0,
|
||||
):
|
||||
t0 = 0
|
||||
t1 = 1
|
||||
eps = train_eps if not eval else sample_eps
|
||||
if (type(self.path_sampler) in [path.VPCPlan]):
|
||||
|
||||
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
||||
|
||||
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
|
||||
and (
|
||||
self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
|
||||
|
||||
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
||||
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
||||
|
||||
if reverse:
|
||||
t0, t1 = 1 - t0, 1 - t1
|
||||
|
||||
return t0, t1
|
||||
|
||||
def sample(self, x1):
|
||||
"""Sampling x0 & t based on shape of x1 (if needed)
|
||||
Args:
|
||||
x1 - data point; [batch, *dim]
|
||||
"""
|
||||
|
||||
x0 = th.randn_like(x1)
|
||||
if self.train_sample_type=="uniform":
|
||||
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
||||
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
||||
t = t.to(x1)
|
||||
elif self.train_sample_type=="logit_normal":
|
||||
t = th.randn((x1.shape[0],)) * self.std + self.mean
|
||||
t = t.to(x1)
|
||||
t = 1/(1+th.exp(-t))
|
||||
|
||||
t = np.sqrt(self.shift_scale)*t/(1+(np.sqrt(self.shift_scale)-1)*t)
|
||||
|
||||
return t, x0, x1
|
||||
|
||||
def training_losses(
|
||||
self,
|
||||
model,
|
||||
x1,
|
||||
model_kwargs=None
|
||||
):
|
||||
"""Loss for training the score model
|
||||
Args:
|
||||
- model: backbone model; could be score, noise, or velocity
|
||||
- x1: datapoint
|
||||
- model_kwargs: additional arguments for the model
|
||||
"""
|
||||
if model_kwargs == None:
|
||||
model_kwargs = {}
|
||||
|
||||
t, x0, x1 = self.sample(x1)
|
||||
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
||||
model_output = model(xt, t, **model_kwargs)
|
||||
B, *_, C = xt.shape
|
||||
assert model_output.size() == (B, *xt.size()[1:-1], C)
|
||||
|
||||
terms = {}
|
||||
terms['pred'] = model_output
|
||||
if self.model_type == ModelType.VELOCITY:
|
||||
terms['loss'] = mean_flat(((model_output - ut) ** 2))
|
||||
else:
|
||||
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
||||
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
||||
if self.loss_type in [WeightType.VELOCITY]:
|
||||
weight = (drift_var / sigma_t) ** 2
|
||||
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
||||
weight = drift_var / (sigma_t ** 2)
|
||||
elif self.loss_type in [WeightType.NONE]:
|
||||
weight = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.model_type == ModelType.NOISE:
|
||||
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
|
||||
else:
|
||||
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
||||
|
||||
return terms
|
||||
|
||||
def get_drift(
|
||||
self
|
||||
):
|
||||
"""member function for obtaining the drift of the probability flow ODE"""
|
||||
|
||||
def score_ode(x, t, model, **model_kwargs):
|
||||
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
||||
model_output = model(x, t, **model_kwargs)
|
||||
return (-drift_mean + drift_var * model_output) # by change of variable
|
||||
|
||||
def noise_ode(x, t, model, **model_kwargs):
|
||||
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
||||
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
||||
model_output = model(x, t, **model_kwargs)
|
||||
score = model_output / -sigma_t
|
||||
return (-drift_mean + drift_var * score)
|
||||
|
||||
def velocity_ode(x, t, model, **model_kwargs):
|
||||
model_output = model(x, t, **model_kwargs)
|
||||
return model_output
|
||||
|
||||
if self.model_type == ModelType.NOISE:
|
||||
drift_fn = noise_ode
|
||||
elif self.model_type == ModelType.SCORE:
|
||||
drift_fn = score_ode
|
||||
else:
|
||||
drift_fn = velocity_ode
|
||||
|
||||
def body_fn(x, t, model, **model_kwargs):
|
||||
model_output = drift_fn(x, t, model, **model_kwargs)
|
||||
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
||||
return model_output
|
||||
|
||||
return body_fn
|
||||
|
||||
def get_score(
|
||||
self,
|
||||
):
|
||||
"""member function for obtaining score of
|
||||
x_t = alpha_t * x + sigma_t * eps"""
|
||||
if self.model_type == ModelType.NOISE:
|
||||
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / - \
|
||||
self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
||||
elif self.model_type == ModelType.SCORE:
|
||||
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
||||
elif self.model_type == ModelType.VELOCITY:
|
||||
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x,
|
||||
t)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return score_fn
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""Sampler class for the transport model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport,
|
||||
):
|
||||
"""Constructor for a general sampler; supporting different sampling methods
|
||||
Args:
|
||||
- transport: an tranport object specify model prediction & interpolant type
|
||||
"""
|
||||
|
||||
self.transport = transport
|
||||
self.drift = self.transport.get_drift()
|
||||
self.score = self.transport.get_score()
|
||||
|
||||
def __get_sde_diffusion_and_drift(
|
||||
self,
|
||||
*,
|
||||
diffusion_form="SBDM",
|
||||
diffusion_norm=1.0,
|
||||
):
|
||||
|
||||
def diffusion_fn(x, t):
|
||||
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
||||
return diffusion
|
||||
|
||||
sde_drift = \
|
||||
lambda x, t, model, **kwargs: \
|
||||
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
||||
|
||||
sde_diffusion = diffusion_fn
|
||||
|
||||
return sde_drift, sde_diffusion
|
||||
|
||||
def __get_last_step(
|
||||
self,
|
||||
sde_drift,
|
||||
*,
|
||||
last_step,
|
||||
last_step_size,
|
||||
):
|
||||
"""Get the last step function of the SDE solver"""
|
||||
|
||||
if last_step is None:
|
||||
last_step_fn = \
|
||||
lambda x, t, model, **model_kwargs: \
|
||||
x
|
||||
elif last_step == "Mean":
|
||||
last_step_fn = \
|
||||
lambda x, t, model, **model_kwargs: \
|
||||
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
||||
elif last_step == "Tweedie":
|
||||
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
||||
sigma = self.transport.path_sampler.compute_sigma_t
|
||||
last_step_fn = \
|
||||
lambda x, t, model, **model_kwargs: \
|
||||
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model,
|
||||
**model_kwargs)
|
||||
elif last_step == "Euler":
|
||||
last_step_fn = \
|
||||
lambda x, t, model, **model_kwargs: \
|
||||
x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return last_step_fn
|
||||
|
||||
def sample_sde(
|
||||
self,
|
||||
*,
|
||||
sampling_method="Euler",
|
||||
diffusion_form="SBDM",
|
||||
diffusion_norm=1.0,
|
||||
last_step="Mean",
|
||||
last_step_size=0.04,
|
||||
num_steps=250,
|
||||
):
|
||||
"""returns a sampling function with given SDE settings
|
||||
Args:
|
||||
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
||||
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
||||
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
||||
- last_step: type of the last step; default to identity
|
||||
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
||||
- num_steps: total integration step of SDE
|
||||
"""
|
||||
|
||||
if last_step is None:
|
||||
last_step_size = 0.0
|
||||
|
||||
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
||||
diffusion_form=diffusion_form,
|
||||
diffusion_norm=diffusion_norm,
|
||||
)
|
||||
|
||||
t0, t1 = self.transport.check_interval(
|
||||
self.transport.train_eps,
|
||||
self.transport.sample_eps,
|
||||
diffusion_form=diffusion_form,
|
||||
sde=True,
|
||||
eval=True,
|
||||
reverse=False,
|
||||
last_step_size=last_step_size,
|
||||
)
|
||||
|
||||
_sde = sde(
|
||||
sde_drift,
|
||||
sde_diffusion,
|
||||
t0=t0,
|
||||
t1=t1,
|
||||
num_steps=num_steps,
|
||||
sampler_type=sampling_method
|
||||
)
|
||||
|
||||
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
||||
|
||||
def _sample(init, model, **model_kwargs):
|
||||
xs = _sde.sample(init, model, **model_kwargs)
|
||||
ts = th.ones(init.size(0), device=init.device) * t1
|
||||
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
||||
xs.append(x)
|
||||
|
||||
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
||||
|
||||
return xs
|
||||
|
||||
return _sample
|
||||
|
||||
def sample_ode(
|
||||
self,
|
||||
*,
|
||||
sampling_method="dopri5",
|
||||
num_steps=50,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
reverse=False,
|
||||
):
|
||||
"""returns a sampling function with given ODE settings
|
||||
Args:
|
||||
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
||||
- num_steps:
|
||||
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
||||
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
||||
- atol: absolute error tolerance for the solver
|
||||
- rtol: relative error tolerance for the solver
|
||||
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
||||
"""
|
||||
if reverse:
|
||||
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
||||
else:
|
||||
drift = self.drift
|
||||
|
||||
t0, t1 = self.transport.check_interval(
|
||||
self.transport.train_eps,
|
||||
self.transport.sample_eps,
|
||||
sde=False,
|
||||
eval=True,
|
||||
reverse=reverse,
|
||||
last_step_size=0.0,
|
||||
)
|
||||
|
||||
_ode = ode(
|
||||
drift=drift,
|
||||
t0=t0,
|
||||
t1=t1,
|
||||
sampler_type=sampling_method,
|
||||
num_steps=num_steps,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
return _ode.sample
|
||||
|
||||
def sample_ode_intermediate(
|
||||
self,
|
||||
*,
|
||||
sampling_method="dopri5",
|
||||
num_steps=50,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
t=0.5,
|
||||
reverse=False,
|
||||
):
|
||||
"""returns a sampling function with given ODE settings
|
||||
Args:
|
||||
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
||||
- num_steps:
|
||||
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
||||
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
||||
- atol: absolute error tolerance for the solver
|
||||
- rtol: relative error tolerance for the solver
|
||||
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
||||
"""
|
||||
if reverse:
|
||||
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
||||
else:
|
||||
drift = self.drift
|
||||
|
||||
t0, t1 = self.transport.check_interval(
|
||||
self.transport.train_eps,
|
||||
self.transport.sample_eps,
|
||||
sde=False,
|
||||
eval=True,
|
||||
reverse=reverse,
|
||||
last_step_size=0.0,
|
||||
)
|
||||
|
||||
_ode = ode(
|
||||
drift=drift,
|
||||
t0=t,
|
||||
t1=t1,
|
||||
sampler_type=sampling_method,
|
||||
num_steps=num_steps,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
return _ode.sample
|
||||
|
||||
def sample_ode_likelihood(
|
||||
self,
|
||||
*,
|
||||
sampling_method="dopri5",
|
||||
num_steps=50,
|
||||
atol=1e-6,
|
||||
rtol=1e-3,
|
||||
):
|
||||
|
||||
"""returns a sampling function for calculating likelihood with given ODE settings
|
||||
Args:
|
||||
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
||||
- num_steps:
|
||||
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
||||
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
||||
- atol: absolute error tolerance for the solver
|
||||
- rtol: relative error tolerance for the solver
|
||||
"""
|
||||
|
||||
def _likelihood_drift(x, t, model, **model_kwargs):
|
||||
x, _ = x
|
||||
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
||||
t = th.ones_like(t) * (1 - t)
|
||||
with th.enable_grad():
|
||||
x.requires_grad = True
|
||||
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
||||
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
||||
drift = self.drift(x, t, model, **model_kwargs)
|
||||
return (-drift, logp_grad)
|
||||
|
||||
t0, t1 = self.transport.check_interval(
|
||||
self.transport.train_eps,
|
||||
self.transport.sample_eps,
|
||||
sde=False,
|
||||
eval=True,
|
||||
reverse=False,
|
||||
last_step_size=0.0,
|
||||
)
|
||||
|
||||
_ode = ode(
|
||||
drift=_likelihood_drift,
|
||||
t0=t0,
|
||||
t1=t1,
|
||||
sampler_type=sampling_method,
|
||||
num_steps=num_steps,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
def _sample_fn(x, model, **model_kwargs):
|
||||
init_logp = th.zeros(x.size(0)).to(x)
|
||||
input = (x, init_logp)
|
||||
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
||||
drift, delta_logp = drift[-1], delta_logp[-1]
|
||||
prior_logp = self.transport.prior_logp(drift)
|
||||
logp = prior_logp - delta_logp
|
||||
return logp, drift
|
||||
|
||||
return _sample_fn
|
||||
54
hy3dshape/hy3dshape/models/diffusion/transport/utils.py
Executable file
54
hy3dshape/hy3dshape/models/diffusion/transport/utils.py
Executable file
@@ -0,0 +1,54 @@
|
||||
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||
# which is licensed under the MIT License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch as th
|
||||
|
||||
class EasyDict:
|
||||
|
||||
def __init__(self, sub_dict):
|
||||
for k, v in sub_dict.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def mean_flat(x):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return th.mean(x, dim=list(range(1, len(x.size()))))
|
||||
|
||||
def log_state(state):
|
||||
result = []
|
||||
|
||||
sorted_state = dict(sorted(state.items()))
|
||||
for key, value in sorted_state.items():
|
||||
# Check if the value is an instance of a class
|
||||
if "<object" in str(value) or "object at" in str(value):
|
||||
result.append(f"{key}: [{value.__class__.__name__}]")
|
||||
else:
|
||||
result.append(f"{key}: {value}")
|
||||
|
||||
return '\n'.join(result)
|
||||
783
hy3dshape/hy3dshape/pipelines.py
Normal file
783
hy3dshape/hy3dshape/pipelines.py
Normal file
@@ -0,0 +1,783 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
|
||||
from tqdm import tqdm
|
||||
|
||||
from .models.autoencoders import ShapeVAE
|
||||
from .models.autoencoders import SurfaceExtractors
|
||||
from .utils import logger, synchronize_timer, smart_load_model
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
@synchronize_timer('Export to trimesh')
|
||||
def export_to_trimesh(mesh_output):
|
||||
if isinstance(mesh_output, list):
|
||||
outputs = []
|
||||
for mesh in mesh_output:
|
||||
if mesh is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
mesh.mesh_f = mesh.mesh_f[:, ::-1]
|
||||
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
|
||||
outputs.append(mesh_output)
|
||||
return outputs
|
||||
else:
|
||||
mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
|
||||
mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
|
||||
return mesh_output
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if "target" not in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
cls = get_obj_from_str(config["target"])
|
||||
params = config.get("params", dict())
|
||||
kwargs.update(params)
|
||||
instance = cls(**kwargs)
|
||||
return instance
|
||||
|
||||
|
||||
class Hunyuan3DDiTPipeline:
|
||||
model_cpu_offload_seq = "conditioner->model->vae"
|
||||
_exclude_from_cpu_offload = []
|
||||
|
||||
@classmethod
|
||||
@synchronize_timer('Hunyuan3DDiTPipeline Model Loading')
|
||||
def from_single_file(
|
||||
cls,
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# load ckpt
|
||||
if use_safetensors:
|
||||
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
|
||||
if use_safetensors:
|
||||
# parse safetensors
|
||||
import safetensors.torch
|
||||
safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||
ckpt = {}
|
||||
for key, value in safetensors_ckpt.items():
|
||||
model_name = key.split('.')[0]
|
||||
new_key = key[len(model_name) + 1:]
|
||||
if model_name not in ckpt:
|
||||
ckpt[model_name] = {}
|
||||
ckpt[model_name][new_key] = value
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
# load model
|
||||
model = instantiate_from_config(config['model'])
|
||||
model.load_state_dict(ckpt['model'])
|
||||
vae = instantiate_from_config(config['vae'])
|
||||
vae.load_state_dict(ckpt['vae'], strict=False)
|
||||
conditioner = instantiate_from_config(config['conditioner'])
|
||||
if 'conditioner' in ckpt:
|
||||
conditioner.load_state_dict(ckpt['conditioner'])
|
||||
image_processor = instantiate_from_config(config['image_processor'])
|
||||
scheduler = instantiate_from_config(config['scheduler'])
|
||||
|
||||
model_kwargs = dict(
|
||||
vae=vae,
|
||||
model=model,
|
||||
scheduler=scheduler,
|
||||
conditioner=conditioner,
|
||||
image_processor=image_processor,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
model_kwargs.update(kwargs)
|
||||
|
||||
return cls(
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=False,
|
||||
variant='fp16',
|
||||
subfolder='hunyuan3d-dit-v2-1',
|
||||
**kwargs,
|
||||
):
|
||||
kwargs['from_pretrained_kwargs'] = dict(
|
||||
model_path=model_path,
|
||||
subfolder=subfolder,
|
||||
use_safetensors=use_safetensors,
|
||||
variant=variant,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
config_path, ckpt_path = smart_load_model(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
use_safetensors=use_safetensors,
|
||||
variant=variant
|
||||
)
|
||||
return cls.from_single_file(
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_safetensors=use_safetensors,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
model,
|
||||
scheduler,
|
||||
conditioner,
|
||||
image_processor,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
**kwargs
|
||||
):
|
||||
self.vae = vae
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
self.conditioner = conditioner
|
||||
self.image_processor = image_processor
|
||||
self.kwargs = kwargs
|
||||
self.to(device, dtype)
|
||||
|
||||
def compile(self):
|
||||
self.vae = torch.compile(self.vae)
|
||||
self.model = torch.compile(self.model)
|
||||
self.conditioner = torch.compile(self.conditioner)
|
||||
|
||||
def enable_flashvdm(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
adaptive_kv_selection=True,
|
||||
topk_mode='mean',
|
||||
mc_algo='mc',
|
||||
replace_vae=True,
|
||||
):
|
||||
if enabled:
|
||||
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
||||
turbo_vae_mapping = {
|
||||
'Hunyuan3D-2': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0-turbo'),
|
||||
'Hunyuan3D-2mv': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0-turbo'),
|
||||
'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini-turbo'),
|
||||
}
|
||||
model_name = model_path.split('/')[-1]
|
||||
if replace_vae and model_name in turbo_vae_mapping:
|
||||
model_path, subfolder = turbo_vae_mapping[model_name]
|
||||
self.vae = ShapeVAE.from_pretrained(
|
||||
model_path, subfolder=subfolder,
|
||||
use_safetensors=self.kwargs['from_pretrained_kwargs']['use_safetensors'],
|
||||
device=self.device,
|
||||
)
|
||||
self.vae.enable_flashvdm_decoder(
|
||||
enabled=enabled,
|
||||
adaptive_kv_selection=adaptive_kv_selection,
|
||||
topk_mode=topk_mode,
|
||||
mc_algo=mc_algo
|
||||
)
|
||||
else:
|
||||
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
||||
vae_mapping = {
|
||||
'Hunyuan3D-2': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0'),
|
||||
'Hunyuan3D-2mv': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0'),
|
||||
'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini'),
|
||||
}
|
||||
model_name = model_path.split('/')[-1]
|
||||
if model_name in vae_mapping:
|
||||
model_path, subfolder = vae_mapping[model_name]
|
||||
self.vae = ShapeVAE.from_pretrained(model_path, subfolder=subfolder)
|
||||
self.vae.enable_flashvdm_decoder(enabled=False)
|
||||
|
||||
def to(self, device=None, dtype=None):
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self.vae.to(dtype=dtype)
|
||||
self.model.to(dtype=dtype)
|
||||
self.conditioner.to(dtype=dtype)
|
||||
if device is not None:
|
||||
self.device = torch.device(device)
|
||||
self.vae.to(device)
|
||||
self.model.to(device)
|
||||
self.conditioner.to(device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||
Accelerate's module hooks.
|
||||
"""
|
||||
for name, model in self.components.items():
|
||||
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
|
||||
continue
|
||||
|
||||
if not hasattr(model, "_hf_hook"):
|
||||
return self.device
|
||||
for module in model.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
|
||||
Arguments:
|
||||
gpu_id (`int`, *optional*):
|
||||
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
|
||||
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
default to "cuda".
|
||||
"""
|
||||
if self.model_cpu_offload_seq is None:
|
||||
raise ValueError(
|
||||
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
||||
)
|
||||
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
if gpu_id is not None and device_index is not None:
|
||||
raise ValueError(
|
||||
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
||||
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of "
|
||||
f"the device: `device`={torch_device.type}"
|
||||
)
|
||||
|
||||
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`)
|
||||
# or default to previously set id or default to 0
|
||||
self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu")
|
||||
device_mod = getattr(torch, self.device.type, None)
|
||||
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
||||
device_mod.empty_cache()
|
||||
# otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||
|
||||
self._all_hooks = []
|
||||
hook = None
|
||||
for model_str in self.model_cpu_offload_seq.split("->"):
|
||||
model = all_model_components.pop(model_str, None)
|
||||
if not isinstance(model, torch.nn.Module):
|
||||
continue
|
||||
|
||||
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
||||
self._all_hooks.append(hook)
|
||||
|
||||
# CPU offload models that are not in the seq chain unless they are explicitly excluded
|
||||
# these models will stay on CPU until maybe_free_model_hooks is called
|
||||
# some models cannot be in the seq chain because they are iteratively called,
|
||||
# such as controlnet
|
||||
for name, model in all_model_components.items():
|
||||
if not isinstance(model, torch.nn.Module):
|
||||
continue
|
||||
|
||||
if name in self._exclude_from_cpu_offload:
|
||||
model.to(device)
|
||||
else:
|
||||
_, hook = cpu_offload_with_hook(model, device)
|
||||
self._all_hooks.append(hook)
|
||||
|
||||
def maybe_free_model_hooks(self):
|
||||
r"""
|
||||
Function that offloads all components, removes all model hooks that were added when using
|
||||
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
|
||||
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
|
||||
functions correctly when applying enable_model_cpu_offload.
|
||||
"""
|
||||
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
|
||||
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
||||
return
|
||||
|
||||
for hook in self._all_hooks:
|
||||
# offload model and remove hook from model
|
||||
hook.offload()
|
||||
hook.remove()
|
||||
|
||||
# make sure the model is in the same state as before calling it
|
||||
self.enable_model_cpu_offload()
|
||||
|
||||
@synchronize_timer('Encode cond')
|
||||
def encode_cond(self, image, additional_cond_inputs, do_classifier_free_guidance, dual_guidance):
|
||||
bsz = image.shape[0]
|
||||
cond = self.conditioner(image=image, **additional_cond_inputs)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)
|
||||
|
||||
if dual_guidance:
|
||||
un_cond_drop_main = copy.deepcopy(un_cond)
|
||||
un_cond_drop_main['additional'] = cond['additional']
|
||||
|
||||
def cat_recursive(a, b, c):
|
||||
if isinstance(a, torch.Tensor):
|
||||
return torch.cat([a, b, c], dim=0).to(self.dtype)
|
||||
out = {}
|
||||
for k in a.keys():
|
||||
out[k] = cat_recursive(a[k], b[k], c[k])
|
||||
return out
|
||||
|
||||
cond = cat_recursive(cond, un_cond_drop_main, un_cond)
|
||||
else:
|
||||
def cat_recursive(a, b):
|
||||
if isinstance(a, torch.Tensor):
|
||||
return torch.cat([a, b], dim=0).to(self.dtype)
|
||||
out = {}
|
||||
for k in a.keys():
|
||||
out[k] = cat_recursive(a[k], b[k])
|
||||
return out
|
||||
|
||||
cond = cat_recursive(cond, un_cond)
|
||||
return cond
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def prepare_latents(self, batch_size, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, *self.vae.latent_shape)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
|
||||
return latents
|
||||
|
||||
def prepare_image(self, image, mask=None) -> dict:
|
||||
if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor):
|
||||
outputs = {
|
||||
'image': image,
|
||||
'mask': mask
|
||||
}
|
||||
return outputs
|
||||
|
||||
if isinstance(image, str) and not os.path.exists(image):
|
||||
raise FileNotFoundError(f"Couldn't find image at path {image}")
|
||||
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
outputs = []
|
||||
for img in image:
|
||||
output = self.image_processor(img)
|
||||
outputs.append(output)
|
||||
|
||||
cond_input = {k: [] for k in outputs[0].keys()}
|
||||
for output in outputs:
|
||||
for key, value in output.items():
|
||||
cond_input[key].append(value)
|
||||
for key, value in cond_input.items():
|
||||
if isinstance(value[0], torch.Tensor):
|
||||
cond_input[key] = torch.cat(value, dim=0)
|
||||
|
||||
return cond_input
|
||||
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
timesteps (`torch.Tensor`):
|
||||
generate embedding vectors at these timesteps
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
dimension of the embeddings to generate
|
||||
dtype:
|
||||
data type of the generated embeddings
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
def set_surface_extractor(self, mc_algo):
|
||||
if mc_algo is None:
|
||||
return
|
||||
logger.info('The parameters `mc_algo` is deprecated, and will be removed in future versions.\n'
|
||||
'Please use: \n'
|
||||
'from hy3dshape.models.autoencoders import SurfaceExtractors\n'
|
||||
'pipeline.vae.surface_extractor = SurfaceExtractors[mc_algo]() instead\n')
|
||||
if mc_algo not in SurfaceExtractors.keys():
|
||||
raise ValueError(f"Unknown mc_algo {mc_algo}")
|
||||
self.vae.surface_extractor = SurfaceExtractors[mc_algo]()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[str, List[str], Image.Image] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
eta: float = 0.0,
|
||||
guidance_scale: float = 7.5,
|
||||
dual_guidance_scale: float = 10.5,
|
||||
dual_guidance: bool = True,
|
||||
generator=None,
|
||||
box_v=1.01,
|
||||
octree_resolution=384,
|
||||
mc_level=-1 / 512,
|
||||
num_chunks=8000,
|
||||
mc_algo=None,
|
||||
output_type: Optional[str] = "trimesh",
|
||||
enable_pbar=True,
|
||||
**kwargs,
|
||||
) -> List[List[trimesh.Trimesh]]:
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
self.set_surface_extractor(mc_algo)
|
||||
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
do_classifier_free_guidance = guidance_scale >= 0 and \
|
||||
getattr(self.model, 'guidance_cond_proj_dim', None) is None
|
||||
dual_guidance = dual_guidance_scale >= 0 and dual_guidance
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
cond_inputs = self.prepare_image(image)
|
||||
image = cond_inputs.pop('image')
|
||||
|
||||
cond = self.encode_cond(
|
||||
image=image,
|
||||
additional_cond_inputs=cond_inputs,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dual_guidance=False,
|
||||
)
|
||||
batch_size = image.shape[0]
|
||||
|
||||
t_dtype = torch.long
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
||||
|
||||
latents = self.prepare_latents(batch_size, dtype, device, generator)
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
guidance_cond = None
|
||||
if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
|
||||
logger.info('Using lcm guidance scale')
|
||||
guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
|
||||
guidance_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
with synchronize_timer('Diffusion Sampling'):
|
||||
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
|
||||
else:
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
|
||||
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
|
||||
noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
|
||||
|
||||
# no drop, drop clip, all drop
|
||||
if do_classifier_free_guidance:
|
||||
if dual_guidance:
|
||||
noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ guidance_scale * (noise_pred_clip - noise_pred_dino)
|
||||
+ dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
|
||||
)
|
||||
else:
|
||||
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
||||
latents = outputs.prev_sample
|
||||
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, outputs)
|
||||
|
||||
return self._export(
|
||||
latents,
|
||||
output_type,
|
||||
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
||||
)
|
||||
|
||||
def _export(
|
||||
self,
|
||||
latents,
|
||||
output_type='trimesh',
|
||||
box_v=1.01,
|
||||
mc_level=0.0,
|
||||
num_chunks=20000,
|
||||
octree_resolution=256,
|
||||
mc_algo='mc',
|
||||
enable_pbar=True
|
||||
):
|
||||
if not output_type == "latent":
|
||||
latents = 1. / self.vae.scale_factor * latents
|
||||
latents = self.vae(latents)
|
||||
outputs = self.vae.latents2mesh(
|
||||
latents,
|
||||
bounds=box_v,
|
||||
mc_level=mc_level,
|
||||
num_chunks=num_chunks,
|
||||
octree_resolution=octree_resolution,
|
||||
mc_algo=mc_algo,
|
||||
enable_pbar=enable_pbar,
|
||||
)
|
||||
else:
|
||||
outputs = latents
|
||||
|
||||
if output_type == 'trimesh':
|
||||
outputs = export_to_trimesh(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[str, List[str], Image.Image, dict, List[dict], torch.Tensor] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
eta: float = 0.0,
|
||||
guidance_scale: float = 5.0,
|
||||
generator=None,
|
||||
box_v=1.01,
|
||||
octree_resolution=384,
|
||||
mc_level=0.0,
|
||||
mc_algo=None,
|
||||
num_chunks=8000,
|
||||
output_type: Optional[str] = "trimesh",
|
||||
enable_pbar=True,
|
||||
mask = None,
|
||||
**kwargs,
|
||||
) -> List[List[trimesh.Trimesh]]:
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
self.set_surface_extractor(mc_algo)
|
||||
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
do_classifier_free_guidance = guidance_scale >= 0 and not (
|
||||
hasattr(self.model, 'guidance_embed') and
|
||||
self.model.guidance_embed is True
|
||||
)
|
||||
|
||||
# print('image', type(image), 'mask', type(mask))
|
||||
cond_inputs = self.prepare_image(image, mask)
|
||||
image = cond_inputs.pop('image')
|
||||
cond = self.encode_cond(
|
||||
image=image,
|
||||
additional_cond_inputs=cond_inputs,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dual_guidance=False,
|
||||
)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
# 5. Prepare timesteps
|
||||
# NOTE: this is slightly different from common usage, we start from 0.
|
||||
sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
)
|
||||
latents = self.prepare_latents(batch_size, dtype, device, generator)
|
||||
|
||||
guidance = None
|
||||
if hasattr(self.model, 'guidance_embed') and \
|
||||
self.model.guidance_embed is True:
|
||||
guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype)
|
||||
# logger.info(f'Using guidance embed with scale {guidance_scale}')
|
||||
|
||||
with synchronize_timer('Diffusion Sampling'):
|
||||
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:")):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
# NOTE: we assume model get timesteps ranged from 0 to 1
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
timestep = timestep / self.scheduler.config.num_train_timesteps
|
||||
noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
outputs = self.scheduler.step(noise_pred, t, latents)
|
||||
latents = outputs.prev_sample
|
||||
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, outputs)
|
||||
|
||||
return self._export(
|
||||
latents,
|
||||
output_type,
|
||||
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
||||
enable_pbar=enable_pbar,
|
||||
)
|
||||
202
hy3dshape/hy3dshape/postprocessors.py
Normal file
202
hy3dshape/hy3dshape/postprocessors.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pymeshlab
|
||||
import torch
|
||||
import trimesh
|
||||
|
||||
from .models.autoencoders import Latent2MeshOutput
|
||||
from .utils import synchronize_timer
|
||||
|
||||
|
||||
def load_mesh(path):
|
||||
if path.endswith(".glb"):
|
||||
mesh = trimesh.load(path)
|
||||
else:
|
||||
mesh = pymeshlab.MeshSet()
|
||||
mesh.load_new_mesh(path)
|
||||
return mesh
|
||||
|
||||
|
||||
def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):
|
||||
if max_facenum > mesh.current_mesh().face_number():
|
||||
return mesh
|
||||
|
||||
mesh.apply_filter(
|
||||
"meshing_decimation_quadric_edge_collapse",
|
||||
targetfacenum=max_facenum,
|
||||
qualitythr=1.0,
|
||||
preserveboundary=True,
|
||||
boundaryweight=3,
|
||||
preservenormal=True,
|
||||
preservetopology=True,
|
||||
autoclean=True
|
||||
)
|
||||
return mesh
|
||||
|
||||
|
||||
def remove_floater(mesh: pymeshlab.MeshSet):
|
||||
mesh.apply_filter("compute_selection_by_small_disconnected_components_per_face",
|
||||
nbfaceratio=0.005)
|
||||
mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False)
|
||||
mesh.apply_filter("meshing_remove_selected_vertices_and_faces")
|
||||
return mesh
|
||||
|
||||
|
||||
def pymeshlab2trimesh(mesh: pymeshlab.MeshSet):
|
||||
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
||||
mesh.save_current_mesh(temp_file.name)
|
||||
mesh = trimesh.load(temp_file.name)
|
||||
# 检查加载的对象类型
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
combined_mesh = trimesh.Trimesh()
|
||||
# 如果是Scene,遍历所有的geometry并合并
|
||||
for geom in mesh.geometry.values():
|
||||
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
||||
mesh = combined_mesh
|
||||
return mesh
|
||||
|
||||
|
||||
def trimesh2pymeshlab(mesh: trimesh.Trimesh):
|
||||
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
||||
if isinstance(mesh, trimesh.scene.Scene):
|
||||
for idx, obj in enumerate(mesh.geometry.values()):
|
||||
if idx == 0:
|
||||
temp_mesh = obj
|
||||
else:
|
||||
temp_mesh = temp_mesh + obj
|
||||
mesh = temp_mesh
|
||||
mesh.export(temp_file.name)
|
||||
mesh = pymeshlab.MeshSet()
|
||||
mesh.load_new_mesh(temp_file.name)
|
||||
return mesh
|
||||
|
||||
|
||||
def export_mesh(input, output):
|
||||
if isinstance(input, pymeshlab.MeshSet):
|
||||
mesh = output
|
||||
elif isinstance(input, Latent2MeshOutput):
|
||||
output = Latent2MeshOutput()
|
||||
output.mesh_v = output.current_mesh().vertex_matrix()
|
||||
output.mesh_f = output.current_mesh().face_matrix()
|
||||
mesh = output
|
||||
else:
|
||||
mesh = pymeshlab2trimesh(output)
|
||||
return mesh
|
||||
|
||||
|
||||
def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:
|
||||
if isinstance(mesh, str):
|
||||
mesh = load_mesh(mesh)
|
||||
elif isinstance(mesh, Latent2MeshOutput):
|
||||
mesh = pymeshlab.MeshSet()
|
||||
mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)
|
||||
mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
|
||||
|
||||
if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
|
||||
mesh = trimesh2pymeshlab(mesh)
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
class FaceReducer:
|
||||
@synchronize_timer('FaceReducer')
|
||||
def __call__(
|
||||
self,
|
||||
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||
max_facenum: int = 40000
|
||||
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
|
||||
ms = import_mesh(mesh)
|
||||
ms = reduce_face(ms, max_facenum=max_facenum)
|
||||
mesh = export_mesh(mesh, ms)
|
||||
return mesh
|
||||
|
||||
|
||||
class FloaterRemover:
|
||||
@synchronize_timer('FloaterRemover')
|
||||
def __call__(
|
||||
self,
|
||||
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
|
||||
ms = import_mesh(mesh)
|
||||
ms = remove_floater(ms)
|
||||
mesh = export_mesh(mesh, ms)
|
||||
return mesh
|
||||
|
||||
|
||||
class DegenerateFaceRemover:
|
||||
@synchronize_timer('DegenerateFaceRemover')
|
||||
def __call__(
|
||||
self,
|
||||
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
|
||||
ms = import_mesh(mesh)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
||||
ms.save_current_mesh(temp_file.name)
|
||||
ms = pymeshlab.MeshSet()
|
||||
ms.load_new_mesh(temp_file.name)
|
||||
|
||||
mesh = export_mesh(mesh, ms)
|
||||
return mesh
|
||||
|
||||
|
||||
def mesh_normalize(mesh):
|
||||
"""
|
||||
Normalize mesh vertices to sphere
|
||||
"""
|
||||
scale_factor = 1.2
|
||||
vtx_pos = np.asarray(mesh.vertices)
|
||||
max_bb = (vtx_pos - 0).max(0)[0]
|
||||
min_bb = (vtx_pos - 0).min(0)[0]
|
||||
|
||||
center = (max_bb + min_bb) / 2
|
||||
|
||||
scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
|
||||
|
||||
vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
|
||||
mesh.vertices = vtx_pos
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
class MeshSimplifier:
|
||||
def __init__(self, executable: str = None):
|
||||
if executable is None:
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
|
||||
self.executable = executable
|
||||
|
||||
@synchronize_timer('MeshSimplifier')
|
||||
def __call__(
|
||||
self,
|
||||
mesh: Union[trimesh.Trimesh],
|
||||
) -> Union[trimesh.Trimesh]:
|
||||
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:
|
||||
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:
|
||||
mesh.export(temp_input.name)
|
||||
os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
|
||||
ms = trimesh.load(temp_output.name, process=False)
|
||||
if isinstance(ms, trimesh.Scene):
|
||||
combined_mesh = trimesh.Trimesh()
|
||||
for geom in ms.geometry.values():
|
||||
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
||||
ms = combined_mesh
|
||||
ms = mesh_normalize(ms)
|
||||
return ms
|
||||
167
hy3dshape/hy3dshape/preprocessors.py
Normal file
167
hy3dshape/hy3dshape/preprocessors.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from einops import repeat, rearrange
|
||||
|
||||
|
||||
def array_to_tensor(np_array):
|
||||
image_pt = torch.tensor(np_array).float()
|
||||
image_pt = image_pt / 255 * 2 - 1
|
||||
image_pt = rearrange(image_pt, "h w c -> c h w")
|
||||
image_pts = repeat(image_pt, "c h w -> b c h w", b=1)
|
||||
return image_pts
|
||||
|
||||
|
||||
class ImageProcessorV2:
|
||||
def __init__(self, size=512, border_ratio=None):
|
||||
self.size = size
|
||||
self.border_ratio = border_ratio
|
||||
|
||||
@staticmethod
|
||||
def recenter(image, border_ratio: float = 0.2):
|
||||
""" recenter an image to leave some empty space at the image border.
|
||||
|
||||
Args:
|
||||
image (ndarray): input image, float/uint8 [H, W, 3/4]
|
||||
mask (ndarray): alpha mask, bool [H, W]
|
||||
border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.
|
||||
|
||||
Returns:
|
||||
ndarray: output image, float/uint8 [H, W, 3/4]
|
||||
"""
|
||||
|
||||
if image.shape[-1] == 4:
|
||||
mask = image[..., 3]
|
||||
else:
|
||||
mask = np.ones_like(image[..., 0:1]) * 255
|
||||
image = np.concatenate([image, mask], axis=-1)
|
||||
mask = mask[..., 0]
|
||||
|
||||
H, W, C = image.shape
|
||||
|
||||
size = max(H, W)
|
||||
result = np.zeros((size, size, C), dtype=np.uint8)
|
||||
|
||||
coords = np.nonzero(mask)
|
||||
x_min, x_max = coords[0].min(), coords[0].max()
|
||||
y_min, y_max = coords[1].min(), coords[1].max()
|
||||
h = x_max - x_min
|
||||
w = y_max - y_min
|
||||
if h == 0 or w == 0:
|
||||
raise ValueError('input image is empty')
|
||||
desired_size = int(size * (1 - border_ratio))
|
||||
scale = desired_size / max(h, w)
|
||||
h2 = int(h * scale)
|
||||
w2 = int(w * scale)
|
||||
x2_min = (size - h2) // 2
|
||||
x2_max = x2_min + h2
|
||||
|
||||
y2_min = (size - w2) // 2
|
||||
y2_max = y2_min + w2
|
||||
|
||||
result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(image[x_min:x_max, y_min:y_max], (w2, h2),
|
||||
interpolation=cv2.INTER_AREA)
|
||||
|
||||
bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
|
||||
|
||||
mask = result[..., 3:].astype(np.float32) / 255
|
||||
result = result[..., :3] * mask + bg * (1 - mask)
|
||||
|
||||
mask = mask * 255
|
||||
result = result.clip(0, 255).astype(np.uint8)
|
||||
mask = mask.clip(0, 255).astype(np.uint8)
|
||||
return result, mask
|
||||
|
||||
def load_image(self, image, border_ratio=0.15, to_tensor=True):
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
||||
image, mask = self.recenter(image, border_ratio=border_ratio)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
elif isinstance(image, Image.Image):
|
||||
image = image.convert("RGBA")
|
||||
image = np.asarray(image)
|
||||
image, mask = self.recenter(image, border_ratio=border_ratio)
|
||||
|
||||
image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
|
||||
mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)
|
||||
mask = mask[..., np.newaxis]
|
||||
|
||||
if to_tensor:
|
||||
image = array_to_tensor(image)
|
||||
mask = array_to_tensor(mask)
|
||||
return image, mask
|
||||
|
||||
def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
|
||||
if self.border_ratio is not None:
|
||||
border_ratio = self.border_ratio
|
||||
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
||||
outputs = {
|
||||
'image': image,
|
||||
'mask': mask
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
class MVImageProcessorV2(ImageProcessorV2):
|
||||
"""
|
||||
view order: front, front clockwise 90, back, front clockwise 270
|
||||
"""
|
||||
return_view_idx = True
|
||||
|
||||
def __init__(self, size=512, border_ratio=None):
|
||||
super().__init__(size, border_ratio)
|
||||
self.view2idx = {
|
||||
'front': 0,
|
||||
'left': 1,
|
||||
'back': 2,
|
||||
'right': 3
|
||||
}
|
||||
|
||||
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
|
||||
if self.border_ratio is not None:
|
||||
border_ratio = self.border_ratio
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
view_idxs = []
|
||||
for idx, (view_tag, image) in enumerate(image_dict.items()):
|
||||
view_idxs.append(self.view2idx[view_tag])
|
||||
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
||||
images.append(image)
|
||||
masks.append(mask)
|
||||
|
||||
zipped_lists = zip(view_idxs, images, masks)
|
||||
sorted_zipped_lists = sorted(zipped_lists)
|
||||
view_idxs, images, masks = zip(*sorted_zipped_lists)
|
||||
|
||||
image = torch.cat(images, 0).unsqueeze(0)
|
||||
mask = torch.cat(masks, 0).unsqueeze(0)
|
||||
outputs = {
|
||||
'image': image,
|
||||
'mask': mask,
|
||||
'view_idxs': view_idxs
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
IMAGE_PROCESSORS = {
|
||||
"v2": ImageProcessorV2,
|
||||
'mv_v2': MVImageProcessorV2,
|
||||
}
|
||||
|
||||
DEFAULT_IMAGEPROCESSOR = 'v2'
|
||||
25
hy3dshape/hy3dshape/rembg.py
Normal file
25
hy3dshape/hy3dshape/rembg.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from PIL import Image
|
||||
from rembg import remove, new_session
|
||||
|
||||
|
||||
class BackgroundRemover():
|
||||
def __init__(self):
|
||||
self.session = new_session()
|
||||
|
||||
def __call__(self, image: Image.Image):
|
||||
output = remove(image, session=self.session, bgcolor=[255, 255, 255, 0])
|
||||
return output
|
||||
480
hy3dshape/hy3dshape/schedulers.py
Normal file
480
hy3dshape/hy3dshape/schedulers.py
Normal file
@@ -0,0 +1,480 @@
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
|
||||
|
||||
Euler scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting=False,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32).copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
if not use_dynamic_shifting:
|
||||
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
||||
|
||||
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
||||
timestep = timestep.to(sample.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(sample.device)
|
||||
timestep = timestep.to(sample.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
||||
elif self.step_index is not None:
|
||||
# add_noise is called after first denoising step (for inpainting)
|
||||
step_indices = [self.step_index] * timestep.shape[0]
|
||||
else:
|
||||
# add noise is called before first denoising step to create initial latent(img2img)
|
||||
step_indices = [self.begin_index] * timestep.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(sample.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is None:
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
pcm_timesteps: int = 50,
|
||||
):
|
||||
sigmas = np.linspace(0, 1, num_train_timesteps)
|
||||
step_ratio = num_train_timesteps // pcm_timesteps
|
||||
|
||||
euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1
|
||||
euler_timesteps = np.asarray([0] + euler_timesteps.tolist())
|
||||
|
||||
self.euler_timesteps = euler_timesteps
|
||||
self.sigmas = sigmas[self.euler_timesteps]
|
||||
self.sigmas = torch.from_numpy((self.sigmas.copy())).to(dtype=torch.float32)
|
||||
self.timesteps = self.sigmas * num_train_timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)
|
||||
inference_indices = np.linspace(
|
||||
0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False
|
||||
)
|
||||
inference_indices = np.floor(inference_indices).astype(np.int64)
|
||||
inference_indices = torch.from_numpy(inference_indices).long()
|
||||
|
||||
self.sigmas_ = self.sigmas[inference_indices]
|
||||
timesteps = self.sigmas_ * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas_ = torch.cat(
|
||||
[self.sigmas_, torch.ones(1, device=self.sigmas_.device)]
|
||||
)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas_[self.step_index]
|
||||
sigma_next = self.sigmas_[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
pred_original_sample = sample + (1.0 - sigma) * model_output
|
||||
pred_original_sample = pred_original_sample.to(model_output.dtype)
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,
|
||||
pred_original_sample=pred_original_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
234
hy3dshape/hy3dshape/surface_loaders.py
Normal file
234
hy3dshape/hy3dshape/surface_loaders.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import trimesh
|
||||
|
||||
|
||||
def normalize_mesh(mesh, scale=0.9999):
|
||||
"""
|
||||
Normalize the mesh to fit inside a centered cube with a specified scale.
|
||||
|
||||
The mesh is translated so that its bounding box center is at the origin,
|
||||
then uniformly scaled so that the longest side of the bounding box fits within [-scale, scale].
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Input mesh to normalize.
|
||||
scale (float, optional): Scaling factor to slightly shrink the mesh inside the unit cube. Default is 0.9999.
|
||||
|
||||
Returns:
|
||||
trimesh.Trimesh: The normalized mesh with applied translation and scaling.
|
||||
"""
|
||||
bbox = mesh.bounds
|
||||
center = (bbox[1] + bbox[0]) / 2
|
||||
scale_ = (bbox[1] - bbox[0]).max()
|
||||
|
||||
mesh.apply_translation(-center)
|
||||
mesh.apply_scale(1 / scale_ * 2 * scale)
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
def sample_pointcloud(mesh, num=200000):
|
||||
"""
|
||||
Sample points uniformly from the surface of the mesh along with their corresponding face normals.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Input mesh to sample from.
|
||||
num (int, optional): Number of points to sample. Default is 200000.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- points: Sampled points as a float tensor of shape (num, 3).
|
||||
- normals: Corresponding normals as a float tensor of shape (num, 3).
|
||||
"""
|
||||
points, face_idx = mesh.sample(num, return_index=True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
points = torch.from_numpy(points.astype(np.float32))
|
||||
normals = torch.from_numpy(normals.astype(np.float32))
|
||||
return points, normals
|
||||
|
||||
|
||||
def load_surface(mesh, num_points=8192):
|
||||
"""
|
||||
Normalize the mesh, sample points and normals from its surface, and randomly select a subset.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Input mesh to process.
|
||||
num_points (int, optional): Number of points to randomly select
|
||||
from the sampled surface points. Default is 8192.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, trimesh.Trimesh]:
|
||||
- surface: Tensor of shape (1, num_points, 6), concatenating points and normals.
|
||||
- mesh: The normalized mesh.
|
||||
"""
|
||||
|
||||
mesh = normalize_mesh(mesh, scale=0.98)
|
||||
surface, normal = sample_pointcloud(mesh)
|
||||
|
||||
rng = np.random.default_rng(seed=0)
|
||||
ind = rng.choice(surface.shape[0], num_points, replace=False)
|
||||
surface = torch.FloatTensor(surface[ind])
|
||||
normal = torch.FloatTensor(normal[ind])
|
||||
|
||||
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0)
|
||||
|
||||
return surface, mesh
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num=16384):
|
||||
"""
|
||||
Sample points and normals preferentially from sharp edges of the mesh.
|
||||
|
||||
Sharp edges are detected based on the angle between vertex normals and face normals.
|
||||
Points are sampled along these edges proportionally to edge length.
|
||||
|
||||
Args:
|
||||
mesh (trimesh.Trimesh): Input mesh to sample from.
|
||||
num (int, optional): Number of points to sample from sharp edges. Default is 16384.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]:
|
||||
- samples: Sampled points along sharp edges, shape (num, 3).
|
||||
- normals: Corresponding interpolated normals, shape (num, 3).
|
||||
"""
|
||||
V = mesh.vertices
|
||||
N = mesh.face_normals
|
||||
VN = mesh.vertex_normals
|
||||
F = mesh.faces
|
||||
VN2 = np.ones(V.shape[0])
|
||||
for i in range(3):
|
||||
dot = np.stack((VN2[F[:, i]], np.sum(VN[F[:, i]] * N, axis=-1)), axis=-1)
|
||||
VN2[F[:, i]] = np.min(dot, axis=-1)
|
||||
|
||||
sharp_mask = VN2 < 0.985
|
||||
# collect edge
|
||||
edge_a = np.concatenate((F[:, 0], F[:, 1], F[:, 2]))
|
||||
edge_b = np.concatenate((F[:, 1], F[:, 2], F[:, 0]))
|
||||
sharp_edge = ((sharp_mask[edge_a] * sharp_mask[edge_b]))
|
||||
edge_a = edge_a[sharp_edge > 0]
|
||||
edge_b = edge_b[sharp_edge > 0]
|
||||
|
||||
sharp_verts_a = V[edge_a]
|
||||
sharp_verts_b = V[edge_b]
|
||||
sharp_verts_an = VN[edge_a]
|
||||
sharp_verts_bn = VN[edge_b]
|
||||
|
||||
weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1)
|
||||
weights /= np.sum(weights)
|
||||
|
||||
random_number = np.random.rand(num)
|
||||
w = np.random.rand(num, 1)
|
||||
index = np.searchsorted(weights.cumsum(), random_number)
|
||||
samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index]
|
||||
normals = w * sharp_verts_an[index] + (1 - w) * sharp_verts_bn[index]
|
||||
return samples, normals
|
||||
|
||||
|
||||
def load_surface_sharpegde(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True):
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
except Exception as err:
|
||||
mesh_full = trimesh.util.concatenate(mesh)
|
||||
mesh_full = normalize_mesh(mesh_full)
|
||||
|
||||
origin_num = mesh_full.faces.shape[0]
|
||||
original_vertices = mesh_full.vertices
|
||||
original_faces = mesh_full.faces
|
||||
|
||||
mesh = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[:origin_num])
|
||||
mesh_fill = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[origin_num:])
|
||||
area = mesh.area
|
||||
area_fill = mesh_fill.area
|
||||
sample_num = 499712 // 2
|
||||
num_fill = int(sample_num * (area_fill / (area + area_fill)))
|
||||
num = sample_num - num_fill
|
||||
|
||||
random_surface, random_normal = sample_pointcloud(mesh, num=num)
|
||||
if num_fill == 0:
|
||||
random_surface_fill, random_normal_fill = np.zeros((0, 3)), np.zeros((0, 3))
|
||||
else:
|
||||
random_surface_fill, random_normal_fill = sample_pointcloud(mesh_fill, num=num_fill)
|
||||
random_sharp_surface, sharp_normal = sharp_sample_pointcloud(mesh, num=sample_num)
|
||||
|
||||
# save_surface
|
||||
surface = np.concatenate((random_surface, random_normal), axis=1).astype(np.float16)
|
||||
surface_fill = np.concatenate((random_surface_fill, random_normal_fill), axis=1).astype(np.float16)
|
||||
sharp_surface = np.concatenate((random_sharp_surface, sharp_normal), axis=1).astype(np.float16)
|
||||
surface = np.concatenate((surface, surface_fill), axis=0)
|
||||
if sharpedge_flag:
|
||||
sharpedge_label = np.zeros((surface.shape[0], 1))
|
||||
surface = np.concatenate((surface, sharpedge_label), axis=1)
|
||||
sharpedge_label = np.ones((sharp_surface.shape[0], 1))
|
||||
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
|
||||
rng = np.random.default_rng()
|
||||
ind = rng.choice(surface.shape[0], num_points, replace=False)
|
||||
surface = torch.FloatTensor(surface[ind])
|
||||
ind = rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)
|
||||
sharp_surface = torch.FloatTensor(sharp_surface[ind])
|
||||
|
||||
return torch.cat([surface, sharp_surface], dim=0).unsqueeze(0), mesh_full
|
||||
|
||||
|
||||
class SurfaceLoader:
|
||||
def __init__(self, num_points=8192):
|
||||
self.num_points = num_points
|
||||
|
||||
def __call__(self, mesh_or_mesh_path, num_points=None):
|
||||
if num_points is None:
|
||||
num_points = self.num_points
|
||||
|
||||
mesh = mesh_or_mesh_path
|
||||
if isinstance(mesh, str):
|
||||
mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
|
||||
if isinstance(mesh, trimesh.scene.Scene):
|
||||
for idx, obj in enumerate(mesh.geometry.values()):
|
||||
if idx == 0:
|
||||
temp_mesh = obj
|
||||
else:
|
||||
temp_mesh = temp_mesh + obj
|
||||
mesh = temp_mesh
|
||||
surface, mesh = load_surface(mesh, num_points=num_points)
|
||||
return surface
|
||||
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
def __init__(self, num_uniform_points=8192, num_sharp_points=8192, **kwargs):
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.num_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_or_mesh_path, num_uniform_points=None, num_sharp_points=None):
|
||||
if num_uniform_points is None:
|
||||
num_uniform_points = self.num_uniform_points
|
||||
if num_sharp_points is None:
|
||||
num_sharp_points = self.num_sharp_points
|
||||
|
||||
mesh = mesh_or_mesh_path
|
||||
if isinstance(mesh, str):
|
||||
mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
|
||||
if isinstance(mesh, trimesh.scene.Scene):
|
||||
for idx, obj in enumerate(mesh.geometry.values()):
|
||||
if idx == 0:
|
||||
temp_mesh = obj
|
||||
else:
|
||||
temp_mesh = temp_mesh + obj
|
||||
mesh = temp_mesh
|
||||
surface, mesh = load_surface_sharpegde(mesh, num_points=num_uniform_points, num_sharp_points=num_sharp_points)
|
||||
return surface
|
||||
5
hy3dshape/hy3dshape/utils/__init__.py
Normal file
5
hy3dshape/hy3dshape/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .misc import get_config_from_file
|
||||
from .misc import instantiate_from_config
|
||||
from .utils import get_logger, logger, synchronize_timer, smart_load_model
|
||||
76
hy3dshape/hy3dshape/utils/ema.py
Normal file
76
hy3dshape/hy3dshape/utils/ema.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_updates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates
|
||||
else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '_____')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, model):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in model.parameters()]
|
||||
|
||||
def restore(self, model):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, model.parameters()):
|
||||
param.data.copy_(c_param.data)
|
||||
122
hy3dshape/hy3dshape/utils/misc.py
Normal file
122
hy3dshape/hy3dshape/utils/misc.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import importlib
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Union
|
||||
|
||||
|
||||
def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
|
||||
config_file = OmegaConf.load(config_file)
|
||||
|
||||
if 'base_config' in config_file.keys():
|
||||
if config_file['base_config'] == "default_base":
|
||||
base_config = OmegaConf.create()
|
||||
# base_config = get_default_config()
|
||||
elif config_file['base_config'].endswith(".yaml"):
|
||||
base_config = get_config_from_file(config_file['base_config'])
|
||||
else:
|
||||
raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
|
||||
|
||||
config_file = {key: value for key, value in config_file if key != "base_config"}
|
||||
|
||||
return OmegaConf.merge(base_config, config_file)
|
||||
|
||||
return config_file
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def get_obj_from_config(config):
|
||||
if "target" not in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
|
||||
return get_obj_from_str(config["target"])
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if "target" not in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
|
||||
cls = get_obj_from_str(config["target"])
|
||||
|
||||
if config.get("from_pretrained", None):
|
||||
return cls.from_pretrained(config["from_pretrained"])
|
||||
|
||||
params = config.get("params", dict())
|
||||
# params.update(kwargs)
|
||||
# instance = cls(**params)
|
||||
kwargs.update(params)
|
||||
instance = cls(**kwargs)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def instantiate_non_trainable_model(config):
|
||||
model = instantiate_from_config(config)
|
||||
model = model.eval()
|
||||
model.train = disabled_train
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def all_gather_batch(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
for tensor in tensors:
|
||||
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(
|
||||
tensor_all,
|
||||
tensor,
|
||||
async_op=False # performance opt
|
||||
)
|
||||
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
1
hy3dshape/hy3dshape/utils/trainings/__init__.py
Executable file
1
hy3dshape/hy3dshape/utils/trainings/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
213
hy3dshape/hy3dshape/utils/trainings/callback.py
Executable file
213
hy3dshape/hy3dshape/utils/trainings/callback.py
Executable file
@@ -0,0 +1,213 @@
|
||||
# ------------------------------------------------------------------------------------
|
||||
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
|
||||
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import time
|
||||
import wandb
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from typing import Tuple, Generic, Dict, Callable, Optional, Any
|
||||
from pprint import pprint
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
import pytorch_lightning.loggers
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.loggers.logger import DummyLogger
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from functools import wraps
|
||||
|
||||
def node_zero_only(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*args, **kwargs) -> Optional[Any]:
|
||||
if node_zero_only.node == 0:
|
||||
return fn(*args, **kwargs)
|
||||
return None
|
||||
return wrapped_fn
|
||||
|
||||
node_zero_only.node = getattr(node_zero_only, 'node', int(os.environ.get('NODE_RANK', 0)))
|
||||
|
||||
def node_zero_experiment(fn: Callable) -> Callable:
|
||||
"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""
|
||||
@wraps(fn)
|
||||
def experiment(self):
|
||||
@node_zero_only
|
||||
def get_experiment():
|
||||
return fn(self)
|
||||
return get_experiment() or DummyLogger.experiment
|
||||
return experiment
|
||||
|
||||
# customize wandb for node 0 only
|
||||
class MyWandbLogger(WandbLogger):
|
||||
@WandbLogger.experiment.getter
|
||||
@node_zero_experiment
|
||||
def experiment(self):
|
||||
return super().experiment
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(self, config: DictConfig, exp_config: DictConfig,
|
||||
basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None:
|
||||
super().__init__()
|
||||
self.logdir = basedir / logdir
|
||||
self.ckptdir = basedir / ckptdir
|
||||
self.config = config
|
||||
self.exp_config = exp_config
|
||||
|
||||
# def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
# if trainer.global_rank == 0:
|
||||
# # Create logdirs and save configs
|
||||
# os.makedirs(self.logdir, exist_ok=True)
|
||||
# os.makedirs(self.ckptdir, exist_ok=True)
|
||||
#
|
||||
# print("Experiment config")
|
||||
# print(self.exp_config.pretty())
|
||||
#
|
||||
# print("Model config")
|
||||
# print(self.config.pretty())
|
||||
|
||||
def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
|
||||
# print("Experiment config")
|
||||
# pprint(self.exp_config)
|
||||
#
|
||||
# print("Model config")
|
||||
# pprint(self.config)
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True,
|
||||
increase_log_steps: bool = True) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {
|
||||
pl.loggers.WandbLogger: self._wandb,
|
||||
pl.loggers.TestTubeLogger: self._testtube,
|
||||
}
|
||||
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
|
||||
@rank_zero_only
|
||||
def _wandb(self, pl_module, images, batch_idx, split):
|
||||
# raise ValueError("No way wandb")
|
||||
grids = dict()
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grids[f"{split}/{k}"] = wandb.Image(grid)
|
||||
pl_module.logger.experiment.log(grids)
|
||||
|
||||
@rank_zero_only
|
||||
def _testtube(self, pl_module, images, batch_idx, split):
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
|
||||
tag = f"{split}/{k}"
|
||||
pl_module.logger.experiment.add_image(
|
||||
tag, grid,
|
||||
global_step=pl_module.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir: str, split: str, images: Dict,
|
||||
global_step: int, current_epoch: int, batch_idx: int) -> None:
|
||||
root = os.path.join(save_dir, "results", split)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
||||
k,
|
||||
global_step,
|
||||
current_epoch,
|
||||
batch_idx)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int,
|
||||
split: str = "train") -> None:
|
||||
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "log_images") and
|
||||
callable(pl_module.log_images) and
|
||||
self.max_images > 0):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
|
||||
|
||||
for k in images:
|
||||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N].detach().cpu()
|
||||
if self.clamp:
|
||||
images[k] = images[k].clamp(0, 1)
|
||||
|
||||
self.log_local(pl_module.logger.save_dir, split, images,
|
||||
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, batch_idx: int) -> bool:
|
||||
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
||||
try:
|
||||
self.log_steps.pop(0)
|
||||
except IndexError:
|
||||
pass
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||
outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None:
|
||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||
outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor],
|
||||
dataloader_idx: int, batch_idx: int) -> None:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
53
hy3dshape/hy3dshape/utils/trainings/lr_scheduler.py
Executable file
53
hy3dshape/hy3dshape/utils/trainings/lr_scheduler.py
Executable file
@@ -0,0 +1,53 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseScheduler(object):
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LambdaWarmUpCosineFactorScheduler(BaseScheduler):
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
def __init__(self, warm_up_steps, f_min, f_max, f_start, max_decay_steps, verbosity_interval=0, **ignore_kwargs):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_f = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.f_start}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
f = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min + 0.5 * (self.f_max - self.f_min) * (1 + np.cos(t * np.pi))
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
128
hy3dshape/hy3dshape/utils/trainings/mesh.py
Executable file
128
hy3dshape/hy3dshape/utils/trainings/mesh.py
Executable file
@@ -0,0 +1,128 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from typing import Optional
|
||||
|
||||
import trimesh
|
||||
|
||||
|
||||
def save_obj(pointnp_px3, facenp_fx3, fname):
|
||||
fid = open(fname, "w")
|
||||
write_str = ""
|
||||
for pidx, p in enumerate(pointnp_px3):
|
||||
pp = p
|
||||
write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
|
||||
|
||||
for i, f in enumerate(facenp_fx3):
|
||||
f1 = f + 1
|
||||
write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
|
||||
fid.write(write_str)
|
||||
fid.close()
|
||||
return
|
||||
|
||||
|
||||
def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
|
||||
fol, na = os.path.split(fname)
|
||||
na, _ = os.path.splitext(na)
|
||||
|
||||
matname = "%s/%s.mtl" % (fol, na)
|
||||
fid = open(matname, "w")
|
||||
fid.write("newmtl material_0\n")
|
||||
fid.write("Kd 1 1 1\n")
|
||||
fid.write("Ka 0 0 0\n")
|
||||
fid.write("Ks 0.4 0.4 0.4\n")
|
||||
fid.write("Ns 10\n")
|
||||
fid.write("illum 2\n")
|
||||
fid.write("map_Kd %s.png\n" % na)
|
||||
fid.close()
|
||||
####
|
||||
|
||||
fid = open(fname, "w")
|
||||
fid.write("mtllib %s.mtl\n" % na)
|
||||
|
||||
for pidx, p3 in enumerate(pointnp_px3):
|
||||
pp = p3
|
||||
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
|
||||
|
||||
for pidx, p2 in enumerate(tcoords_px2):
|
||||
pp = p2
|
||||
fid.write("vt %f %f\n" % (pp[0], pp[1]))
|
||||
|
||||
fid.write("usemtl material_0\n")
|
||||
for i, f in enumerate(facenp_fx3):
|
||||
f1 = f + 1
|
||||
f2 = facetex_fx3[i] + 1
|
||||
fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
|
||||
fid.close()
|
||||
|
||||
PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
|
||||
os.path.join(fol, "%s.png" % na))
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MeshOutput(object):
|
||||
|
||||
def __init__(self,
|
||||
mesh_v: np.ndarray,
|
||||
mesh_f: np.ndarray,
|
||||
vertex_colors: Optional[np.ndarray] = None,
|
||||
uvs: Optional[np.ndarray] = None,
|
||||
mesh_tex_idx: Optional[np.ndarray] = None,
|
||||
tex_map: Optional[np.ndarray] = None):
|
||||
|
||||
self.mesh_v = mesh_v
|
||||
self.mesh_f = mesh_f
|
||||
self.vertex_colors = vertex_colors
|
||||
self.uvs = uvs
|
||||
self.mesh_tex_idx = mesh_tex_idx
|
||||
self.tex_map = tex_map
|
||||
|
||||
def contain_uv_texture(self):
|
||||
return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
|
||||
|
||||
def contain_vertex_colors(self):
|
||||
return self.vertex_colors is not None
|
||||
|
||||
def export(self, fname):
|
||||
|
||||
if self.contain_uv_texture():
|
||||
savemeshtes2(
|
||||
self.mesh_v,
|
||||
self.uvs,
|
||||
self.mesh_f,
|
||||
self.mesh_tex_idx,
|
||||
self.tex_map,
|
||||
fname
|
||||
)
|
||||
|
||||
elif self.contain_vertex_colors():
|
||||
mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
|
||||
mesh_obj.export(fname)
|
||||
|
||||
else:
|
||||
save_obj(
|
||||
self.mesh_v,
|
||||
self.mesh_f,
|
||||
fname
|
||||
)
|
||||
|
||||
|
||||
|
||||
336
hy3dshape/hy3dshape/utils/trainings/mesh_log_callback.py
Executable file
336
hy3dshape/hy3dshape/utils/trainings/mesh_log_callback.py
Executable file
@@ -0,0 +1,336 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Tuple, Generic, Dict, List, Union, Optional
|
||||
|
||||
import trimesh
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import pytorch_lightning.loggers
|
||||
import torch
|
||||
import torchvision
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from hy3dshape.pipelines import export_to_trimesh
|
||||
from hy3dshape.utils.trainings.mesh import MeshOutput
|
||||
from hy3dshape.utils.visualizers import html_util
|
||||
from hy3dshape.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
|
||||
|
||||
|
||||
class ImageConditionalASLDiffuserLogger(Callback):
|
||||
def __init__(self,
|
||||
step_frequency: int,
|
||||
num_samples: int = 1,
|
||||
mean: Optional[Union[List[float], Tuple[float]]] = None,
|
||||
std: Optional[Union[List[float], Tuple[float]]] = None,
|
||||
bounds: Union[List[float], Tuple[float]] = (-1.1, -1.1, -1.1, 1.1, 1.1, 1.1),
|
||||
**kwargs) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.bbox_size = np.array(bounds[3:6]) - np.array(bounds[0:3])
|
||||
|
||||
if mean is not None:
|
||||
mean = np.asarray(mean)
|
||||
|
||||
if std is not None:
|
||||
std = np.asarray(std)
|
||||
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
self.step_freq = step_frequency
|
||||
self.num_samples = num_samples
|
||||
self.has_train_logged = False
|
||||
self.logger_log_images = {
|
||||
pl.loggers.WandbLogger: self._wandb,
|
||||
}
|
||||
|
||||
self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
|
||||
|
||||
@rank_zero_only
|
||||
def _wandb(self, pl_module, images, batch_idx, split):
|
||||
# raise ValueError("No way wandb")
|
||||
grids = dict()
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grids[f"{split}/{k}"] = wandb.Image(grid)
|
||||
pl_module.logger.experiment.log(grids)
|
||||
|
||||
def log_local(self,
|
||||
outputs: List[List['Latent2MeshOutput']],
|
||||
images: Union[np.ndarray, List[np.ndarray]],
|
||||
description: List[str],
|
||||
keys: List[str],
|
||||
save_dir: str, split: str,
|
||||
global_step: int, current_epoch: int, batch_idx: int,
|
||||
prog_bar: bool = False,
|
||||
multi_views=None, # yf ...
|
||||
) -> None:
|
||||
|
||||
folder = "gs-{:010}_e-{:06}_b-{:06}".format(global_step, current_epoch, batch_idx)
|
||||
visual_dir = os.path.join(save_dir, "visuals", split, folder)
|
||||
os.makedirs(visual_dir, exist_ok=True)
|
||||
|
||||
num_samples = len(images)
|
||||
|
||||
for i in range(num_samples):
|
||||
key_i = keys[i]
|
||||
image_i = self.denormalize_image(images[i])
|
||||
shape_tag_i = description[i]
|
||||
|
||||
for j in range(1):
|
||||
mesh = outputs[j][i]
|
||||
if mesh is None:
|
||||
continue
|
||||
|
||||
mesh_v = mesh.mesh_v.copy()
|
||||
mesh_v[:, 0] += j * np.max(self.bbox_size)
|
||||
self.viewer.add_mesh(mesh_v, mesh.mesh_f)
|
||||
|
||||
image_tag = html_util.to_image_embed_tag(image_i)
|
||||
mesh_tag = self.viewer.to_html(html_frame=False)
|
||||
|
||||
table_tag = f"""
|
||||
<table border = "1">
|
||||
<caption> {shape_tag_i} - {key_i} </caption>
|
||||
<caption> Input Image | Generated Mesh </caption>
|
||||
<tr>
|
||||
<td>{image_tag}</td>
|
||||
<td>{mesh_tag}</td>
|
||||
</tr>
|
||||
</table>
|
||||
"""
|
||||
|
||||
if multi_views is not None:
|
||||
multi_views_i = self.make_grid(multi_views[i])
|
||||
views_tag = html_util.to_image_embed_tag(self.denormalize_image(multi_views_i))
|
||||
table_tag = f"""
|
||||
<table border = "1">
|
||||
<caption> {shape_tag_i} - {key_i} </caption>
|
||||
<caption> Input Image | Generated Mesh </caption>
|
||||
<tr>
|
||||
<td>{image_tag}</td>
|
||||
<td>{views_tag}</td>
|
||||
<td>{mesh_tag}</td>
|
||||
</tr>
|
||||
</table>
|
||||
"""
|
||||
|
||||
html_frame = html_util.to_html_frame(table_tag)
|
||||
if len(key_i) > 100:
|
||||
key_i = key_i[:100]
|
||||
with open(os.path.join(visual_dir, f"{key_i}.html"), "w") as writer:
|
||||
writer.write(html_frame)
|
||||
|
||||
self.viewer.reset()
|
||||
|
||||
def log_sample(self,
|
||||
pl_module: pl.LightningModule,
|
||||
batch: Dict[str, torch.FloatTensor],
|
||||
batch_idx: int,
|
||||
split: str = "train") -> None:
|
||||
"""
|
||||
|
||||
Args:
|
||||
pl_module:
|
||||
batch (dict): the batch sample information, and it contains:
|
||||
- surface (torch.FloatTensor):
|
||||
- image (torch.FloatTensor):
|
||||
batch_idx (int):
|
||||
split (str):
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
batch_size = len(batch["surface"])
|
||||
replace = batch_size < self.num_samples
|
||||
ids = np.random.choice(batch_size, self.num_samples, replace=replace)
|
||||
|
||||
with torch.no_grad():
|
||||
# run text to mesh
|
||||
# keys = [batch["__key__"][i] for i in ids]
|
||||
keys = [f'key_{i}' for i in ids]
|
||||
# texts = [batch["text"][i] for i in ids]
|
||||
texts = [f'text_{i}'for i in ids]
|
||||
# description = [batch["description"][i] for i in ids]
|
||||
description = [f'desc_{i}' for i in ids]
|
||||
images = batch["image"][ids]
|
||||
mask_input = batch["mask"][ids] if 'mask' in batch else None
|
||||
sample_batch = {
|
||||
"__key__": keys,
|
||||
"image": images,
|
||||
'text': texts,
|
||||
'mask': mask_input,
|
||||
}
|
||||
|
||||
# if 'cam_parm' in batch:
|
||||
# sample_batch['cam_parm'] = batch['cam_parm'][ids]
|
||||
|
||||
# if 'multi_views' in batch: # yf ...
|
||||
# sample_batch['multi_views'] = batch['multi_views'][ids]
|
||||
|
||||
outputs = pl_module.sample(
|
||||
batch=sample_batch,
|
||||
output_type='latents2mesh'
|
||||
)
|
||||
|
||||
images = images.cpu().float().numpy()
|
||||
# images = self.denormalize_image(images)
|
||||
# images = np.transpose(images, (0, 2, 3, 1))
|
||||
# images = ((images + 1) / 2 * 255).astype(np.uint8)
|
||||
|
||||
self.log_local(outputs, images, description, keys, pl_module.logger.save_dir, split,
|
||||
pl_module.global_step, pl_module.current_epoch, batch_idx, prog_bar=False,
|
||||
multi_views=sample_batch.get('multi_views'))
|
||||
|
||||
if is_train: pl_module.train()
|
||||
|
||||
def make_grid(self, images): # return (3,h,w) in (0,1) ...
|
||||
images_resized = []
|
||||
for img in images:
|
||||
img_resized = torchvision.transforms.functional.resize(img, (320, 320))
|
||||
images_resized.append(img_resized)
|
||||
image = torchvision.utils.make_grid(images_resized, nrow=2, padding=5, pad_value=255)
|
||||
|
||||
image = image.cpu().numpy()
|
||||
# image = np.transpose(image, (1, 2, 0))
|
||||
# image = (image * 255).astype(np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
def check_frequency(self, step: int) -> bool:
|
||||
if step % self.step_freq == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||
outputs: Generic, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> None:
|
||||
|
||||
if (self.check_frequency(pl_module.global_step) and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "sample") and
|
||||
callable(pl_module.sample) and
|
||||
self.num_samples > 0):
|
||||
self.log_sample(pl_module, batch, batch_idx, split="train")
|
||||
self.has_train_logged = True
|
||||
|
||||
def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||
outputs: Generic, batch: Dict[str, torch.FloatTensor],
|
||||
dataloader_idx: int, batch_idx: int) -> None:
|
||||
|
||||
if self.has_train_logged:
|
||||
self.log_sample(pl_module, batch, batch_idx, split="val")
|
||||
self.has_train_logged = False
|
||||
|
||||
def denormalize_image(self, image):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image (np.ndarray): [3, h, w]
|
||||
|
||||
Returns:
|
||||
image (np.ndarray): [h, w, 3], np.uint8, [0, 255].
|
||||
"""
|
||||
# image = np.transpose(image, (0, 2, 3, 1))
|
||||
image = np.transpose(image, (1, 2, 0))
|
||||
|
||||
if self.std is not None:
|
||||
image = image * self.std
|
||||
|
||||
if self.mean is not None:
|
||||
image = image + self.mean
|
||||
|
||||
image = (image * 255).astype(np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class ImageConditionalFixASLDiffuserLogger(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
step_frequency: int,
|
||||
test_data_path: str,
|
||||
max_size: int = None,
|
||||
save_dir: str = 'infer',
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.step_freq = step_frequency
|
||||
self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
|
||||
|
||||
self.test_data_path = test_data_path
|
||||
with open(self.test_data_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.file_list = data['file_list']
|
||||
self.file_folder = data['file_folder']
|
||||
if max_size is not None:
|
||||
self.file_list = self.file_list[:max_size]
|
||||
self.kwargs = kwargs
|
||||
self.save_dir = save_dir
|
||||
|
||||
def on_train_batch_end(
|
||||
self,
|
||||
trainer: pl.trainer.Trainer,
|
||||
pl_module: pl.LightningModule,
|
||||
outputs: Generic,
|
||||
batch: Dict[str, torch.FloatTensor],
|
||||
batch_idx: int,
|
||||
):
|
||||
if pl_module.global_step % self.step_freq == 0:
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
folder_path = self.file_folder
|
||||
folder_name = os.path.basename(folder_path)
|
||||
folder = "gs-{:010}_e-{:06}_b-{:06}".format(pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||
visual_dir = os.path.join(pl_module.logger.save_dir, self.save_dir, folder, folder_name)
|
||||
os.makedirs(visual_dir, exist_ok=True)
|
||||
|
||||
image_paths = self.file_list
|
||||
chunk_size = math.ceil(len(image_paths) / trainer.world_size)
|
||||
if pl_module.global_rank == trainer.world_size - 1:
|
||||
image_paths = image_paths[pl_module.global_rank * chunk_size:]
|
||||
else:
|
||||
image_paths = image_paths[pl_module.global_rank * chunk_size:(pl_module.global_rank + 1) * chunk_size]
|
||||
|
||||
print(f'Rank{pl_module.global_rank}: processing {len(image_paths)}|{len(self.file_list)} images')
|
||||
for image_path in image_paths:
|
||||
if folder_path in image_path:
|
||||
save_path = image_path.replace(folder_path, visual_dir)
|
||||
else:
|
||||
save_path = os.path.join(visual_dir, os.path.basename(image_path))
|
||||
save_path = os.path.splitext(save_path)[0] + '.glb'
|
||||
|
||||
print(image_path)
|
||||
with torch.no_grad():
|
||||
mesh = pl_module.sample(batch={"image": image_path}, **self.kwargs)[0][0]
|
||||
if isinstance(mesh, tuple) and len(mesh)==2:
|
||||
mesh = export_to_trimesh(mesh)
|
||||
elif isinstance(mesh, trimesh.Trimesh):
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
mesh.export(save_path)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
78
hy3dshape/hy3dshape/utils/trainings/peft.py
Normal file
78
hy3dshape/hy3dshape/utils/trainings/peft.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
|
||||
class PeftSaveCallback(Callback):
|
||||
def __init__(self, peft_model, save_dir: str, save_every_n_steps: int = None):
|
||||
super().__init__()
|
||||
self.peft_model = peft_model
|
||||
self.save_dir = save_dir
|
||||
self.save_every_n_steps = save_every_n_steps
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
def recursive_convert(self, obj):
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
if isinstance(obj, (OmegaConf, ListConfig)):
|
||||
return OmegaConf.to_container(obj, resolve=True)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self.recursive_convert(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self.recursive_convert(i) for i in obj]
|
||||
elif isinstance(obj, type):
|
||||
# 避免修改类对象
|
||||
return obj
|
||||
elif hasattr(obj, '__dict__'):
|
||||
for attr_name, attr_value in vars(obj).items():
|
||||
setattr(obj, attr_name, self.recursive_convert(attr_value))
|
||||
return obj
|
||||
else:
|
||||
return obj
|
||||
|
||||
# def recursive_convert(self, obj):
|
||||
# if isinstance(obj, (OmegaConf, ListConfig)):
|
||||
# return OmegaConf.to_container(obj, resolve=True)
|
||||
# elif isinstance(obj, dict):
|
||||
# return {k: self.recursive_convert(v) for k, v in obj.items()}
|
||||
# elif isinstance(obj, list):
|
||||
# return [self.recursive_convert(i) for i in obj]
|
||||
# elif hasattr(obj, '__dict__'):
|
||||
# for attr_name, attr_value in vars(obj).items():
|
||||
# setattr(obj, attr_name, self.recursive_convert(attr_value))
|
||||
# return obj
|
||||
# else:
|
||||
# return obj
|
||||
|
||||
def _convert_peft_config(self):
|
||||
pc = self.peft_model.peft_config
|
||||
self.peft_model.peft_config = self.recursive_convert(pc)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
self._convert_peft_config()
|
||||
save_path = os.path.join(self.save_dir, f"epoch_{trainer.current_epoch}")
|
||||
self.peft_model.save_pretrained(save_path)
|
||||
print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if self.save_every_n_steps is not None:
|
||||
global_step = trainer.global_step
|
||||
if global_step % self.save_every_n_steps == 0 and global_step > 0:
|
||||
self._convert_peft_config()
|
||||
save_path = os.path.join(self.save_dir, f"step_{global_step}")
|
||||
self.peft_model.save_pretrained(save_path)
|
||||
print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
|
||||
126
hy3dshape/hy3dshape/utils/utils.py
Normal file
126
hy3dshape/hy3dshape/utils/utils.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
return logger
|
||||
|
||||
|
||||
logger = get_logger('hy3dgen.shapgen')
|
||||
|
||||
|
||||
class synchronize_timer:
|
||||
""" Synchronized timer to count the inference time of `nn.Module.forward`.
|
||||
|
||||
Supports both context manager and decorator usage.
|
||||
|
||||
Example as context manager:
|
||||
```python
|
||||
with synchronize_timer('name') as t:
|
||||
run()
|
||||
```
|
||||
|
||||
Example as decorator:
|
||||
```python
|
||||
@synchronize_timer('Export to trimesh')
|
||||
def export_to_trimesh(mesh_output):
|
||||
pass
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry: start timing."""
|
||||
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
||||
self.start = torch.cuda.Event(enable_timing=True)
|
||||
self.end = torch.cuda.Event(enable_timing=True)
|
||||
self.start.record()
|
||||
return lambda: self.time
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
"""Context manager exit: stop timing and log results."""
|
||||
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
||||
self.end.record()
|
||||
torch.cuda.synchronize()
|
||||
self.time = self.start.elapsed_time(self.end)
|
||||
if self.name is not None:
|
||||
logger.info(f'{self.name} takes {self.time} ms')
|
||||
|
||||
def __call__(self, func):
|
||||
"""Decorator: wrap the function to time its execution."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with self:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def smart_load_model(
|
||||
model_path,
|
||||
subfolder,
|
||||
use_safetensors,
|
||||
variant,
|
||||
):
|
||||
original_model_path = model_path
|
||||
# try local path
|
||||
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
||||
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
||||
logger.info(f'Try to load model from local path: {model_path}')
|
||||
if not os.path.exists(model_path):
|
||||
logger.info('Model path not exists, try to download from huggingface')
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
# 只下载指定子目录
|
||||
path = snapshot_download(
|
||||
repo_id=original_model_path,
|
||||
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
|
||||
)
|
||||
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"You need to install HuggingFace Hub to load models from the hub."
|
||||
)
|
||||
raise RuntimeError(f"Model path {model_path} not found")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model path {original_model_path} not found")
|
||||
|
||||
extension = 'ckpt' if not use_safetensors else 'safetensors'
|
||||
variant = '' if variant is None else f'.{variant}'
|
||||
ckpt_name = f'model{variant}.{extension}'
|
||||
config_path = os.path.join(model_path, 'config.yaml')
|
||||
ckpt_path = os.path.join(model_path, ckpt_name)
|
||||
return config_path, ckpt_path
|
||||
1
hy3dshape/hy3dshape/utils/visualizers/__init__.py
Executable file
1
hy3dshape/hy3dshape/utils/visualizers/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
57
hy3dshape/hy3dshape/utils/visualizers/color_util.py
Executable file
57
hy3dshape/hy3dshape/utils/visualizers/color_util.py
Executable file
@@ -0,0 +1,57 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Helper functions
|
||||
def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
|
||||
colormap = plt.cm.get_cmap(colormap)
|
||||
if normalize:
|
||||
vmin = np.min(inp)
|
||||
vmax = np.max(inp)
|
||||
|
||||
norm = plt.Normalize(vmin, vmax)
|
||||
return colormap(norm(inp))[:, :3]
|
||||
|
||||
|
||||
def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
|
||||
# tex dims need to be power of two.
|
||||
array = np.ones((width, height, 3), dtype='float32')
|
||||
|
||||
# width in texels of each checker
|
||||
checker_w = width / n_checkers_x
|
||||
checker_h = height / n_checkers_y
|
||||
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
color_key = int(x / checker_w) + int(y / checker_h)
|
||||
if color_key % 2 == 0:
|
||||
array[x, y, :] = [1., 0.874, 0.0]
|
||||
else:
|
||||
array[x, y, :] = [0., 0., 0.]
|
||||
return array
|
||||
|
||||
|
||||
def gen_circle(width=256, height=256):
|
||||
xx, yy = np.mgrid[:width, :height]
|
||||
circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
|
||||
array = np.ones((width, height, 4), dtype='float32')
|
||||
array[:, :, 0] = (circle <= width)
|
||||
array[:, :, 1] = (circle <= width)
|
||||
array[:, :, 2] = (circle <= width)
|
||||
array[:, :, 3] = circle <= width
|
||||
return array
|
||||
|
||||
64
hy3dshape/hy3dshape/utils/visualizers/html_util.py
Executable file
64
hy3dshape/hy3dshape/utils/visualizers/html_util.py
Executable file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import io
|
||||
import base64
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def to_html_frame(content):
|
||||
|
||||
html_frame = f"""
|
||||
<html>
|
||||
<body>
|
||||
{content}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return html_frame
|
||||
|
||||
|
||||
def to_single_row_table(caption: str, content: str):
|
||||
|
||||
table_html = f"""
|
||||
<table border = "1">
|
||||
<caption>{caption}</caption>
|
||||
<tr>
|
||||
<td>{content}</td>
|
||||
</tr>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return table_html
|
||||
|
||||
|
||||
def to_image_embed_tag(image: np.ndarray):
|
||||
|
||||
# Convert np.ndarray to bytes
|
||||
img = Image.fromarray(image)
|
||||
raw_bytes = io.BytesIO()
|
||||
img.save(raw_bytes, "PNG")
|
||||
|
||||
# Encode bytes to base64
|
||||
image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
|
||||
|
||||
image_tag = f"""
|
||||
<img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
|
||||
"""
|
||||
|
||||
return image_tag
|
||||
549
hy3dshape/hy3dshape/utils/visualizers/pythreejs_viewer.py
Executable file
549
hy3dshape/hy3dshape/utils/visualizers/pythreejs_viewer.py
Executable file
@@ -0,0 +1,549 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
import numpy as np
|
||||
from ipywidgets import embed
|
||||
import pythreejs as p3s
|
||||
import uuid
|
||||
|
||||
from .color_util import get_colors, gen_circle, gen_checkers
|
||||
|
||||
|
||||
EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js"
|
||||
|
||||
|
||||
class PyThreeJSViewer(object):
|
||||
|
||||
def __init__(self, settings, render_mode="WEBSITE"):
|
||||
self.render_mode = render_mode
|
||||
self.__update_settings(settings)
|
||||
self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
|
||||
self._light2 = p3s.AmbientLight(intensity=0.5)
|
||||
self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
|
||||
aspect=self.__s["width"] / self.__s["height"], children=[self._light])
|
||||
self._orbit = p3s.OrbitControls(controlling=self._cam)
|
||||
self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80"
|
||||
self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
|
||||
width=self.__s["width"], height=self.__s["height"],
|
||||
antialias=self.__s["antialias"])
|
||||
|
||||
self.__objects = {}
|
||||
self.__cnt = 0
|
||||
|
||||
def jupyter_mode(self):
|
||||
self.render_mode = "JUPYTER"
|
||||
|
||||
def offline(self):
|
||||
self.render_mode = "OFFLINE"
|
||||
|
||||
def website(self):
|
||||
self.render_mode = "WEBSITE"
|
||||
|
||||
def __get_shading(self, shading):
|
||||
shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
|
||||
"side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
|
||||
"bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
|
||||
"line_width": 1.0, "line_color": "black",
|
||||
"point_color": "red", "point_size": 0.01, "point_shape": "circle",
|
||||
"text_color": "red"
|
||||
}
|
||||
for k in shading:
|
||||
shad[k] = shading[k]
|
||||
return shad
|
||||
|
||||
def __update_settings(self, settings={}):
|
||||
sett = {"width": 1600, "height": 800, "antialias": True, "scale": 1.5, "background": "#ffffff",
|
||||
"fov": 30}
|
||||
for k in settings:
|
||||
sett[k] = settings[k]
|
||||
self.__s = sett
|
||||
|
||||
def __add_object(self, obj, parent=None):
|
||||
if not parent: # Object is added to global scene and objects dict
|
||||
self.__objects[self.__cnt] = obj
|
||||
self.__cnt += 1
|
||||
self._scene.add(obj["mesh"])
|
||||
else: # Object is added to parent object and NOT to objects dict
|
||||
parent.add(obj["mesh"])
|
||||
|
||||
self.__update_view()
|
||||
|
||||
if self.render_mode == "JUPYTER":
|
||||
return self.__cnt - 1
|
||||
elif self.render_mode == "WEBSITE":
|
||||
return self
|
||||
|
||||
def __add_line_geometry(self, lines, shading, obj=None):
|
||||
lines = lines.astype("float32", copy=False)
|
||||
mi = np.min(lines, axis=0)
|
||||
ma = np.max(lines, axis=0)
|
||||
|
||||
geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
|
||||
material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
|
||||
# , vertexColors='VertexColors'),
|
||||
lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces')
|
||||
line_obj = {"geometry": geometry, "mesh": lines, "material": material,
|
||||
"max": ma, "min": mi, "type": "Lines", "wireframe": None}
|
||||
|
||||
if obj:
|
||||
return self.__add_object(line_obj, obj), line_obj
|
||||
else:
|
||||
return self.__add_object(line_obj)
|
||||
|
||||
def __update_view(self):
|
||||
if len(self.__objects) == 0:
|
||||
return
|
||||
ma = np.zeros((len(self.__objects), 3))
|
||||
mi = np.zeros((len(self.__objects), 3))
|
||||
for r, obj in enumerate(self.__objects):
|
||||
ma[r] = self.__objects[obj]["max"]
|
||||
mi[r] = self.__objects[obj]["min"]
|
||||
ma = np.max(ma, axis=0)
|
||||
mi = np.min(mi, axis=0)
|
||||
diag = np.linalg.norm(ma - mi)
|
||||
mean = ((ma - mi) / 2 + mi).tolist()
|
||||
scale = self.__s["scale"] * (diag)
|
||||
self._orbit.target = mean
|
||||
self._cam.lookAt(mean)
|
||||
self._cam.position = [mean[0], mean[1], mean[2] + scale]
|
||||
self._light.position = [mean[0], mean[1], mean[2] + scale]
|
||||
|
||||
self._orbit.exec_three_obj_method('update')
|
||||
self._cam.exec_three_obj_method('updateProjectionMatrix')
|
||||
|
||||
def __get_bbox(self, v):
|
||||
m = np.min(v, axis=0)
|
||||
M = np.max(v, axis=0)
|
||||
|
||||
# Corners of the bounding box
|
||||
v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
|
||||
[m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
|
||||
|
||||
f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
|
||||
[0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
|
||||
return v_box, f_box
|
||||
|
||||
def __get_colors(self, v, f, c, sh):
|
||||
coloring = "VertexColors"
|
||||
if type(c) == np.ndarray and c.size == 3: # Single color
|
||||
colors = np.ones_like(v)
|
||||
colors[:, 0] = c[0]
|
||||
colors[:, 1] = c[1]
|
||||
colors[:, 2] = c[2]
|
||||
# print("Single colors")
|
||||
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for
|
||||
if c.shape[0] == f.shape[0]: # faces
|
||||
colors = np.hstack([c, c, c]).reshape((-1, 3))
|
||||
coloring = "FaceColors"
|
||||
# print("Face color values")
|
||||
elif c.shape[0] == v.shape[0]: # vertices
|
||||
colors = c
|
||||
# print("Vertex color values")
|
||||
else: # Wrong size, fallback
|
||||
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
||||
colors = np.ones_like(v)
|
||||
colors[:, 0] = 1.0
|
||||
colors[:, 1] = 0.874
|
||||
colors[:, 2] = 0.0
|
||||
elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces
|
||||
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||
cc = get_colors(c, sh["colormap"], normalize=normalize,
|
||||
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||
# print(cc.shape)
|
||||
colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
|
||||
coloring = "FaceColors"
|
||||
# print("Face function values")
|
||||
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices
|
||||
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||
colors = get_colors(c, sh["colormap"], normalize=normalize,
|
||||
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||
# print("Vertex function values")
|
||||
|
||||
else:
|
||||
colors = np.ones_like(v)
|
||||
colors[:, 0] = 1.0
|
||||
colors[:, 1] = 0.874
|
||||
colors[:, 2] = 0.0
|
||||
|
||||
# No color
|
||||
if c is not None:
|
||||
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
||||
|
||||
return colors, coloring
|
||||
|
||||
def __get_point_colors(self, v, c, sh):
|
||||
v_color = True
|
||||
if c is None: # No color given, use global color
|
||||
# conv = mpl.colors.ColorConverter()
|
||||
colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"]))
|
||||
v_color = False
|
||||
elif isinstance(c, str): # No color given, use global color
|
||||
# conv = mpl.colors.ColorConverter()
|
||||
colors = c # np.array(conv.to_rgb(c))
|
||||
v_color = False
|
||||
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
|
||||
# Point color
|
||||
colors = c.astype("float32", copy=False)
|
||||
|
||||
elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
|
||||
# Function values for vertices, but the colors are features
|
||||
c_norm = np.linalg.norm(c, ord=2, axis=-1)
|
||||
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||
colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
|
||||
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||
colors = colors.astype("float32", copy=False)
|
||||
|
||||
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color
|
||||
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||
colors = get_colors(c, sh["colormap"], normalize=normalize,
|
||||
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||
colors = colors.astype("float32", copy=False)
|
||||
# print("Vertex function values")
|
||||
|
||||
else:
|
||||
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
||||
colors = sh["point_color"]
|
||||
v_color = False
|
||||
|
||||
return colors, v_color
|
||||
|
||||
def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
|
||||
shading.update(kwargs)
|
||||
sh = self.__get_shading(shading)
|
||||
mesh_obj = {}
|
||||
|
||||
# it is a tet
|
||||
if v.shape[1] == 3 and f.shape[1] == 4:
|
||||
f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
|
||||
for i in range(f.shape[0]):
|
||||
f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
|
||||
f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
|
||||
f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
|
||||
f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
|
||||
f = f_tmp
|
||||
|
||||
if v.shape[1] == 2:
|
||||
v = np.append(v, np.zeros([v.shape[0], 1]), 1)
|
||||
|
||||
# Type adjustment vertices
|
||||
v = v.astype("float32", copy=False)
|
||||
|
||||
# Color setup
|
||||
colors, coloring = self.__get_colors(v, f, c, sh)
|
||||
|
||||
# Type adjustment faces and colors
|
||||
c = colors.astype("float32", copy=False)
|
||||
|
||||
# Material and geometry setup
|
||||
ba_dict = {"color": p3s.BufferAttribute(c)}
|
||||
if coloring == "FaceColors":
|
||||
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
|
||||
for ii in range(f.shape[0]):
|
||||
# print(ii*3, f[ii])
|
||||
verts[ii * 3] = v[f[ii, 0]]
|
||||
verts[ii * 3 + 1] = v[f[ii, 1]]
|
||||
verts[ii * 3 + 2] = v[f[ii, 2]]
|
||||
v = verts
|
||||
else:
|
||||
f = f.astype("uint32", copy=False).ravel()
|
||||
ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
|
||||
|
||||
ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
|
||||
|
||||
if uv is not None:
|
||||
uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
|
||||
if texture_data is None:
|
||||
texture_data = gen_checkers(20, 20)
|
||||
tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
|
||||
material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
|
||||
roughness=sh["roughness"], metalness=sh["metalness"],
|
||||
flatShading=sh["flat"],
|
||||
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
|
||||
ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
|
||||
else:
|
||||
material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
|
||||
side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
|
||||
flatShading=sh["flat"],
|
||||
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
|
||||
|
||||
if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well
|
||||
ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
|
||||
|
||||
geometry = p3s.BufferGeometry(attributes=ba_dict)
|
||||
|
||||
if coloring == "VertexColors" and type(n) == type(None):
|
||||
geometry.exec_three_obj_method('computeVertexNormals')
|
||||
elif coloring == "FaceColors" and type(n) == type(None):
|
||||
geometry.exec_three_obj_method('computeFaceNormals')
|
||||
|
||||
# Mesh setup
|
||||
mesh = p3s.Mesh(geometry=geometry, material=material)
|
||||
|
||||
# Wireframe setup
|
||||
mesh_obj["wireframe"] = None
|
||||
if sh["wireframe"]:
|
||||
wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry
|
||||
wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
|
||||
wireframe = p3s.LineSegments(wf_geometry, wf_material)
|
||||
mesh.add(wireframe)
|
||||
mesh_obj["wireframe"] = wireframe
|
||||
|
||||
# Bounding box setup
|
||||
if sh["bbox"]:
|
||||
v_box, f_box = self.__get_bbox(v)
|
||||
_, bbox = self.add_edges(v_box, f_box, sh, mesh)
|
||||
mesh_obj["bbox"] = [bbox, v_box, f_box]
|
||||
|
||||
# Object setup
|
||||
mesh_obj["max"] = np.max(v, axis=0)
|
||||
mesh_obj["min"] = np.min(v, axis=0)
|
||||
mesh_obj["geometry"] = geometry
|
||||
mesh_obj["mesh"] = mesh
|
||||
mesh_obj["material"] = material
|
||||
mesh_obj["type"] = "Mesh"
|
||||
mesh_obj["shading"] = sh
|
||||
mesh_obj["coloring"] = coloring
|
||||
mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed
|
||||
|
||||
return self.__add_object(mesh_obj)
|
||||
|
||||
def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
|
||||
shading.update(kwargs)
|
||||
if len(beginning.shape) == 1:
|
||||
if len(beginning) == 2:
|
||||
beginning = np.array([[beginning[0], beginning[1], 0]])
|
||||
else:
|
||||
if beginning.shape[1] == 2:
|
||||
beginning = np.append(
|
||||
beginning, np.zeros([beginning.shape[0], 1]), 1)
|
||||
if len(ending.shape) == 1:
|
||||
if len(ending) == 2:
|
||||
ending = np.array([[ending[0], ending[1], 0]])
|
||||
else:
|
||||
if ending.shape[1] == 2:
|
||||
ending = np.append(
|
||||
ending, np.zeros([ending.shape[0], 1]), 1)
|
||||
|
||||
sh = self.__get_shading(shading)
|
||||
lines = np.hstack([beginning, ending])
|
||||
lines = lines.reshape((-1, 3))
|
||||
return self.__add_line_geometry(lines, sh, obj)
|
||||
|
||||
def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
|
||||
shading.update(kwargs)
|
||||
if vertices.shape[1] == 2:
|
||||
vertices = np.append(
|
||||
vertices, np.zeros([vertices.shape[0], 1]), 1)
|
||||
sh = self.__get_shading(shading)
|
||||
lines = np.zeros((edges.size, 3))
|
||||
cnt = 0
|
||||
for e in edges:
|
||||
lines[cnt, :] = vertices[e[0]]
|
||||
lines[cnt + 1, :] = vertices[e[1]]
|
||||
cnt += 2
|
||||
return self.__add_line_geometry(lines, sh, obj)
|
||||
|
||||
def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
|
||||
shading.update(kwargs)
|
||||
if len(points.shape) == 1:
|
||||
if len(points) == 2:
|
||||
points = np.array([[points[0], points[1], 0]])
|
||||
else:
|
||||
if points.shape[1] == 2:
|
||||
points = np.append(
|
||||
points, np.zeros([points.shape[0], 1]), 1)
|
||||
sh = self.__get_shading(shading)
|
||||
points = points.astype("float32", copy=False)
|
||||
mi = np.min(points, axis=0)
|
||||
ma = np.max(points, axis=0)
|
||||
|
||||
g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
|
||||
m_attributes = {"size": sh["point_size"]}
|
||||
|
||||
if sh["point_shape"] == "circle": # Plot circles
|
||||
tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
|
||||
m_attributes["map"] = tex
|
||||
m_attributes["alphaTest"] = 0.5
|
||||
m_attributes["transparency"] = True
|
||||
else: # Plot squares
|
||||
pass
|
||||
|
||||
colors, v_colors = self.__get_point_colors(points, c, sh)
|
||||
if v_colors: # Colors per point
|
||||
m_attributes["vertexColors"] = 'VertexColors'
|
||||
g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
|
||||
|
||||
else: # Colors for all points
|
||||
m_attributes["color"] = colors
|
||||
|
||||
material = p3s.PointsMaterial(**m_attributes)
|
||||
geometry = p3s.BufferGeometry(attributes=g_attributes)
|
||||
points = p3s.Points(geometry=geometry, material=material)
|
||||
point_obj = {"geometry": geometry, "mesh": points, "material": material,
|
||||
"max": ma, "min": mi, "type": "Points", "wireframe": None}
|
||||
|
||||
if obj:
|
||||
return self.__add_object(point_obj, obj), point_obj
|
||||
else:
|
||||
return self.__add_object(point_obj)
|
||||
|
||||
def remove_object(self, obj_id):
|
||||
if obj_id not in self.__objects:
|
||||
print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
|
||||
return
|
||||
self._scene.remove(self.__objects[obj_id]["mesh"])
|
||||
del self.__objects[obj_id]
|
||||
self.__update_view()
|
||||
|
||||
def reset(self):
|
||||
for obj_id in list(self.__objects.keys()).copy():
|
||||
self._scene.remove(self.__objects[obj_id]["mesh"])
|
||||
del self.__objects[obj_id]
|
||||
self.__update_view()
|
||||
|
||||
def update_object(self, oid=0, vertices=None, colors=None, faces=None):
|
||||
obj = self.__objects[oid]
|
||||
if type(vertices) != type(None):
|
||||
if obj["coloring"] == "FaceColors":
|
||||
f = obj["arrays"][1]
|
||||
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
|
||||
for ii in range(f.shape[0]):
|
||||
# print(ii*3, f[ii])
|
||||
verts[ii * 3] = vertices[f[ii, 0]]
|
||||
verts[ii * 3 + 1] = vertices[f[ii, 1]]
|
||||
verts[ii * 3 + 2] = vertices[f[ii, 2]]
|
||||
v = verts
|
||||
|
||||
else:
|
||||
v = vertices.astype("float32", copy=False)
|
||||
obj["geometry"].attributes["position"].array = v
|
||||
# self.wireframe.attributes["position"].array = v # Wireframe updates?
|
||||
obj["geometry"].attributes["position"].needsUpdate = True
|
||||
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
|
||||
if type(colors) != type(None):
|
||||
colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
|
||||
colors = colors.astype("float32", copy=False)
|
||||
obj["geometry"].attributes["color"].array = colors
|
||||
obj["geometry"].attributes["color"].needsUpdate = True
|
||||
if type(faces) != type(None):
|
||||
if obj["coloring"] == "FaceColors":
|
||||
print("Face updates are currently only possible in vertex color mode.")
|
||||
return
|
||||
f = faces.astype("uint32", copy=False).ravel()
|
||||
print(obj["geometry"].attributes)
|
||||
obj["geometry"].attributes["index"].array = f
|
||||
# self.wireframe.attributes["position"].array = v # Wireframe updates?
|
||||
obj["geometry"].attributes["index"].needsUpdate = True
|
||||
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
|
||||
# self.mesh.geometry.verticesNeedUpdate = True
|
||||
# self.mesh.geometry.elementsNeedUpdate = True
|
||||
# self.update()
|
||||
if self.render_mode == "WEBSITE":
|
||||
return self
|
||||
|
||||
# def update(self):
|
||||
# self.mesh.exec_three_obj_method('update')
|
||||
# self.orbit.exec_three_obj_method('update')
|
||||
# self.cam.exec_three_obj_method('updateProjectionMatrix')
|
||||
# self.scene.exec_three_obj_method('update')
|
||||
|
||||
def add_text(self, text, shading={}, **kwargs):
|
||||
shading.update(kwargs)
|
||||
sh = self.__get_shading(shading)
|
||||
tt = p3s.TextTexture(string=text, color=sh["text_color"])
|
||||
sm = p3s.SpriteMaterial(map=tt)
|
||||
text = p3s.Sprite(material=sm, scaleToTexture=True)
|
||||
self._scene.add(text)
|
||||
|
||||
# def add_widget(self, widget, callback):
|
||||
# self.widgets.append(widget)
|
||||
# widget.observe(callback, names='value')
|
||||
|
||||
# def add_dropdown(self, options, default, desc, cb):
|
||||
# widget = widgets.Dropdown(options=options, value=default, description=desc)
|
||||
# self.__widgets.append(widget)
|
||||
# widget.observe(cb, names="value")
|
||||
# display(widget)
|
||||
|
||||
# def add_button(self, text, cb):
|
||||
# button = widgets.Button(description=text)
|
||||
# self.__widgets.append(button)
|
||||
# button.on_click(cb)
|
||||
# display(button)
|
||||
|
||||
def to_html(self, imports=True, html_frame=True):
|
||||
# Bake positions (fixes centering bug in offline rendering)
|
||||
if len(self.__objects) == 0:
|
||||
return
|
||||
ma = np.zeros((len(self.__objects), 3))
|
||||
mi = np.zeros((len(self.__objects), 3))
|
||||
for r, obj in enumerate(self.__objects):
|
||||
ma[r] = self.__objects[obj]["max"]
|
||||
mi[r] = self.__objects[obj]["min"]
|
||||
ma = np.max(ma, axis=0)
|
||||
mi = np.min(mi, axis=0)
|
||||
diag = np.linalg.norm(ma - mi)
|
||||
mean = (ma - mi) / 2 + mi
|
||||
for r, obj in enumerate(self.__objects):
|
||||
v = self.__objects[obj]["geometry"].attributes["position"].array
|
||||
v -= mean
|
||||
# v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
|
||||
|
||||
scale = self.__s["scale"] * (diag)
|
||||
self._orbit.target = [0.0, 0.0, 0.0]
|
||||
self._cam.lookAt([0.0, 0.0, 0.0])
|
||||
# self._cam.position = [0.0, 0.0, scale]
|
||||
self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
|
||||
self._light.position = [0.0, 0.0, scale]
|
||||
|
||||
state = embed.dependency_state(self._renderer)
|
||||
|
||||
# Somehow these entries are missing when the state is exported in python.
|
||||
# Exporting from the GUI works, so we are inserting the missing entries.
|
||||
for k in state:
|
||||
if state[k]["model_name"] == "OrbitControlsModel":
|
||||
state[k]["state"]["maxAzimuthAngle"] = "inf"
|
||||
state[k]["state"]["maxDistance"] = "inf"
|
||||
state[k]["state"]["maxZoom"] = "inf"
|
||||
state[k]["state"]["minAzimuthAngle"] = "-inf"
|
||||
|
||||
tpl = embed.load_requirejs_template
|
||||
if not imports:
|
||||
embed.load_requirejs_template = ""
|
||||
|
||||
s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
|
||||
# s = embed.embed_snippet(self.__w, state=state)
|
||||
embed.load_requirejs_template = tpl
|
||||
|
||||
if html_frame:
|
||||
s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
|
||||
|
||||
# Revert changes
|
||||
for r, obj in enumerate(self.__objects):
|
||||
v = self.__objects[obj]["geometry"].attributes["position"].array
|
||||
v += mean
|
||||
self.__update_view()
|
||||
|
||||
return s
|
||||
|
||||
def save(self, filename=""):
|
||||
if filename == "":
|
||||
uid = str(uuid.uuid4()) + ".html"
|
||||
else:
|
||||
filename = filename.replace(".html", "")
|
||||
uid = filename + '.html'
|
||||
with open(uid, "w") as f:
|
||||
f.write(self.to_html())
|
||||
print("Plot saved to file %s." % uid)
|
||||
Reference in New Issue
Block a user