Skip to content

Commit bf7f7f9

Browse files
authored
Max Length rework (#741)
* implementation * c * implementation * docs * c * dont allow freeze and lora * implementation * format * rm * c * c * c * fix * r * c * format * fix * docs * readme * ui test * fixing unfreeze + lora + dpo * fixing position_id issue * c * c
1 parent 9cafe8c commit bf7f7f9

33 files changed

+137
-133
lines changed

Makefile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,19 @@ test: reports
9999
-o log_cli=true -o log_level=INFO -o log_file=reports/tests.log \
100100
tests/* 2>&1 | tee reports/tests.log'
101101

102+
103+
.PHONY: test-debug
104+
test-debug: reports
105+
@bash -c 'set -o pipefail; export PYTHONPATH=$(PWD); \
106+
$(PIPENV) run pytest -v --junitxml=reports/junit.xml \
107+
--import-mode importlib \
108+
--html=./reports/pytest.html \
109+
-k test_encode \
110+
-s \
111+
-o log_cli=false -o log_level=WARNING -o log_file=/dev/null \
112+
tests/*'
113+
114+
102115
.PHONY: test-ui
103116
test-ui: reports setup-ui
104117
@bash -c 'set -o pipefail; \

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Using CLI for fine-tuning LLMs:
5454
## What's New
5555

5656
- [PR 747](https://github.com/h2oai/h2o-llmstudio/pull/747) Fully removed RLHF in favor of DPO/IPO/KTO optimization.
57+
- [PR 741](https://github.com/h2oai/h2o-llmstudio/pull/741) Removing separate max length settings for prompt and answer in favor of a single `max_length` settings better resembling `chat_template` functionality from `transformers`.
5758
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/599) Added `KTOPairLoss` for DPO modeling allowing to train models with simple preference data. Data currently needs to be manually prepared by randomly matching positive and negative examples as pairs.
5859
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/592) Starting to deprecate RLHF in favor of DPO/IPO optimization. Training is disabled, but old experiments are still viewable. RLHF will be fully removed in a future release.
5960
- [PR 530](https://github.com/h2oai/h2o-llmstudio/pull/530) Introduced a new problem type for DPO/IPO optimization. This optimization technique can be used as an alternative to RLHF.

documentation/docs/get-started/llm-studio-performance.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ problem_type: text_causal_language_modeling
137137
tokenizer:
138138
add_prompt_answer_tokens: false
139139
max_length: 512
140-
max_length_answer: 256
141-
max_length_prompt: 256
142140
padding_quantile: 1.0
143141
training:
144142
batch_size: 2

documentation/docs/guide/experiments/experiment-settings.md

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ import DStextAnswerSeparator from '../../tooltips/experiments/_text-answer-separ
1919
import DSaddEosTokentoprompt from '../../tooltips/experiments/_add-eos-token-to-prompt.mdx';
2020
import DSaddEosTokentoanswer from '../../tooltips/experiments/_add-eos-token-to-answer.mdx';
2121
import DSmaskPromptlabels from '../../tooltips/experiments/_mask-prompt-labels.mdx';
22-
import TSmaxLengthPrompt from '../../tooltips/experiments/_max-length-prompt.mdx';
23-
import TSmaxLengthAnswer from '../../tooltips/experiments/_max-length-answer.mdx';
2422
import TSmaxLength from '../../tooltips/experiments/_max-length.mdx';
2523
import TSaddpromptanswertokens from '../../tooltips/experiments/_add-prompt-answer-tokens.mdx';
2624
import TSpaddingQuantile from '../../tooltips/experiments/_padding-quantile.mdx';
@@ -173,14 +171,6 @@ The settings under each category are listed and described below.
173171

174172
## Tokenizer settings
175173

176-
### Max length prompt
177-
178-
<TSmaxLengthPrompt/>
179-
180-
### Max length answer
181-
182-
<TSmaxLengthAnswer/>
183-
184174
### Max length
185175

186176
<TSmaxLength/>
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
The column in the dataset containing the expected output.
22

3-
For classification, this needs to be an integer column containing the class label.
3+
For classification, this needs to be an integer column starting from zero containing the class label.

documentation/docs/tooltips/experiments/_max-length-answer.mdx

Lines changed: 0 additions & 1 deletion
This file was deleted.

documentation/docs/tooltips/experiments/_max-length-prompt.mdx

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/example_oasst2.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ problem_type: text_causal_language_modeling
7676
tokenizer:
7777
add_prompt_answer_tokens: false
7878
max_length: 512
79-
max_length_answer: 256
80-
max_length_prompt: 256
8179
padding_quantile: 1.0
8280
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
8381
training:

llm_studio/app_utils/sections/chat_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ async def answer_chat(q: Q) -> str:
113113
logger.info(f"Full prompt: {full_prompt}")
114114

115115
inputs = cfg.dataset.dataset_class.encode(
116-
tokenizer, full_prompt, cfg.tokenizer.max_length_prompt, "left"
116+
tokenizer, full_prompt, cfg.tokenizer.max_length, "left"
117117
)
118118
inputs["prompt_input_ids"] = (
119119
inputs.pop("input_ids").unsqueeze(0).to(cfg.environment._device)

llm_studio/python_configs/text_causal_classification_modeling_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,11 @@ def __post_init__(self):
9292

9393
@dataclass
9494
class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer):
95-
max_length_prompt: int = 512
9695
max_length: int = 512
9796

9897
def __post_init__(self):
9998
super().__post_init__()
10099

101-
self._visibility["max_length_answer"] = -1
102-
103100

104101
@dataclass
105102
class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture):

llm_studio/python_configs/text_causal_language_modeling_config.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def __post_init__(self):
174174
)
175175
self._possible_values["differential_learning_rate_layers"] = (
176176
possible_values.String(
177-
values=("backbone", "embed"),
178-
allow_custom=False,
177+
values=("backbone", "embed", "head"),
178+
allow_custom=True,
179179
placeholder="Select optional layers...",
180180
)
181181
)
@@ -250,17 +250,13 @@ def __post_init__(self):
250250

251251
@dataclass
252252
class ConfigNLPCausalLMTokenizer(DefaultConfig):
253-
max_length_prompt: int = 256
254-
max_length_answer: int = 256
255253
max_length: int = 512
256254
add_prompt_answer_tokens: bool = False
257255
padding_quantile: float = 1.0
258256
tokenizer_kwargs: str = '{"use_fast": true, "add_prefix_space": false}'
259257

260258
def __post_init__(self):
261259
super().__post_init__()
262-
self._possible_values["max_length_prompt"] = (32, 1024 * 16, 32)
263-
self._possible_values["max_length_answer"] = (32, 1024 * 16, 32)
264260
self._possible_values["max_length"] = (32, 1024 * 16, 32)
265261
self._possible_values["padding_quantile"] = (0, 1, 0.01)
266262
self._padding_side = "left"
@@ -353,7 +349,7 @@ def __post_init__(self):
353349

354350
self._possible_values["num_beams"] = (1, 4, 1)
355351
self._possible_values["temperature"] = (0, 10, 0.05)
356-
self._possible_values["repetition_penalty"] = (1, 10, 0.05)
352+
self._possible_values["repetition_penalty"] = (1, 10, 0.025)
357353
self._possible_values["top_k"] = (0, 100, 1)
358354
self._possible_values["top_p"] = (0.5, 1, 0.05)
359355
self._possible_values["num_history"] = (1, 50, 1)

llm_studio/src/datasets/text_causal_language_modeling_ds.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def __getitem__(self, idx: int) -> Dict:
4242
input_text_dict["prompts"] = [
4343
self.parse_prompt(self.cfg, prompt) for prompt in input_text_dict["prompts"]
4444
]
45+
input_text_dict["answers"] = [
46+
self.parse_answer(self.cfg, answer) for answer in input_text_dict["answers"]
47+
]
4548

4649
sample = dict()
4750
system_encoding, prompt_encodings, answer_encodings = self.get_encodings(
@@ -72,7 +75,7 @@ def __getitem__(self, idx: int) -> Dict:
7275
self.pad_tokens(
7376
answer_encodings[-1],
7477
attention_mask=torch.ones_like(answer_encodings[-1]),
75-
max_length=self.cfg.tokenizer.max_length_answer,
78+
max_length=self.cfg.tokenizer.max_length,
7679
pad_token_id=self.tokenizer.pad_token_id,
7780
direction="right",
7881
prefix="answer_",
@@ -99,14 +102,6 @@ def __getitem__(self, idx: int) -> Dict:
99102
)
100103
)
101104

102-
# make sure system encoding is always prepended if max_length exceeded
103-
if sample["input_ids"][0] != self.tokenizer.pad_token_id:
104-
sample["input_ids"][: len(system_encoding)] = system_encoding
105-
if self.cfg.dataset.mask_prompt_labels and "labels" in sample.keys():
106-
sample["labels"][: len(system_encoding)] = -100
107-
if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id:
108-
sample["prompt_input_ids"][: len(system_encoding)] = system_encoding
109-
110105
return sample
111106

112107
@staticmethod
@@ -122,6 +117,12 @@ def parse_prompt(cfg: Any, prompt: str):
122117
)
123118
return prompt
124119

120+
@staticmethod
121+
def parse_answer(cfg: Any, answer: str):
122+
if cfg.dataset.add_eos_token_to_answer:
123+
answer += cfg._tokenizer_eos_token
124+
return answer
125+
125126
@staticmethod
126127
def parse_system(cfg: Any, system: str):
127128
# no system tokens if empty
@@ -375,9 +376,6 @@ def get_labels(self, prompt_encodings, answer_encodings):
375376
]
376377
).to(torch.bool)
377378
labels.masked_fill_(prompt_mask, -100)
378-
if self.cfg.dataset.add_eos_token_to_answer:
379-
# eos_token may be equal to pad_token. Add the label back manually.
380-
labels[-1] = self.tokenizer.eos_token_id
381379
if self.cfg.tokenizer.max_length < len(labels):
382380
labels = labels[-self.cfg.tokenizer.max_length :]
383381

@@ -446,27 +444,16 @@ def augment_data(self, encodings):
446444
def _get_sample_encoding(self, system: str, prompt: str, answer: str) -> List:
447445
if len(system) > 0:
448446
system_encoding = self.encode(
449-
self.tokenizer, system, self.cfg.tokenizer.max_length_prompt, "right"
447+
self.tokenizer, system, self.cfg.tokenizer.max_length, "right"
450448
)["input_ids"]
451449
else:
452450
system_encoding = torch.empty(0)
453451
prompt_encoding = self.encode(
454-
self.tokenizer, prompt, self.cfg.tokenizer.max_length_prompt, "left"
452+
self.tokenizer, prompt, self.cfg.tokenizer.max_length, "left"
455453
)["input_ids"]
456-
max_length_answer = self.cfg.tokenizer.max_length_answer - int(
457-
self.cfg.dataset.add_eos_token_to_answer
458-
)
459454
answer_encoding = self.encode(
460-
self.tokenizer, answer, max_length_answer, "right"
455+
self.tokenizer, answer, self.cfg.tokenizer.max_length, "right"
461456
)["input_ids"]
462-
if self.cfg.dataset.add_eos_token_to_answer:
463-
answer_encoding = torch.cat(
464-
[
465-
answer_encoding,
466-
torch.Tensor([self.tokenizer.eos_token_id]),
467-
],
468-
dim=0,
469-
)
470457

471458
return [system_encoding, prompt_encoding, answer_encoding]
472459

@@ -482,6 +469,7 @@ def pad_tokens(
482469
sample = {}
483470

484471
if max_length < len(input_ids):
472+
logger.info(f"Input exceeds max_length of {max_length}, truncating sample.")
485473
input_ids = input_ids[-max_length:]
486474
attention_mask = attention_mask[-max_length:]
487475

llm_studio/src/models/text_causal_classification_modeling_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from transformers import AutoModelForCausalLM
66

77
from llm_studio.src.utils.data_utils import batch_padding
8-
from llm_studio.src.utils.modeling_utils import create_nlp_backbone, prepare_lora
8+
from llm_studio.src.utils.modeling_utils import (
9+
create_nlp_backbone,
10+
forward,
11+
prepare_lora,
12+
)
913

1014
logger = logging.getLogger(__name__)
1115

@@ -67,7 +71,8 @@ def forward(
6771
padding_side=self.cfg.tokenizer._padding_side,
6872
)
6973

70-
output = self.backbone(
74+
output = forward(
75+
self.backbone,
7176
input_ids=batch["prompt_input_ids"],
7277
attention_mask=batch["prompt_attention_mask"],
7378
)

llm_studio/src/models/text_causal_language_modeling_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from llm_studio.src.utils.data_utils import batch_padding
99
from llm_studio.src.utils.modeling_utils import (
1010
create_nlp_backbone,
11+
forward,
1112
generate,
1213
prepare_lora,
1314
)
@@ -92,10 +93,7 @@ def forward(
9293
padding_side=self.cfg.tokenizer._padding_side,
9394
)
9495

95-
output = self.backbone(
96-
input_ids=batch["input_ids"],
97-
attention_mask=batch["attention_mask"],
98-
)
96+
output = forward(self.backbone, batch["input_ids"], batch["attention_mask"])
9997

10098
if "labels" in batch:
10199
loss = self.loss_fn(output.logits, batch["labels"])

llm_studio/src/models/text_dpo_modeling_model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from llm_studio.src.utils.data_utils import batch_padding
1414
from llm_studio.src.utils.modeling_utils import (
1515
create_nlp_backbone,
16+
forward,
1617
generate,
1718
prepare_lora,
1819
)
@@ -83,7 +84,12 @@ def __init__(self, cfg: Any):
8384

8485
if cfg.training.lora:
8586
self.backbone = prepare_lora(cfg=cfg, backbone=self.backbone)
87+
88+
if cfg.training.lora and not cfg.training.lora_unfreeze_layers:
89+
self.backbone_orig = None
8690
else:
91+
if cfg.environment._local_rank == 0:
92+
logger.info("Duplicating backbone for reference model.")
8793
self.backbone_orig, self.backbone_orig_config = create_nlp_backbone(
8894
cfg, model_class=AutoModelForCausalLM
8995
)
@@ -137,7 +143,8 @@ def forward(
137143
f"{answer}_labels",
138144
],
139145
)
140-
logits = self.backbone(
146+
logits = forward(
147+
self.backbone,
141148
input_ids=batch[f"{answer}_input_ids"],
142149
attention_mask=batch[f"{answer}_attention_mask"],
143150
).logits
@@ -152,18 +159,21 @@ def forward(
152159
)
153160

154161
with torch.no_grad():
155-
if self.cfg.training.lora:
156-
with self.backbone.disable_adapter():
157-
reference_logits = self.backbone(
162+
if self.backbone_orig:
163+
with torch.no_grad():
164+
reference_logits = forward(
165+
self.backbone_orig,
158166
input_ids=batch[f"{answer}_input_ids"],
159167
attention_mask=batch[f"{answer}_attention_mask"],
160168
).logits
161169
else:
162-
with torch.no_grad():
163-
reference_logits = self.backbone_orig(
170+
with self.backbone.disable_adapter():
171+
reference_logits = forward(
172+
self.backbone,
164173
input_ids=batch[f"{answer}_input_ids"],
165174
attention_mask=batch[f"{answer}_attention_mask"],
166175
).logits
176+
167177
outputs[f"{answer}_reference_logps"] = get_batch_logps(
168178
reference_logits,
169179
batch[f"{answer}_labels"],

0 commit comments

Comments
 (0)