Skip to content

Commit 3b44a78

Browse files
DawiAlotaibiDawi-Alotaibi
andauthored
fixed error message for GatedRepoError (#1832)
Co-authored-by: Dawi-Alotaibi <[email protected]>
1 parent 4107cc4 commit 3b44a78

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

tests/torchtune/_cli/test_download.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,44 @@ def test_download_calls_snapshot(self, capsys, monkeypatch, snapshot_download):
6565

6666
# Make sure it was called twice
6767
assert snapshot_download.call_count == 3
68+
69+
# GatedRepoError without --hf-token (expect prompt for token)
70+
def test_gated_repo_error_no_token(self, capsys, monkeypatch, snapshot_download):
71+
model = "meta-llama/Llama-2-7b"
72+
testargs = f"tune download {model}".split()
73+
monkeypatch.setattr(sys, "argv", testargs)
74+
75+
# Expect GatedRepoError without --hf-token provided
76+
with pytest.raises(SystemExit, match="2"):
77+
runpy.run_path(TUNE_PATH, run_name="__main__")
78+
79+
out_err = capsys.readouterr()
80+
# Check that error message prompts for --hf-token
81+
assert (
82+
"It looks like you are trying to access a gated repository." in out_err.err
83+
)
84+
assert (
85+
"Please ensure you have access to the repository and have provided the proper Hugging Face API token"
86+
in out_err.err
87+
)
88+
89+
# GatedRepoError with --hf-token (should not ask for token)
90+
def test_gated_repo_error_with_token(self, capsys, monkeypatch, snapshot_download):
91+
model = "meta-llama/Llama-2-7b"
92+
testargs = f"tune download {model} --hf-token valid_token".split()
93+
monkeypatch.setattr(sys, "argv", testargs)
94+
95+
# Expect GatedRepoError with --hf-token provided
96+
with pytest.raises(SystemExit, match="2"):
97+
runpy.run_path(TUNE_PATH, run_name="__main__")
98+
99+
out_err = capsys.readouterr()
100+
# Check that error message does not prompt for --hf-token again
101+
assert (
102+
"It looks like you are trying to access a gated repository." in out_err.err
103+
)
104+
assert "Please ensure you have access to the repository." in out_err.err
105+
assert (
106+
"Please ensure you have access to the repository and have provided the proper Hugging Face API token"
107+
not in out_err.err
108+
)

torchtune/_cli/download.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,18 @@ def _download_cmd(self, args: argparse.Namespace) -> None:
131131
token=args.hf_token,
132132
)
133133
except GatedRepoError:
134-
self._parser.error(
135-
"It looks like you are trying to access a gated repository. Please ensure you "
136-
"have access to the repository and have provided the proper Hugging Face API token "
137-
"using the option `--hf-token` or by running `huggingface-cli login`."
138-
"You can find your token by visiting https://huggingface.co/settings/tokens"
139-
)
134+
if args.hf_token:
135+
self._parser.error(
136+
"It looks like you are trying to access a gated repository. Please ensure you "
137+
"have access to the repository."
138+
)
139+
else:
140+
self._parser.error(
141+
"It looks like you are trying to access a gated repository. Please ensure you "
142+
"have access to the repository and have provided the proper Hugging Face API token "
143+
"using the option `--hf-token` or by running `huggingface-cli login`."
144+
"You can find your token by visiting https://huggingface.co/settings/tokens"
145+
)
140146
except RepositoryNotFoundError:
141147
self._parser.error(
142148
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."

0 commit comments

Comments
 (0)