Skip to content

Commit 7e25804

Browse files
authored
fix hf auth issue
fixes force using auth token even when not required by a repository, which caused a 401 error when a user had not saved a token to the config.yaml
1 parent bfb6a02 commit 7e25804

File tree

1 file changed

+112
-75
lines changed

1 file changed

+112
-75
lines changed

src/download_model.py

Lines changed: 112 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -6,83 +6,114 @@
66
import fnmatch
77
import humanfriendly
88
import atexit
9+
import yaml
10+
import functools
911

1012
class ModelDownloadedSignal(QObject):
11-
downloaded = Signal(str, str)
13+
downloaded = Signal(str, str)
1214

1315
model_downloaded_signal = ModelDownloadedSignal()
1416

1517
MODEL_DIRECTORIES = {
16-
"vector": "vector",
17-
"chat": "chat",
18-
"tts": "tts",
19-
"jeeves": "jeeves",
20-
"ocr": "ocr"
18+
"vector": "vector",
19+
"chat": "chat",
20+
"tts": "tts",
21+
"jeeves": "jeeves",
22+
"ocr": "ocr"
2123
}
2224

25+
@functools.lru_cache(maxsize=1)
26+
def get_hf_token():
27+
config_path = Path("config.yaml")
28+
if config_path.exists():
29+
try:
30+
with open(config_path, 'r', encoding='utf-8') as config_file:
31+
config = yaml.safe_load(config_file)
32+
return config.get('hf_access_token')
33+
except Exception as e:
34+
print(f"Warning: Could not load config: {e}")
35+
return None
36+
2337
class ModelDownloader(QObject):
24-
def __init__(self, model_info, model_type):
25-
super().__init__()
26-
self.model_info = model_info
27-
self.model_type = model_type
28-
self._model_directory = None
29-
self.api = HfApi()
30-
self.api.timeout = 60 # increase timeout
31-
disable_progress_bars()
32-
self.local_dir = self.get_model_directory()
33-
34-
def cleanup_incomplete_download(self):
35-
if self.local_dir.exists():
36-
import shutil
37-
shutil.rmtree(self.local_dir)
38-
39-
def get_model_url(self):
40-
if isinstance(self.model_info, dict):
41-
return self.model_info['repo_id']
42-
else:
43-
return self.model_info
44-
45-
def check_repo_type(self, repo_id):
46-
try:
47-
repo_info = self.api.repo_info(repo_id, timeout=60) # increase timeout
48-
if repo_info.private:
49-
return "private"
50-
elif getattr(repo_info, 'gated', False):
51-
return "gated"
52-
else:
53-
return "public"
54-
except GatedRepoError:
55-
return "gated"
56-
except RepositoryNotFoundError:
57-
return "not_found"
58-
except Exception as e:
59-
return f"error: {str(e)}"
60-
61-
def get_model_directory_name(self):
62-
if isinstance(self.model_info, dict):
63-
return self.model_info['cache_dir']
64-
else:
65-
return self.model_info.replace("/", "--")
66-
67-
def get_model_directory(self):
68-
return Path("Models") / self.model_type / self.get_model_directory_name()
69-
70-
def download_model(self, allow_patterns=None, ignore_patterns=None):
38+
def __init__(self, model_info, model_type):
39+
super().__init__()
40+
self.model_info = model_info
41+
self.model_type = model_type
42+
self._model_directory = None
43+
44+
self.hf_token = get_hf_token()
45+
46+
self.api = HfApi(token=False)
47+
self.api.timeout = 60
48+
disable_progress_bars()
49+
self.local_dir = self.get_model_directory()
50+
51+
def cleanup_incomplete_download(self):
52+
if self.local_dir.exists():
53+
import shutil
54+
shutil.rmtree(self.local_dir)
55+
56+
def get_model_url(self):
57+
if isinstance(self.model_info, dict):
58+
return self.model_info['repo_id']
59+
else:
60+
return self.model_info
61+
62+
def check_repo_type(self, repo_id):
63+
try:
64+
repo_info = self.api.repo_info(repo_id, timeout=60, token=False)
65+
if repo_info.private:
66+
return "private"
67+
elif getattr(repo_info, 'gated', False):
68+
return "gated"
69+
else:
70+
return "public"
71+
except Exception as e:
72+
if self.hf_token and ("401" in str(e) or "Unauthorized" in str(e)):
73+
try:
74+
api_with_token = HfApi(token=self.hf_token)
75+
repo_info = api_with_token.repo_info(repo_id, timeout=60)
76+
if repo_info.private:
77+
return "private"
78+
elif getattr(repo_info, 'gated', False):
79+
return "gated"
80+
else:
81+
return "public"
82+
except Exception as e2:
83+
return f"error: {str(e2)}"
84+
elif "404" in str(e):
85+
return "not_found"
86+
else:
87+
return f"error: {str(e)}"
88+
89+
def get_model_directory_name(self):
90+
if isinstance(self.model_info, dict):
91+
return self.model_info['cache_dir']
92+
else:
93+
return self.model_info.replace("/", "--")
94+
95+
def get_model_directory(self):
96+
return Path("Models") / self.model_type / self.get_model_directory_name()
97+
98+
def download_model(self, allow_patterns=None, ignore_patterns=None):
7199
repo_id = self.get_model_url()
72100

