init
This commit is contained in:
13
hy3dpaint/src/__init__.py
Executable file
13
hy3dpaint/src/__init__.py
Executable file
@@ -0,0 +1,13 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
13
hy3dpaint/src/data/__init__.py
Executable file
13
hy3dpaint/src/data/__init__.py
Executable file
@@ -0,0 +1,13 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
219
hy3dpaint/src/data/dataloader/loader_util.py
Normal file
219
hy3dpaint/src/data/dataloader/loader_util.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image, ImageOps, ImageChops
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
def __init__(self, json_path, num_view=4, image_size=512):
|
||||
self.data = list()
|
||||
self.num_view = num_view
|
||||
self.image_size = image_size
|
||||
if isinstance(json_path, str):
|
||||
json_path = [json_path]
|
||||
for jp in json_path:
|
||||
with open(jp) as f:
|
||||
self.data.extend(json.load(f))
|
||||
print("============= length of dataset %d =============" % len(self.data))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def load_image(self, pil_img, color, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if isinstance(pil_img, str):
|
||||
pil_img = Image.open(pil_img)
|
||||
else:
|
||||
pil_img = pil_img
|
||||
if pil_img.mode == "L":
|
||||
pil_img = pil_img.convert("RGB")
|
||||
pil_img = pil_img.resize((image_size, image_size))
|
||||
image = np.asarray(pil_img, dtype=np.float32) / 255.0
|
||||
if image.shape[2] == 3:
|
||||
image = image[:, :, :3]
|
||||
alpha = np.ones_like(image)
|
||||
else:
|
||||
alpha = image[:, :, 3:]
|
||||
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
||||
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
||||
return image, alpha
|
||||
|
||||
def _apply_scaling(self, image, scale_factor, width, height, bg_color, scale_width=True):
|
||||
"""Apply scaling to image with proper cropping or padding."""
|
||||
if scale_width:
|
||||
new_width = int(width * scale_factor)
|
||||
new_height = height
|
||||
else:
|
||||
new_width = width
|
||||
new_height = int(height * scale_factor)
|
||||
|
||||
image = image.resize((new_width, new_height), resample=Image.BILINEAR)
|
||||
|
||||
if scale_factor > 1.0:
|
||||
# Crop to original size
|
||||
left = (new_width - width) // 2
|
||||
top = (new_height - height) // 2
|
||||
image = image.crop((left, top, left + width, top + height))
|
||||
else:
|
||||
# Pad to original size
|
||||
pad_width = (width - new_width) // 2
|
||||
pad_height = (height - new_height) // 2
|
||||
image = ImageOps.expand(
|
||||
image,
|
||||
(
|
||||
pad_width,
|
||||
pad_height,
|
||||
width - new_width - pad_width,
|
||||
height - new_height - pad_height,
|
||||
),
|
||||
fill=bg_color,
|
||||
)
|
||||
return image
|
||||
|
||||
def _apply_rotation(self, image, bg_color):
|
||||
"""Apply random rotation to image."""
|
||||
original_size = image.size
|
||||
angle = random.uniform(-30, 30)
|
||||
image = image.convert("RGBA")
|
||||
rotated_image = image.rotate(angle, resample=Image.BILINEAR, expand=True)
|
||||
|
||||
# Create background with bg_color
|
||||
background = Image.new("RGBA", rotated_image.size, (bg_color[0], bg_color[1], bg_color[2], 255))
|
||||
background.paste(rotated_image, (0, 0), rotated_image)
|
||||
image = background.convert("RGB")
|
||||
|
||||
# Crop to original size
|
||||
left = (image.width - original_size[0]) // 2
|
||||
top = (image.height - original_size[1]) // 2
|
||||
right = left + original_size[0]
|
||||
bottom = top + original_size[1]
|
||||
|
||||
return image.crop((left, top, right, bottom))
|
||||
|
||||
def _apply_translation(self, image, bg_color):
|
||||
"""Apply random translation to image."""
|
||||
max_dx = 0.1 * image.size[0]
|
||||
max_dy = 0.1 * image.size[1]
|
||||
dx = int(random.uniform(-max_dx, max_dx))
|
||||
dy = int(random.uniform(-max_dy, max_dy))
|
||||
|
||||
image = ImageChops.offset(image, dx, dy)
|
||||
|
||||
# Fill edges
|
||||
width, height = image.size
|
||||
if dx > 0:
|
||||
image.paste(bg_color, (0, 0, dx, height))
|
||||
elif dx < 0:
|
||||
image.paste(bg_color, (width + dx, 0, width, height))
|
||||
|
||||
if dy > 0:
|
||||
image.paste(bg_color, (0, 0, width, dy))
|
||||
elif dy < 0:
|
||||
image.paste(bg_color, (0, height + dy, width, height))
|
||||
|
||||
return image
|
||||
|
||||
def _apply_perspective(self, image, bg_color):
|
||||
"""Apply random perspective transformation to image."""
|
||||
image_np = np.array(image)
|
||||
height, width = image_np.shape[:2]
|
||||
|
||||
# Define original and new points
|
||||
original_points = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
|
||||
perspective_scale = 0.2
|
||||
|
||||
new_points = np.float32(
|
||||
[
|
||||
[random.uniform(0, width * perspective_scale), random.uniform(0, height * perspective_scale)],
|
||||
[random.uniform(width * (1 - perspective_scale), width), random.uniform(0, height * perspective_scale)],
|
||||
[
|
||||
random.uniform(width * (1 - perspective_scale), width),
|
||||
random.uniform(height * (1 - perspective_scale), height),
|
||||
],
|
||||
[
|
||||
random.uniform(0, width * perspective_scale),
|
||||
random.uniform(height * (1 - perspective_scale), height),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
matrix = cv2.getPerspectiveTransform(original_points, new_points)
|
||||
image_np = cv2.warpPerspective(
|
||||
image_np, matrix, (width, height), borderMode=cv2.BORDER_CONSTANT, borderValue=bg_color
|
||||
)
|
||||
|
||||
return Image.fromarray(image_np)
|
||||
|
||||
def augment_image(
|
||||
self,
|
||||
image,
|
||||
bg_color,
|
||||
identity_prob=0.5,
|
||||
rotate_prob=0.3,
|
||||
scale_prob=0.5,
|
||||
translate_prob=0.5,
|
||||
perspective_prob=0.3,
|
||||
):
|
||||
if random.random() < identity_prob:
|
||||
return image
|
||||
|
||||
# Convert torch tensors back to PIL images for augmentation
|
||||
image = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
||||
bg_color = (int(bg_color[0] * 255), int(bg_color[1] * 255), int(bg_color[2] * 255))
|
||||
|
||||
# Random rotation
|
||||
if random.random() < rotate_prob:
|
||||
image = self._apply_rotation(image, bg_color)
|
||||
|
||||
# Random scaling
|
||||
if random.random() < scale_prob:
|
||||
width, height = image.size
|
||||
scale_factor = random.uniform(0.8, 1.2)
|
||||
|
||||
if random.random() < 0.5:
|
||||
# Scale both dimensions proportionally
|
||||
image = self._apply_scaling(image, scale_factor, width, height, bg_color, scale_width=True)
|
||||
image = self._apply_scaling(image, scale_factor, width, height, bg_color, scale_width=False)
|
||||
else:
|
||||
# Scale width then height independently
|
||||
scale_factor_w = random.uniform(0.8, 1.2)
|
||||
scale_factor_h = random.uniform(0.8, 1.2)
|
||||
image = self._apply_scaling(image, scale_factor_w, width, height, bg_color, scale_width=True)
|
||||
image = self._apply_scaling(image, scale_factor_h, width, height, bg_color, scale_width=False)
|
||||
|
||||
# Random translation
|
||||
if random.random() < translate_prob:
|
||||
image = self._apply_translation(image, bg_color)
|
||||
|
||||
# Random perspective
|
||||
if random.random() < perspective_prob:
|
||||
image = self._apply_perspective(image, bg_color)
|
||||
|
||||
# Convert back to torch tensors
|
||||
image = image.convert("RGB")
|
||||
image = np.asarray(image, dtype=np.float32) / 255.0
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
||||
|
||||
return image
|
||||
146
hy3dpaint/src/data/dataloader/objaverse_loader_forTexturePBR.py
Normal file
146
hy3dpaint/src/data/dataloader/objaverse_loader_forTexturePBR.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from .loader_util import BaseDataset
|
||||
|
||||
|
||||
class TextureDataset(BaseDataset):
|
||||
|
||||
def __init__(
|
||||
self, json_path, num_view=6, image_size=512, lighting_suffix_pool=["light_PL", "light_AL", "light_ENVMAP"]
|
||||
):
|
||||
self.data = list()
|
||||
self.num_view = num_view
|
||||
self.image_size = image_size
|
||||
self.lighting_suffix_pool = lighting_suffix_pool
|
||||
if isinstance(json_path, str):
|
||||
json_path = [json_path]
|
||||
for jp in json_path:
|
||||
with open(jp) as f:
|
||||
self.data.extend(json.load(f))
|
||||
print("============= length of dataset %d =============" % len(self.data))
|
||||
|
||||
def __getitem__(self, index):
|
||||
try_sleep_interval = 20
|
||||
total_try_num = 100
|
||||
cnt = try_sleep_interval * total_try_num
|
||||
# try:
|
||||
images_ref = list()
|
||||
images_albedo = list()
|
||||
images_mr = list()
|
||||
images_normal = list()
|
||||
images_position = list()
|
||||
bg_white = [1.0, 1.0, 1.0]
|
||||
bg_black = [0.0, 0.0, 0.0]
|
||||
bg_gray = [127 / 255.0, 127 / 255.0, 127 / 255.0]
|
||||
dirx = self.data[index]
|
||||
|
||||
condition_dict = {}
|
||||
|
||||
# 6view
|
||||
fix_num_view = self.num_view
|
||||
available_views = []
|
||||
for ext in ["*_albedo.png", "*_albedo.jpg", "*_albedo.jpeg"]:
|
||||
available_views.extend(glob.glob(os.path.join(dirx, "render_tex", ext)))
|
||||
cond_images = (
|
||||
glob.glob(os.path.join(dirx, "render_cond", "*.png"))
|
||||
+ glob.glob(os.path.join(dirx, "render_cond", "*.jpg"))
|
||||
+ glob.glob(os.path.join(dirx, "render_cond", "*.jpeg"))
|
||||
)
|
||||
|
||||
# 确保有足够的样本
|
||||
if len(available_views) < fix_num_view:
|
||||
print(
|
||||
f"Warning: Only {len(available_views)} views available, but {fix_num_view} requested."
|
||||
"Using all available views."
|
||||
)
|
||||
images_gen = available_views
|
||||
else:
|
||||
images_gen = random.sample(available_views, fix_num_view)
|
||||
|
||||
if not cond_images:
|
||||
raise ValueError(f"No condition images found in {os.path.join(dirx, 'render_cond')}")
|
||||
ref_image_path = random.choice(cond_images)
|
||||
light_suffix = None
|
||||
for suffix in self.lighting_suffix_pool:
|
||||
if suffix in ref_image_path:
|
||||
light_suffix = suffix
|
||||
break
|
||||
if light_suffix is None:
|
||||
raise ValueError(f"light suffix not found in {ref_image_path}")
|
||||
ref_image_diff_light_path = random.choice(
|
||||
[
|
||||
ref_image_path.replace(light_suffix, tar_suffix)
|
||||
for tar_suffix in self.lighting_suffix_pool
|
||||
if tar_suffix != light_suffix
|
||||
]
|
||||
)
|
||||
images_ref_paths = [ref_image_path, ref_image_diff_light_path]
|
||||
|
||||
# Data aug
|
||||
bg_c_record = None
|
||||
for i, image_ref in enumerate(images_ref_paths):
|
||||
if random.random() < 0.6:
|
||||
bg_c = bg_gray
|
||||
else:
|
||||
if random.random() < 0.5:
|
||||
bg_c = bg_black
|
||||
else:
|
||||
bg_c = bg_white
|
||||
if i == 0:
|
||||
bg_c_record = bg_c
|
||||
image, alpha = self.load_image(image_ref, bg_c_record)
|
||||
image = self.augment_image(image, bg_c_record).float()
|
||||
images_ref.append(image)
|
||||
condition_dict["images_cond"] = torch.stack(images_ref, dim=0).float()
|
||||
|
||||
for i, image_gen in enumerate(images_gen):
|
||||
images_albedo.append(self.augment_image(self.load_image(image_gen, bg_gray)[0], bg_gray))
|
||||
images_mr.append(
|
||||
self.augment_image(self.load_image(image_gen.replace("_albedo", "_mr"), bg_gray)[0], bg_gray)
|
||||
)
|
||||
images_normal.append(
|
||||
self.augment_image(self.load_image(image_gen.replace("_albedo", "_normal"), bg_gray)[0], bg_gray)
|
||||
)
|
||||
images_position.append(
|
||||
self.augment_image(self.load_image(image_gen.replace("_albedo", "_pos"), bg_gray)[0], bg_gray)
|
||||
)
|
||||
|
||||
condition_dict["images_albedo"] = torch.stack(images_albedo, dim=0).float()
|
||||
condition_dict["images_mr"] = torch.stack(images_mr, dim=0).float()
|
||||
condition_dict["images_normal"] = torch.stack(images_normal, dim=0).float()
|
||||
condition_dict["images_position"] = torch.stack(images_position, dim=0).float()
|
||||
condition_dict["name"] = dirx # .replace('/', '_')
|
||||
return condition_dict # (N, 3, H, W)
|
||||
|
||||
# except Exception as e:
|
||||
# print(e, self.data[index])
|
||||
# # exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = TextureDataset(json_path=["../../../train_examples/examples.json"])
|
||||
print("images_cond", dataset[0]["images_cond"].shape)
|
||||
print("images_albedo", dataset[0]["images_albedo"].shape)
|
||||
print("images_mr", dataset[0]["images_mr"].shape)
|
||||
print("images_normal", dataset[0]["images_normal"].shape)
|
||||
print("images_position", dataset[0]["images_position"].shape)
|
||||
print("name", dataset[0]["name"])
|
||||
10
hy3dpaint/src/data/dataloader/pbr_data_format.txt
Normal file
10
hy3dpaint/src/data/dataloader/pbr_data_format.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
+-----------------+----------------------------------+
|
||||
| Key | Value |
|
||||
+-----------------+----------------------------------+
|
||||
| images_cond | torch.Size([2, 2, 3, 512, 512]) |
|
||||
| images_albedo | torch.Size([2, 6, 3, 512, 512]) |
|
||||
| images_mr | torch.Size([2, 6, 3, 512, 512]) |
|
||||
| images_normal | torch.Size([2, 6, 3, 512, 512]) |
|
||||
| images_position | torch.Size([2, 6, 3, 512, 512]) |
|
||||
| caption | ['high quality', 'high quality'] |
|
||||
+-----------------+----------------------------------+
|
||||
79
hy3dpaint/src/data/objaverse_hunyuan.py
Executable file
79
hy3dpaint/src/data/objaverse_hunyuan.py
Executable file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import Dataset, ConcatDataset, DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.dataset_configs = dict()
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
|
||||
def setup(self, stage):
|
||||
from src.utils.train_util import instantiate_from_config
|
||||
|
||||
if stage in ["fit"]:
|
||||
dataset_dict = {}
|
||||
for k in self.dataset_configs:
|
||||
dataset_dict[k] = []
|
||||
for loader in self.dataset_configs[k]:
|
||||
dataset_dict[k].append(instantiate_from_config(loader))
|
||||
self.datasets = dataset_dict
|
||||
print(self.datasets)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def train_dataloader(self):
|
||||
datasets = ConcatDataset(self.datasets["train"])
|
||||
sampler = DistributedSampler(datasets)
|
||||
return DataLoader(
|
||||
datasets,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
sampler=sampler,
|
||||
prefetch_factor=2,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
datasets = ConcatDataset(self.datasets["validation"])
|
||||
sampler = DistributedSampler(datasets)
|
||||
return DataLoader(datasets, batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
||||
|
||||
def test_dataloader(self):
|
||||
datasets = ConcatDataset(self.datasets["test"])
|
||||
return DataLoader(datasets, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
||||
13
hy3dpaint/src/utils/__init__.py
Executable file
13
hy3dpaint/src/utils/__init__.py
Executable file
@@ -0,0 +1,13 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
40
hy3dpaint/src/utils/train_util.py
Executable file
40
hy3dpaint/src/utils/train_util.py
Executable file
@@ -0,0 +1,40 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
Reference in New Issue
Block a user