fix shape training
This commit is contained in:
@@ -548,7 +548,7 @@ class PointCrossAttentionEncoder(nn.Module):
|
||||
|
||||
if pc_sharpedge_size == 0:
|
||||
print(
|
||||
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given, using pc_size as pc_sharpedge_size')
|
||||
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is zero')
|
||||
else:
|
||||
print(
|
||||
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')
|
||||
|
||||
@@ -32,6 +32,7 @@ from transformers import (
|
||||
Dinov2Model,
|
||||
Dinov2Config,
|
||||
)
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
@@ -66,9 +67,10 @@ class ImageEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
self.model = self.MODEL_CLASS.from_pretrained(version)
|
||||
self.model = AutoModel.from_pretrained(version)
|
||||
else:
|
||||
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
|
||||
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
self.use_cls_token = use_cls_token
|
||||
@@ -240,11 +242,26 @@ class SingleImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
main_image_encoder,
|
||||
drop_ratio=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||
self.drop_ratio = drop_ratio
|
||||
self.disable_drop = True
|
||||
|
||||
def forward(self, image, mask=None, **kwargs):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
||||
}
|
||||
if self.disable_drop:
|
||||
return outputs
|
||||
else:
|
||||
random_p = torch.rand(len(image), device='cuda')
|
||||
remain_bool_tensor = random_p > self.drop_ratio
|
||||
outputs['main'] *= remain_bool_tensor.view(-1,1,1)
|
||||
return outputs
|
||||
|
||||
|
||||
outputs = {
|
||||
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
@@ -31,6 +33,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .moe_layers import MoEBlock
|
||||
from ...utils import logger, synchronize_timer, smart_load_model
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@@ -464,6 +467,74 @@ class FinalLayer(nn.Module):
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
@classmethod
|
||||
@synchronize_timer('HunYuanDiTPlain Model Loading')
|
||||
def from_single_file(
|
||||
cls,
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# load ckpt
|
||||
if use_safetensors:
|
||||
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
if use_safetensors:
|
||||
import safetensors.torch
|
||||
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
|
||||
if 'model' in ckpt:
|
||||
ckpt = ckpt['model']
|
||||
if 'model' in config:
|
||||
config = config['model']
|
||||
|
||||
model_kwargs = config['params']
|
||||
model_kwargs.update(kwargs)
|
||||
|
||||
model = cls(**model_kwargs)
|
||||
model.load_state_dict(ckpt)
|
||||
model.to(device=device, dtype=dtype)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=False,
|
||||
variant='fp16',
|
||||
subfolder='hunyuan3d-dit-v2-1',
|
||||
**kwargs,
|
||||
):
|
||||
config_path, ckpt_path = smart_load_model(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
use_safetensors=use_safetensors,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
return cls.from_single_file(
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_safetensors=use_safetensors,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=1024,
|
||||
|
||||
@@ -256,17 +256,14 @@ class Diffuser(pl.LightningModule):
|
||||
def forward(self, batch):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): #float32 for text
|
||||
contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'))
|
||||
# t5_text = contexts['t5_text']['prompt_embeds']
|
||||
# nan_count = torch.isnan(t5_text).sum()
|
||||
# if nan_count > 0:
|
||||
# print("t5_text has %d NaN values"%(nan_count))
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
with torch.no_grad():
|
||||
latents = self.first_stage_model.encode(batch[self.first_stage_key], sample_posterior=True)
|
||||
latents = self.z_scale_factor * latents
|
||||
# print(latents.shape)
|
||||
|
||||
# check vae encode and decode is ok? answer is ok !
|
||||
# check vae encode and decode is ok? answer is ok!
|
||||
# import time
|
||||
# from hy3dshape.pipelines import export_to_trimesh
|
||||
# latents = 1. / self.z_scale_factor * latents
|
||||
@@ -333,9 +330,6 @@ class Diffuser(pl.LightningModule):
|
||||
image = batch.get("image", None)
|
||||
mask = batch.get('mask', None)
|
||||
|
||||
# if not isinstance(image, torch.Tensor): print(image.shape)
|
||||
# if isinstance(mask, torch.Tensor): print(mask.shape)
|
||||
|
||||
outputs = self.pipeline(image=image,
|
||||
mask=mask,
|
||||
generator=generator,
|
||||
@@ -350,5 +344,6 @@ class Diffuser(pl.LightningModule):
|
||||
f.write(traceback.format_exc())
|
||||
f.write("\n")
|
||||
outputs = [None]
|
||||
|
||||
self.cond_stage_model.disable_drop = False
|
||||
return [outputs]
|
||||
|
||||
@@ -323,7 +323,9 @@ class ImageConditionalFixASLDiffuserLogger(Callback):
|
||||
save_path = os.path.join(visual_dir, os.path.basename(image_path))
|
||||
save_path = os.path.splitext(save_path)[0] + '.glb'
|
||||
|
||||
print(image_path)
|
||||
if isinstance(image_path, str):
|
||||
print(image_path)
|
||||
|
||||
with torch.no_grad():
|
||||
mesh = pl_module.sample(batch={"image": image_path}, **self.kwargs)[0][0]
|
||||
if isinstance(mesh, tuple) and len(mesh)==2:
|
||||
|
||||
Reference in New Issue
Block a user