Closed
Description
Potential issue in prompt handling in generate()
Description
generate()
in torchtune/recipes/generate.py receives a cfg dict that includes 'prompt', 'system', 'user'
and other keys. However, it only passes cfg.prompt
to self.convert_prompt_to_tokens()
. This causes a failure since self.convert_prompt_to_tokens
looks for cfg.system
and cfg.user
keys.
Documentation
There are two documents that describe using generate()
:
Potential solution
A potential solution could be passing cfg
to self.convert_prompt_to_tokens()
instead of cfg.prompt
only.
Error message
Here is an error message:
Traceback (most recent call last):
File "/usr/local/bin/tune", line 8, in <module>
sys.exit(main())
File "/workspace/torchtune/torchtune/_cli/tune.py", line 49, in main
parser.run(args)
File "/workspace/torchtune/torchtune/_cli/tune.py", line 43, in run
args.func(args)
File "/workspace/torchtune/torchtune/_cli/run.py", line 214, in _run_cmd
self._run_single_device(args, is_builtin=is_builtin)
File "/workspace/torchtune/torchtune/_cli/run.py", line 108, in _run_single_device
runpy.run_path(str(args.recipe), run_name="__main__")
File "/usr/lib/python3.10/runpy.py", line 289, in run_path
return _run_module_code(code, init_globals, run_name,
File "/usr/lib/python3.10/runpy.py", line 96, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/torchtune/recipes/generate.py", line 203, in <module>
File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
sys.exit(recipe_main(conf))
File "/workspace/torchtune/recipes/generate.py", line 199, in main
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/workspace/torchtune/recipes/generate.py", line 123, in generate
cfg.prompt,
File "/workspace/torchtune/recipes/generate.py", line 109, in convert_prompt_to_tokens
if "system" in prompt and prompt["system"] is not None:
TypeError: argument of type 'NoneType' is not iterable
Metadata
Metadata
Assignees
Labels
No labels