Skip to content

Commit 2e095cd

Browse files
authored
Causal Regression Modeling (#788)
* init * implementation * readme * missing file * changes
1 parent 87c2978 commit 2e095cd

24 files changed

+850
-8
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ notebooks/
44
demo_data/
55
output*/
66
tmp/
7-
data*/
7+
data/
88
examples/data_oasst2
99
examples/output_oasst2
1010
data_old/

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Using CLI for fine-tuning LLMs:
5353

5454
## What's New
5555

56+
- [PR 788](https://github.com/h2oai/h2o-llmstudio/pull/788) New problem type for Causal Regression Modeling allows to train single target regression data using LLMs.
5657
- [PR 747](https://github.com/h2oai/h2o-llmstudio/pull/747) Fully removed RLHF in favor of DPO/IPO/KTO optimization.
5758
- [PR 741](https://github.com/h2oai/h2o-llmstudio/pull/741) Removing separate max length settings for prompt and answer in favor of a single `max_length` settings better resembling `chat_template` functionality from `transformers`.
5859
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/599) Added `KTOPairLoss` for DPO modeling allowing to train models with simple preference data. Data currently needs to be manually prepared by randomly matching positive and negative examples as pairs.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
The column in the dataset containing the expected output.
22

3-
For classification, this needs to be an integer column starting from zero containing the class label.
3+
For classification, this needs to be an integer column starting from zero containing the class label, while for regression, it needs to be a float column.

documentation/docs/tooltips/experiments/_metric.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ Causal Classification Modeling
1212
- AUC: Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC).
1313
- Accuracy: Compute the accuracy of the model.
1414
- LogLoss: Compute the log loss of the model.
15+
16+
Causal Regression Modeling
17+
- MSE: Compute Mean Squared Error of the model.
18+
- MAE: Compute Mean Absolute Error of the model.

documentation/docs/tooltips/experiments/_problem-type.mdx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ Defines the problem type of the experiment, which also defines the settings H2O
22

33
- Causal Language Modeling: Used to fine-tune large language models
44

5-
- DPO Modeling: Used to fine-tune large language models using Direct Preference Optimization
5+
- Causal Classification Modeling: Used to fine-tune causal classification models
6+
7+
- Causal Regression Modeling: Used to fine-tune causal regression models
68

79
- Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models
810

9-
- Causal Classification Modeling: Used to fine-tune causal classification models
11+
- DPO Modeling: Used to fine-tune large language models using Direct Preference Optimization
12+

llm_studio/app_utils/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ def get_size(x):
5353
"start_page": "home",
5454
"problem_types": [
5555
"text_causal_language_modeling_config",
56-
"text_dpo_modeling_config",
57-
"text_sequence_to_sequence_modeling_config",
5856
"text_causal_classification_modeling_config",
57+
"text_causal_regression_modeling_config",
58+
"text_sequence_to_sequence_modeling_config",
59+
"text_dpo_modeling_config",
5960
],
6061
"problem_categories": ["text"],
6162
"dataset_keys": [

llm_studio/app_utils/hugging_face_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,15 @@ def publish_model_to_hugging_face(
253253
repo_type="model",
254254
commit_message="Upload classification_head.pth",
255255
)
256+
# push regression head to hub
257+
if os.path.isfile(f"{path_to_experiment}/regression_head.pth"):
258+
api.upload_file(
259+
path_or_fileobj=f"{path_to_experiment}/regression_head.pth",
260+
path_in_repo="regression_head.pth",
261+
repo_id=repo_id,
262+
repo_type="model",
263+
commit_message="Upload regression_head.pth",
264+
)
256265

257266
# push config to hub
258267
api.upload_file(

llm_studio/app_utils/sections/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ async def should_start_chat(q: Q):
145145
box="first",
146146
items=[
147147
ui.text(
148-
"Chatbot is not available for text classification problems. "
148+
"Chatbot is not available for this problem type. "
149149
"Please select a text generation problem."
150150
)
151151
],

llm_studio/app_utils/sections/experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,6 +1814,7 @@ async def experiment_download_model(q: Q):
18141814
"added_tokens.json",
18151815
"model_card.md",
18161816
"classification_head.pth",
1817+
"regression_head.pth",
18171818
]
18181819
FILES_TO_PUSH = set(
18191820
FILES_TO_PUSH
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import os
2+
from dataclasses import dataclass, field
3+
from typing import Any, Dict, List, Tuple
4+
5+
import llm_studio.src.datasets.text_causal_regression_ds
6+
import llm_studio.src.plots.text_causal_classification_modeling_plots
7+
from llm_studio.python_configs.base import DefaultConfig, DefaultConfigProblemBase
8+
from llm_studio.python_configs.text_causal_classification_modeling_config import (
9+
ConfigNLPCausalClassificationAugmentation as ConfigNLPCausalRegressionAugmentation,
10+
)
11+
from llm_studio.python_configs.text_causal_classification_modeling_config import (
12+
ConfigNLPCausalClassificationDataset,
13+
)
14+
from llm_studio.python_configs.text_causal_classification_modeling_config import (
15+
ConfigNLPCausalClassificationLogging as ConfigNLPCausalRegressionLogging,
16+
)
17+
from llm_studio.python_configs.text_causal_classification_modeling_config import (
18+
ConfigNLPCausalClassificationTokenizer as ConfigNLPCausalRegressionTokenizer,
19+
)
20+
from llm_studio.python_configs.text_causal_classification_modeling_config import (
21+
ConfigNLPCausalClassificationTraining,
22+
)
23+
from llm_studio.python_configs.text_causal_language_modeling_config import (
24+
ConfigNLPCausalLMArchitecture,
25+
ConfigNLPCausalLMEnvironment,
26+
)
27+
from llm_studio.src import possible_values
28+
from llm_studio.src.losses import text_causal_regression_modeling_losses
29+
from llm_studio.src.metrics import text_causal_regression_modeling_metrics
30+
from llm_studio.src.models import text_causal_regression_modeling_model
31+
from llm_studio.src.utils.modeling_utils import generate_experiment_name
32+
33+
34+
@dataclass
35+
class ConfigNLPCausalRegressionDataset(ConfigNLPCausalClassificationDataset):
36+
dataset_class: Any = llm_studio.src.datasets.text_causal_regression_ds.CustomDataset
37+
num_classes: int = 1
38+
39+
def __post_init__(self):
40+
self.prompt_column = (
41+
tuple(
42+
self.prompt_column,
43+
)
44+
if isinstance(self.prompt_column, str)
45+
else tuple(self.prompt_column)
46+
)
47+
super().__post_init__()
48+
49+
self._visibility["num_classes"] = -1
50+
51+
52+
@dataclass
53+
class ConfigNLPCausalRegressionTraining(ConfigNLPCausalClassificationTraining):
54+
loss_class: Any = text_causal_regression_modeling_losses.Losses
55+
loss_function: str = "MSELoss"
56+
57+
learning_rate: float = 0.0001
58+
differential_learning_rate_layers: Tuple[str, ...] = ("regression_head",)
59+
differential_learning_rate: float = 0.00001
60+
61+
def __post_init__(self):
62+
super().__post_init__()
63+
self._possible_values["loss_function"] = self.loss_class.names()
64+
65+
self._possible_values["differential_learning_rate_layers"] = (
66+
possible_values.String(
67+
values=("backbone", "embed", "regression_head"),
68+
allow_custom=False,
69+
placeholder="Select optional layers...",
70+
)
71+
)
72+
73+
74+
@dataclass
75+
class ConfigNLPCausalRegressionArchitecture(ConfigNLPCausalLMArchitecture):
76+
model_class: Any = text_causal_regression_modeling_model.Model
77+
78+
def __post_init__(self):
79+
super().__post_init__()
80+
81+
82+
@dataclass
83+
class ConfigNLPCausalRegressionPrediction(DefaultConfig):
84+
metric_class: Any = text_causal_regression_modeling_metrics.Metrics
85+
metric: str = "MSE"
86+
batch_size_inference: int = 0
87+
88+
def __post_init__(self):
89+
super().__post_init__()
90+
91+
self._possible_values["metric"] = self.metric_class.names()
92+
self._possible_values["batch_size_inference"] = (0, 512, 1)
93+
94+
self._visibility["metric_class"] = -1
95+
96+
97+
@dataclass
98+
class ConfigNLPCausalRegressionEnvironment(ConfigNLPCausalLMEnvironment):
99+
_model_card_template: str = "text_causal_regression_model_card_template.md"
100+
_summary_card_template: str = (
101+
"text_causal_regression_experiment_summary_card_template.md"
102+
)
103+
104+
def __post_init__(self):
105+
super().__post_init__()
106+
107+
108+
@dataclass
109+
class ConfigProblemBase(DefaultConfigProblemBase):
110+
output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}"
111+
experiment_name: str = field(default_factory=generate_experiment_name)
112+
llm_backbone: str = "h2oai/h2o-danube2-1.8b-chat"
113+
114+
dataset: ConfigNLPCausalRegressionDataset = field(
115+
default_factory=ConfigNLPCausalRegressionDataset
116+
)
117+
tokenizer: ConfigNLPCausalRegressionTokenizer = field(
118+
default_factory=ConfigNLPCausalRegressionTokenizer
119+
)
120+
architecture: ConfigNLPCausalRegressionArchitecture = field(
121+
default_factory=ConfigNLPCausalRegressionArchitecture
122+
)
123+
training: ConfigNLPCausalRegressionTraining = field(
124+
default_factory=ConfigNLPCausalRegressionTraining
125+
)
126+
augmentation: ConfigNLPCausalRegressionAugmentation = field(
127+
default_factory=ConfigNLPCausalRegressionAugmentation
128+
)
129+
prediction: ConfigNLPCausalRegressionPrediction = field(
130+
default_factory=ConfigNLPCausalRegressionPrediction
131+
)
132+
environment: ConfigNLPCausalRegressionEnvironment = field(
133+
default_factory=ConfigNLPCausalRegressionEnvironment
134+
)
135+
logging: ConfigNLPCausalRegressionLogging = field(
136+
default_factory=ConfigNLPCausalRegressionLogging
137+
)
138+
139+
def __post_init__(self):
140+
super().__post_init__()
141+
142+
self._visibility["output_directory"] = -1
143+
144+
self._possible_values["llm_backbone"] = possible_values.String(
145+
values=(
146+
"h2oai/h2o-danube2-1.8b-base",
147+
"h2oai/h2o-danube2-1.8b-chat",
148+
"h2oai/h2ogpt-4096-llama2-7b",
149+
"h2oai/h2ogpt-4096-llama2-7b-chat",
150+
"h2oai/h2ogpt-4096-llama2-13b",
151+
"h2oai/h2ogpt-4096-llama2-13b-chat",
152+
"h2oai/h2ogpt-4096-llama2-70b",
153+
"h2oai/h2ogpt-4096-llama2-70b-chat",
154+
"tiiuae/falcon-7b",
155+
"mistralai/Mistral-7B-v0.1",
156+
"HuggingFaceH4/zephyr-7b-beta",
157+
"google/gemma-2b",
158+
"google/gemma-7b",
159+
"stabilityai/stablelm-3b-4e1t",
160+
"microsoft/phi-2",
161+
"facebook/opt-125m",
162+
),
163+
allow_custom=True,
164+
)
165+
166+
def check(self) -> Dict[str, List]:
167+
errors: Dict[str, List] = {"title": [], "message": []}
168+
169+
if self.dataset.parent_id_column not in ["None", None]:
170+
errors["title"] += ["Parent ID column is not supported for regression"]
171+
errors["message"] += [
172+
"Parent ID column is not supported for regression datasets."
173+
]
174+
175+
return errors
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
from typing import Any, Dict
3+
4+
import numpy as np
5+
import pandas as pd
6+
7+
from llm_studio.src.datasets.text_causal_language_modeling_ds import (
8+
CustomDataset as TextCausalLanguageModelingCustomDataset,
9+
)
10+
from llm_studio.src.utils.exceptions import LLMDataException
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class CustomDataset(TextCausalLanguageModelingCustomDataset):
16+
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
17+
super().__init__(df=df, cfg=cfg, mode=mode)
18+
self.answers_float = df[cfg.dataset.answer_column].astype(float).values.tolist()
19+
20+
if cfg.dataset.parent_id_column != "None":
21+
raise LLMDataException(
22+
"Parent ID column is not supported for regression datasets."
23+
)
24+
25+
def __getitem__(self, idx: int) -> Dict:
26+
sample = super().__getitem__(idx)
27+
sample["class_label"] = self.answers_float[idx]
28+
return sample
29+
30+
def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
31+
output["logits"] = output["logits"].float()
32+
preds = output["logits"]
33+
preds = np.array(preds).astype(float).astype(str).reshape(-1)
34+
output["predicted_text"] = preds
35+
return super().postprocess_output(cfg, df, output)
36+
37+
def clean_output(self, output, cfg):
38+
return output
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
from typing import Any, KeysView
3+
4+
from torch import Tensor, nn
5+
6+
__all__ = ["Losses"]
7+
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class MSELoss(nn.Module):
13+
def __init__(self, cfg: Any):
14+
super().__init__()
15+
self.cfg = cfg
16+
self.loss_fn = nn.MSELoss()
17+
18+
def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
19+
return self.loss_fn(logits, labels.reshape(-1))
20+
21+
22+
class MAELoss(nn.Module):
23+
def __init__(self, cfg: Any):
24+
super().__init__()
25+
self.cfg = cfg
26+
self.loss_fn = nn.L1Loss()
27+
28+
def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
29+
return self.loss_fn(logits, labels.reshape(-1))
30+
31+
32+
class Losses:
33+
"""Losses factory."""
34+
35+
_losses = {
36+
"MSELoss": MSELoss,
37+
"MAELoss": MAELoss,
38+
}
39+
40+
@classmethod
41+
def names(cls) -> KeysView:
42+
return cls._losses.keys()
43+
44+
@classmethod
45+
def get(cls, name: str) -> Any:
46+
"""Access to Losses.
47+
48+
Args:
49+
name: losses name
50+
Returns:
51+
A class to build the Losses
52+
"""
53+
return cls._losses.get(name, MSELoss)

0 commit comments

Comments
 (0)