init
This commit is contained in:
111
hy3dpaint/utils/torchvision_fix.py
Normal file
111
hy3dpaint/utils/torchvision_fix.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Torchvision compatibility fix for functional_tensor module
|
||||
# This file helps resolve compatibility issues between different torchvision versions
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
def fix_torchvision_functional_tensor():
|
||||
"""
|
||||
Fix torchvision.transforms.functional_tensor import issue
|
||||
"""
|
||||
try:
|
||||
# Check if the module exists in the expected location
|
||||
import torchvision.transforms.functional_tensor
|
||||
print("torchvision.transforms.functional_tensor is available")
|
||||
return True
|
||||
except ImportError:
|
||||
print("torchvision.transforms.functional_tensor not found, applying compatibility fix...")
|
||||
|
||||
try:
|
||||
# Create a mock functional_tensor module with the required functions
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
class FunctionalTensorMock:
|
||||
"""Mock module to replace functional_tensor"""
|
||||
|
||||
@staticmethod
|
||||
def _get_grayscale_weights(img):
|
||||
"""Helper to create grayscale weights based on image dimensions"""
|
||||
weights = torch.tensor([0.299, 0.587, 0.114], device=img.device, dtype=img.dtype)
|
||||
return weights.view(1, 3, 1, 1) if len(img.shape) == 4 else weights.view(3, 1, 1)
|
||||
|
||||
@staticmethod
|
||||
def _try_import_fallback(module_names, attr_name):
|
||||
"""Helper to try importing from multiple modules"""
|
||||
for module_name in module_names:
|
||||
try:
|
||||
module = __import__(module_name, fromlist=[attr_name])
|
||||
if hasattr(module, attr_name):
|
||||
return getattr(module, attr_name)
|
||||
except ImportError:
|
||||
continue
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def rgb_to_grayscale(img, num_output_channels=1):
|
||||
"""Convert RGB image to grayscale"""
|
||||
if hasattr(F, 'rgb_to_grayscale'):
|
||||
return F.rgb_to_grayscale(img, num_output_channels)
|
||||
|
||||
# Fallback implementation
|
||||
weights = FunctionalTensorMock._get_grayscale_weights(img)
|
||||
grayscale = torch.sum(img * weights, dim=-3, keepdim=True)
|
||||
|
||||
if num_output_channels == 3:
|
||||
repeat_dims = (1, 3, 1, 1) if len(img.shape) == 4 else (3, 1, 1)
|
||||
grayscale = grayscale.repeat(*repeat_dims)
|
||||
|
||||
return grayscale
|
||||
|
||||
@staticmethod
|
||||
def resize(img, size, interpolation=2, antialias=None):
|
||||
"""Resize function wrapper"""
|
||||
# Try v2.functional first, then regular functional, then torch.nn.functional
|
||||
resize_func = FunctionalTensorMock._try_import_fallback([
|
||||
'torchvision.transforms.v2.functional',
|
||||
'torchvision.transforms.functional'
|
||||
], 'resize')
|
||||
|
||||
if resize_func:
|
||||
try:
|
||||
return resize_func(img, size, interpolation=interpolation, antialias=antialias)
|
||||
except TypeError:
|
||||
# Fallback for older versions without antialias parameter
|
||||
return resize_func(img, size, interpolation=interpolation)
|
||||
|
||||
# Final fallback using torch.nn.functional
|
||||
import torch.nn.functional as torch_F
|
||||
size = (size, size) if isinstance(size, int) else size
|
||||
img_input = img.unsqueeze(0) if len(img.shape) == 3 else img
|
||||
return torch_F.interpolate(img_input, size=size, mode='bilinear', align_corners=False)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to regular functional module"""
|
||||
func = self._try_import_fallback([
|
||||
'torchvision.transforms.functional',
|
||||
'torchvision.transforms.v2.functional'
|
||||
], name)
|
||||
|
||||
if func:
|
||||
return func
|
||||
|
||||
raise AttributeError(f"'{name}' not found in functional_tensor mock")
|
||||
|
||||
# Create the mock module instance and monkey patch
|
||||
sys.modules['torchvision.transforms.functional_tensor'] = FunctionalTensorMock()
|
||||
print("Applied compatibility fix: created functional_tensor mock module")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to create functional_tensor mock: {e}")
|
||||
return False
|
||||
|
||||
def apply_fix():
|
||||
"""Apply the torchvision compatibility fix"""
|
||||
print(f"Torchvision version: {torchvision.__version__}")
|
||||
return fix_torchvision_functional_tensor()
|
||||
|
||||
if __name__ == "__main__":
|
||||
apply_fix()
|
||||
|
||||
Reference in New Issue
Block a user