diff --git a/hy3dshape/hy3dshape/rembg.py b/hy3dshape/hy3dshape/rembg.py index 6247f06..bd94de3 100644 --- a/hy3dshape/hy3dshape/rembg.py +++ b/hy3dshape/hy3dshape/rembg.py @@ -17,8 +17,8 @@ from rembg import remove, new_session class BackgroundRemover(): - def __init__(self): - self.session = new_session() + def __init__(self, model_name: str = "bria-rmbg"): + self.session = new_session(model_name) def __call__(self, image: Image.Image): output = remove(image, session=self.session, bgcolor=[255, 255, 255, 0])