6
6
import fnmatch
7
7
import humanfriendly
8
8
import atexit
9
+ import yaml
10
+ import functools
9
11
10
12
class ModelDownloadedSignal (QObject ):
11
- downloaded = Signal (str , str )
13
+ downloaded = Signal (str , str )
12
14
13
15
model_downloaded_signal = ModelDownloadedSignal ()
14
16
15
17
MODEL_DIRECTORIES = {
16
- "vector" : "vector" ,
17
- "chat" : "chat" ,
18
- "tts" : "tts" ,
19
- "jeeves" : "jeeves" ,
20
- "ocr" : "ocr"
18
+ "vector" : "vector" ,
19
+ "chat" : "chat" ,
20
+ "tts" : "tts" ,
21
+ "jeeves" : "jeeves" ,
22
+ "ocr" : "ocr"
21
23
}
22
24
25
+ @functools .lru_cache (maxsize = 1 )
26
+ def get_hf_token ():
27
+ config_path = Path ("config.yaml" )
28
+ if config_path .exists ():
29
+ try :
30
+ with open (config_path , 'r' , encoding = 'utf-8' ) as config_file :
31
+ config = yaml .safe_load (config_file )
32
+ return config .get ('hf_access_token' )
33
+ except Exception as e :
34
+ print (f"Warning: Could not load config: { e } " )
35
+ return None
36
+
23
37
class ModelDownloader (QObject ):
24
- def __init__ (self , model_info , model_type ):
25
- super ().__init__ ()
26
- self .model_info = model_info
27
- self .model_type = model_type
28
- self ._model_directory = None
29
- self .api = HfApi ()
30
- self .api .timeout = 60 # increase timeout
31
- disable_progress_bars ()
32
- self .local_dir = self .get_model_directory ()
33
-
34
- def cleanup_incomplete_download (self ):
35
- if self .local_dir .exists ():
36
- import shutil
37
- shutil .rmtree (self .local_dir )
38
-
39
- def get_model_url (self ):
40
- if isinstance (self .model_info , dict ):
41
- return self .model_info ['repo_id' ]
42
- else :
43
- return self .model_info
44
-
45
- def check_repo_type (self , repo_id ):
46
- try :
47
- repo_info = self .api .repo_info (repo_id , timeout = 60 ) # increase timeout
48
- if repo_info .private :
49
- return "private"
50
- elif getattr (repo_info , 'gated' , False ):
51
- return "gated"
52
- else :
53
- return "public"
54
- except GatedRepoError :
55
- return "gated"
56
- except RepositoryNotFoundError :
57
- return "not_found"
58
- except Exception as e :
59
- return f"error: { str (e )} "
60
-
61
- def get_model_directory_name (self ):
62
- if isinstance (self .model_info , dict ):
63
- return self .model_info ['cache_dir' ]
64
- else :
65
- return self .model_info .replace ("/" , "--" )
66
-
67
- def get_model_directory (self ):
68
- return Path ("Models" ) / self .model_type / self .get_model_directory_name ()
69
-
70
- def download_model (self , allow_patterns = None , ignore_patterns = None ):
38
+ def __init__ (self , model_info , model_type ):
39
+ super ().__init__ ()
40
+ self .model_info = model_info
41
+ self .model_type = model_type
42
+ self ._model_directory = None
43
+
44
+ self .hf_token = get_hf_token ()
45
+
46
+ self .api = HfApi (token = False )
47
+ self .api .timeout = 60
48
+ disable_progress_bars ()
49
+ self .local_dir = self .get_model_directory ()
50
+
51
+ def cleanup_incomplete_download (self ):
52
+ if self .local_dir .exists ():
53
+ import shutil
54
+ shutil .rmtree (self .local_dir )
55
+
56
+ def get_model_url (self ):
57
+ if isinstance (self .model_info , dict ):
58
+ return self .model_info ['repo_id' ]
59
+ else :
60
+ return self .model_info
61
+
62
+ def check_repo_type (self , repo_id ):
63
+ try :
64
+ repo_info = self .api .repo_info (repo_id , timeout = 60 , token = False )
65
+ if repo_info .private :
66
+ return "private"
67
+ elif getattr (repo_info , 'gated' , False ):
68
+ return "gated"
69
+ else :
70
+ return "public"
71
+ except Exception as e :
72
+ if self .hf_token and ("401" in str (e ) or "Unauthorized" in str (e )):
73
+ try :
74
+ api_with_token = HfApi (token = self .hf_token )
75
+ repo_info = api_with_token .repo_info (repo_id , timeout = 60 )
76
+ if repo_info .private :
77
+ return "private"
78
+ elif getattr (repo_info , 'gated' , False ):
79
+ return "gated"
80
+ else :
81
+ return "public"
82
+ except Exception as e2 :
83
+ return f"error: { str (e2 )} "
84
+ elif "404" in str (e ):
85
+ return "not_found"
86
+ else :
87
+ return f"error: { str (e )} "
88
+
89
+ def get_model_directory_name (self ):
90
+ if isinstance (self .model_info , dict ):
91
+ return self .model_info ['cache_dir' ]
92
+ else :
93
+ return self .model_info .replace ("/" , "--" )
94
+
95
+ def get_model_directory (self ):
96
+ return Path ("Models" ) / self .model_type / self .get_model_directory_name ()
97
+
98
+ def download_model (self , allow_patterns = None , ignore_patterns = None ):
71
99
repo_id = self .get_model_url ()
72
100
73
- # only download if repo is public
74
- # https://huggingface.co/docs/hub/models-gated#access-gated-models-as-a-user
75
- # https://huggingface.co/docs/hub/en/enterprise-hub-tokens-management
76
101
repo_type = self .check_repo_type (repo_id )
77
- if repo_type != "public" :
102
+ if repo_type not in [ "public" , "gated" ] :
78
103
if repo_type == "private" :
79
- print (f"Repository { repo_id } is private and requires a token. Aborting download." )
80
- elif repo_type == "gated" :
81
- print (f"Repository { repo_id } is gated. Please request access through the web interface. Aborting download." )
104
+ print (f"Repository { repo_id } is private and requires a token." )
105
+ if not self .hf_token :
106
+ print ("No Hugging Face token found. Please add one through the credentials menu." )
107
+ return
82
108
elif repo_type == "not_found" :
83
109
print (f"Repository { repo_id } not found. Aborting download." )
110
+ return
84
111
else :
85
112
print (f"Error checking repository { repo_id } : { repo_type } . Aborting download." )
113
+ return
114
+
115
+ if repo_type == "gated" and not self .hf_token :
116
+ print (f"Repository { repo_id } is gated. Please add a Hugging Face token and request access through the web interface." )
86
117
return
87
118
88
119
local_dir = self .get_model_directory ()
@@ -91,13 +122,12 @@ def download_model(self, allow_patterns=None, ignore_patterns=None):
91
122
atexit .register (self .cleanup_incomplete_download )
92
123
93
124
try :
94
- repo_files = list (self .api .list_repo_tree (repo_id , recursive = True ))
95
- """
96
- allow_patterns: If provided, only matching files are downloaded (ignore_patterns is disregarded)
97
- ignore_patterns: If provided alone, matching files are excluded
98
- neither: Uses default ignore patterns (.gitattributes, READMEs, etc.) with smart model file filtering
99
- both: Behaves same as allow_patterns only
100
- """
125
+ if repo_type == "gated" and self .hf_token :
126
+ api_for_listing = HfApi (token = self .hf_token )
127
+ repo_files = list (api_for_listing .list_repo_tree (repo_id , recursive = True ))
128
+ else :
129
+ repo_files = list (self .api .list_repo_tree (repo_id , recursive = True , token = False ))
130
+
101
131
if allow_patterns is not None :
102
132
final_ignore_patterns = None
103
133
elif ignore_patterns is not None :
@@ -154,14 +184,21 @@ def download_model(self, allow_patterns=None, ignore_patterns=None):
154
184
print (f"- { file } " )
155
185
print (f"\n Downloading to { local_dir } ..." )
156
186
157
- snapshot_download (
158
- repo_id = repo_id ,
159
- local_dir = str (local_dir ),
160
- max_workers = 4 ,
161
- ignore_patterns = final_ignore_patterns ,
162
- allow_patterns = allow_patterns ,
163
- etag_timeout = 60 # increase timeout
164
- )
187
+ download_kwargs = {
188
+ 'repo_id' : repo_id ,
189
+ 'local_dir' : str (local_dir ),
190
+ 'max_workers' : 4 ,
191
+ 'ignore_patterns' : final_ignore_patterns ,
192
+ 'allow_patterns' : allow_patterns ,
193
+ 'etag_timeout' : 60
194
+ }
195
+
196
+ if repo_type == "gated" and self .hf_token :
197
+ download_kwargs ['token' ] = self .hf_token
198
+ elif repo_type == "public" :
199
+ download_kwargs ['token' ] = False
200
+
201
+ snapshot_download (** download_kwargs )
165
202
166
203
print ("\033 [92mModel downloaded and ready to use.\033 [0m" )
167
204
atexit .unregister (self .cleanup_incomplete_download )
0 commit comments