73-
# only download if repo is public
74-
# https://huggingface.co/docs/hub/models-gated#access-gated-models-as-a-user
75-
# https://huggingface.co/docs/hub/en/enterprise-hub-tokens-management
76101
repo_type = self.check_repo_type(repo_id)
77-
if repo_type != "public":
102+
if repo_type not in ["public", "gated"]:
78103
if repo_type == "private":
79-
print(f"Repository {repo_id} is private and requires a token. Aborting download.")
80-
elif repo_type == "gated":
81-
print(f"Repository {repo_id} is gated. Please request access through the web interface. Aborting download.")
104+
print(f"Repository {repo_id} is private and requires a token.")
105+
if not self.hf_token:
106+
print("No Hugging Face token found. Please add one through the credentials menu.")
107+
return
82108
elif repo_type == "not_found":
83109
print(f"Repository {repo_id} not found. Aborting download.")
110+
return
84111
else:
85112
print(f"Error checking repository {repo_id}: {repo_type}. Aborting download.")
113+
return
114+
115+
if repo_type == "gated" and not self.hf_token:
116+
print(f"Repository {repo_id} is gated. Please add a Hugging Face token and request access through the web interface.")
86117
return
87118

88119
local_dir = self.get_model_directory()
@@ -91,13 +122,12 @@ def download_model(self, allow_patterns=None, ignore_patterns=None):
91122
atexit.register(self.cleanup_incomplete_download)
92123

93124
try:
94-
repo_files = list(self.api.list_repo_tree(repo_id, recursive=True))
95-
"""
96-
allow_patterns: If provided, only matching files are downloaded (ignore_patterns is disregarded)
97-
ignore_patterns: If provided alone, matching files are excluded
98-
neither: Uses default ignore patterns (.gitattributes, READMEs, etc.) with smart model file filtering
99-
both: Behaves same as allow_patterns only
100-
"""
125+
if repo_type == "gated" and self.hf_token:
126+
api_for_listing = HfApi(token=self.hf_token)
127+
repo_files = list(api_for_listing.list_repo_tree(repo_id, recursive=True))
128+
else:
129+
repo_files = list(self.api.list_repo_tree(repo_id, recursive=True, token=False))
130+
101131
if allow_patterns is not None:
102132
final_ignore_patterns = None
103133
elif ignore_patterns is not None:
@@ -154,14 +184,21 @@ def download_model(self, allow_patterns=None, ignore_patterns=None):
154184
print(f"- {file}")
155185
print(f"\nDownloading to {local_dir}...")
156186

157-
snapshot_download(
158-
repo_id=repo_id,
159-
local_dir=str(local_dir),
160-
max_workers=4,
161-
ignore_patterns=final_ignore_patterns,
162-
allow_patterns=allow_patterns,
163-
etag_timeout=60 # increase timeout
164-
)
187+
download_kwargs = {
188+
'repo_id': repo_id,
189+
'local_dir': str(local_dir),
190+
'max_workers': 4,
191+
'ignore_patterns': final_ignore_patterns,
192+
'allow_patterns': allow_patterns,
193+
'etag_timeout': 60
194+
}
195+
196+
if repo_type == "gated" and self.hf_token:
197+
download_kwargs['token'] = self.hf_token
198+
elif repo_type == "public":
199+
download_kwargs['token'] = False
200+
201+
snapshot_download(**download_kwargs)
165202

166203
print("\033[92mModel downloaded and ready to use.\033[0m")
167204
atexit.unregister(self.cleanup_incomplete_download)

0 commit comments

Comments
 (0)