Skip to content

Commit 6a139fa

Browse files
Handle repository redirects and skip authorization header for LFS (#314)
1 parent b947dd2 commit 6a139fa

File tree

2 files changed

+137
-33
lines changed

2 files changed

+137
-33
lines changed

lib/bumblebee/huggingface/hub.ex

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,39 @@ defmodule Bumblebee.HuggingFace.Hub do
6767

6868
metadata_path = Path.join(dir, metadata_filename(url))
6969

70-
if offline do
71-
case load_json(metadata_path) do
72-
{:ok, %{"etag" => etag}} ->
73-
entry_path = Path.join(dir, entry_filename(url, etag))
74-
{:ok, entry_path}
75-
76-
_ ->
77-
{:error,
78-
"could not find file in local cache and outgoing traffic is disabled, url: #{url}"}
79-
end
80-
else
81-
head_result =
82-
if etag = opts[:etag] do
83-
{:ok, etag, url}
84-
else
85-
head_download(url, headers)
86-
end
87-
88-
with {:ok, etag, download_url} <- head_result do
89-
entry_path = Path.join(dir, entry_filename(url, etag))
90-
70+
cond do
71+
offline ->
9172
case load_json(metadata_path) do
92-
{:ok, %{"etag" => ^etag}} ->
73+
{:ok, %{"etag" => etag}} ->
74+
entry_path = Path.join(dir, entry_filename(url, etag))
9375
{:ok, entry_path}
9476

9577
_ ->
96-
case HTTP.download(download_url, entry_path, headers: headers)
97-
|> finish_request(download_url) do
78+
{:error,
79+
"could not find file in local cache and outgoing traffic is disabled, url: #{url}"}
80+
end
81+
82+
entry_path = opts[:etag] && cached_path_for_etag(dir, url, opts[:etag]) ->
83+
{:ok, entry_path}
84+
85+
true ->
86+
with {:ok, etag, download_url, redirect?} <- head_download(url, headers) do
87+
if entry_path = cached_path_for_etag(dir, url, etag) do
88+
{:ok, entry_path}
89+
else
90+
entry_path = Path.join(dir, entry_filename(url, etag))
91+
92+
headers =
93+
if redirect? do
94+
List.keydelete(headers, "Authorization", 0)
95+
else
96+
headers
97+
end
98+
99+
download_url
100+
|> HTTP.download(entry_path, headers: headers)
101+
|> finish_request(download_url)
102+
|> case do
98103
:ok ->
99104
:ok = store_json(metadata_path, %{"etag" => etag, "url" => url})
100105
{:ok, entry_path}
@@ -104,24 +109,45 @@ defmodule Bumblebee.HuggingFace.Hub do
104109
File.rm_rf!(entry_path)
105110
error
106111
end
112+
end
107113
end
108-
end
114+
end
115+
end
116+
117+
defp cached_path_for_etag(dir, url, etag) do
118+
metadata_path = Path.join(dir, metadata_filename(url))
119+
120+
case load_json(metadata_path) do
121+
{:ok, %{"etag" => ^etag}} ->
122+
Path.join(dir, entry_filename(url, etag))
123+
124+
_ ->
125+
nil
109126
end
110127
end
111128

112129
defp head_download(url, headers) do
113130
with {:ok, response} <-
114131
HTTP.request(:head, url, follow_redirects: false, headers: headers)
115-
|> finish_request(url),
116-
{:ok, etag} <- fetch_etag(response) do
117-
download_url =
118-
if response.status in 300..399 do
119-
HTTP.get_header(response, "location")
132+
|> finish_request(url) do
133+
if response.status in 300..399 do
134+
location = HTTP.get_header(response, "location")
135+
136+
# Follow relative redirects
137+
if URI.parse(location).host == nil do
138+
url =
139+
url
140+
|> URI.parse()
141+
|> Map.replace!(:path, location)
142+
|> URI.to_string()
143+
144+
head_download(url, headers)
120145
else
121-
url
146+
with {:ok, etag} <- fetch_etag(response), do: {:ok, etag, location, true}
122147
end
123-
124-
{:ok, etag, download_url}
148+
else
149+
with {:ok, etag} <- fetch_etag(response), do: {:ok, etag, url, false}
150+
end
125151
end
126152
end
127153

test/bumblebee/huggingface/hub_test.exs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,28 @@ defmodule Bumblebee.HuggingFace.HubTest do
115115
assert File.read!(path) == <<0, 1>>
116116
end
117117

118+
@tag :tmp_dir
119+
test "follows relative redirect before checking etag", %{bypass: bypass, tmp_dir: tmp_dir} do
120+
Bypass.expect_once(bypass, "HEAD", "/repo/file.json", fn conn ->
121+
conn
122+
|> Plug.Conn.put_resp_header("location", "/repo-renamed/file.json")
123+
|> Plug.Conn.resp(307, "")
124+
end)
125+
126+
Bypass.expect_once(bypass, "HEAD", "/repo-renamed/file.json", fn conn ->
127+
serve_with_etag(conn, ~s/"hash"/, "{}")
128+
end)
129+
130+
Bypass.expect_once(bypass, "GET", "/repo-renamed/file.json", fn conn ->
131+
serve_with_etag(conn, ~s/"hash"/, "{}")
132+
end)
133+
134+
url = url(bypass.port) <> "/repo/file.json"
135+
136+
assert {:ok, path} = Hub.cached_download(url, cache_dir: tmp_dir, offline: false)
137+
assert File.read!(path) == "{}"
138+
end
139+
118140
@tag :tmp_dir
119141
test "returns an error on missing etag header", %{bypass: bypass, tmp_dir: tmp_dir} do
120142
Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn ->
@@ -183,6 +205,62 @@ defmodule Bumblebee.HuggingFace.HubTest do
183205
"could not find file in local cache and outgoing traffic is disabled, url: " <> _} =
184206
Hub.cached_download(url, cache_dir: tmp_dir, offline: true)
185207
end
208+
209+
@tag :tmp_dir
210+
test "includes authorization header when :auth_token is given",
211+
%{bypass: bypass, tmp_dir: tmp_dir} do
212+
Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn ->
213+
assert {"authorization", "Bearer token"} in conn.req_headers
214+
215+
serve_with_etag(conn, ~s/"hash"/, "")
216+
end)
217+
218+
Bypass.expect_once(bypass, "GET", "/file.json", fn conn ->
219+
assert {"authorization", "Bearer token"} in conn.req_headers
220+
221+
serve_with_etag(conn, ~s/"hash"/, "{}")
222+
end)
223+
224+
url = url(bypass.port) <> "/file.json"
225+
226+
assert {:ok, _path} =
227+
Hub.cached_download(url, auth_token: "token", cache_dir: tmp_dir, offline: false)
228+
end
229+
230+
@tag :tmp_dir
231+
test "skips authentication header for redirected requests",
232+
%{bypass: bypass, tmp_dir: tmp_dir} do
233+
# Context: HuggingFace Hub returns redirects for files stored
234+
# in LFS and the redirect location already has S3 signature in
235+
# URL params. If the location points to S3 directly, which is
236+
# the case within HF spaces, passing the Authorization header
237+
# leads to failure, presumably because it takes precedence over
238+
# the params signature
239+
240+
Bypass.expect_once(bypass, "HEAD", "/file.bin", fn conn ->
241+
assert {"authorization", "Bearer token"} in conn.req_headers
242+
243+
url = Plug.Conn.request_url(conn)
244+
245+
conn
246+
|> Plug.Conn.put_resp_header("x-linked-etag", ~s/"hash"/)
247+
|> Plug.Conn.put_resp_header("location", url <> "/storage")
248+
|> Plug.Conn.resp(302, "")
249+
end)
250+
251+
Bypass.expect_once(bypass, "GET", "/file.bin/storage", fn conn ->
252+
assert {"authorization", "Bearer token"} not in conn.req_headers
253+
254+
conn
255+
|> Plug.Conn.put_resp_header("etag", ~s/"hash"/)
256+
|> Plug.Conn.resp(200, <<0, 1>>)
257+
end)
258+
259+
url = url(bypass.port) <> "/file.bin"
260+
261+
assert {:ok, _path} =
262+
Hub.cached_download(url, auth_token: "token", cache_dir: tmp_dir, offline: false)
263+
end
186264
end
187265

188266
defp url(port), do: "http://localhost:#{port}"

0 commit comments

Comments
 (0)