Skip to content

Causal Regression Modeling #788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ notebooks/
demo_data/
output*/
tmp/
data*/
data/
examples/data_oasst2
examples/output_oasst2
data_old/
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Using CLI for fine-tuning LLMs:

## What's New

- [PR 788](https://github.com/h2oai/h2o-llmstudio/pull/788) New problem type for Causal Regression Modeling allows to train single target regression data using LLMs.
- [PR 747](https://github.com/h2oai/h2o-llmstudio/pull/747) Fully removed RLHF in favor of DPO/IPO/KTO optimization.
- [PR 741](https://github.com/h2oai/h2o-llmstudio/pull/741) Removing separate max length settings for prompt and answer in favor of a single `max_length` settings better resembling `chat_template` functionality from `transformers`.
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/599) Added `KTOPairLoss` for DPO modeling allowing to train models with simple preference data. Data currently needs to be manually prepared by randomly matching positive and negative examples as pairs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
The column in the dataset containing the expected output.

For classification, this needs to be an integer column starting from zero containing the class label.
For classification, this needs to be an integer column starting from zero containing the class label, while for regression, it needs to be a float column.
4 changes: 4 additions & 0 deletions documentation/docs/tooltips/experiments/_metric.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ Causal Classification Modeling
- AUC: Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC).
- Accuracy: Compute the accuracy of the model.
- LogLoss: Compute the log loss of the model.

Causal Regression Modeling
- MSE: Compute Mean Squared Error of the model.
- MAE: Compute Mean Absolute Error of the model.
7 changes: 5 additions & 2 deletions documentation/docs/tooltips/experiments/_problem-type.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ Defines the problem type of the experiment, which also defines the settings H2O

- Causal Language Modeling: Used to fine-tune large language models

- DPO Modeling: Used to fine-tune large language models using Direct Preference Optimization
- Causal Classification Modeling: Used to fine-tune causal classification models

- Causal Regression Modeling: Used to fine-tune causal regression models

- Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models

- Causal Classification Modeling: Used to fine-tune causal classification models
- DPO Modeling: Used to fine-tune large language models using Direct Preference Optimization

5 changes: 3 additions & 2 deletions llm_studio/app_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def get_size(x):
"start_page": "home",
"problem_types": [
"text_causal_language_modeling_config",
"text_dpo_modeling_config",
"text_sequence_to_sequence_modeling_config",
"text_causal_classification_modeling_config",
"text_causal_regression_modeling_config",
"text_sequence_to_sequence_modeling_config",
"text_dpo_modeling_config",
],
"problem_categories": ["text"],
"dataset_keys": [
Expand Down
9 changes: 9 additions & 0 deletions llm_studio/app_utils/hugging_face_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ def publish_model_to_hugging_face(
repo_type="model",
commit_message="Upload classification_head.pth",
)
# push regression head to hub
if os.path.isfile(f"{path_to_experiment}/regression_head.pth"):
api.upload_file(
path_or_fileobj=f"{path_to_experiment}/regression_head.pth",
path_in_repo="regression_head.pth",
repo_id=repo_id,
repo_type="model",
commit_message="Upload regression_head.pth",
)

# push config to hub
api.upload_file(
Expand Down
2 changes: 1 addition & 1 deletion llm_studio/app_utils/sections/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def should_start_chat(q: Q):
box="first",
items=[
ui.text(
"Chatbot is not available for text classification problems. "
"Chatbot is not available for this problem type. "
"Please select a text generation problem."
)
],
Expand Down
1 change: 1 addition & 0 deletions llm_studio/app_utils/sections/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,7 @@ async def experiment_download_model(q: Q):
"added_tokens.json",
"model_card.md",
"classification_head.pth",
"regression_head.pth",
]
FILES_TO_PUSH = set(
FILES_TO_PUSH
Expand Down
175 changes: 175 additions & 0 deletions llm_studio/python_configs/text_causal_regression_modeling_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple

import llm_studio.src.datasets.text_causal_regression_ds
import llm_studio.src.plots.text_causal_classification_modeling_plots
from llm_studio.python_configs.base import DefaultConfig, DefaultConfigProblemBase
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigNLPCausalClassificationAugmentation as ConfigNLPCausalRegressionAugmentation,
)
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigNLPCausalClassificationDataset,
)
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigNLPCausalClassificationLogging as ConfigNLPCausalRegressionLogging,
)
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigNLPCausalClassificationTokenizer as ConfigNLPCausalRegressionTokenizer,
)
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigNLPCausalClassificationTraining,
)
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigNLPCausalLMArchitecture,
ConfigNLPCausalLMEnvironment,
)
from llm_studio.src import possible_values
from llm_studio.src.losses import text_causal_regression_modeling_losses
from llm_studio.src.metrics import text_causal_regression_modeling_metrics
from llm_studio.src.models import text_causal_regression_modeling_model
from llm_studio.src.utils.modeling_utils import generate_experiment_name


@dataclass
class ConfigNLPCausalRegressionDataset(ConfigNLPCausalClassificationDataset):
dataset_class: Any = llm_studio.src.datasets.text_causal_regression_ds.CustomDataset
num_classes: int = 1

def __post_init__(self):
self.prompt_column = (
tuple(
self.prompt_column,
)
if isinstance(self.prompt_column, str)
else tuple(self.prompt_column)
)
super().__post_init__()

self._visibility["num_classes"] = -1


@dataclass
class ConfigNLPCausalRegressionTraining(ConfigNLPCausalClassificationTraining):
loss_class: Any = text_causal_regression_modeling_losses.Losses
loss_function: str = "MSELoss"

learning_rate: float = 0.0001
differential_learning_rate_layers: Tuple[str, ...] = ("regression_head",)
differential_learning_rate: float = 0.00001

def __post_init__(self):
super().__post_init__()
self._possible_values["loss_function"] = self.loss_class.names()

self._possible_values["differential_learning_rate_layers"] = (
possible_values.String(
values=("backbone", "embed", "regression_head"),
allow_custom=False,
placeholder="Select optional layers...",
)
)


@dataclass
class ConfigNLPCausalRegressionArchitecture(ConfigNLPCausalLMArchitecture):
model_class: Any = text_causal_regression_modeling_model.Model

def __post_init__(self):
super().__post_init__()


@dataclass
class ConfigNLPCausalRegressionPrediction(DefaultConfig):
metric_class: Any = text_causal_regression_modeling_metrics.Metrics
metric: str = "MSE"
batch_size_inference: int = 0

def __post_init__(self):
super().__post_init__()

self._possible_values["metric"] = self.metric_class.names()
self._possible_values["batch_size_inference"] = (0, 512, 1)

self._visibility["metric_class"] = -1


@dataclass
class ConfigNLPCausalRegressionEnvironment(ConfigNLPCausalLMEnvironment):
_model_card_template: str = "text_causal_regression_model_card_template.md"
_summary_card_template: str = (
"text_causal_regression_experiment_summary_card_template.md"
)

def __post_init__(self):
super().__post_init__()


@dataclass
class ConfigProblemBase(DefaultConfigProblemBase):
output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}"
experiment_name: str = field(default_factory=generate_experiment_name)
llm_backbone: str = "h2oai/h2o-danube2-1.8b-chat"

dataset: ConfigNLPCausalRegressionDataset = field(
default_factory=ConfigNLPCausalRegressionDataset
)
tokenizer: ConfigNLPCausalRegressionTokenizer = field(
default_factory=ConfigNLPCausalRegressionTokenizer
)
architecture: ConfigNLPCausalRegressionArchitecture = field(
default_factory=ConfigNLPCausalRegressionArchitecture
)
training: ConfigNLPCausalRegressionTraining = field(
default_factory=ConfigNLPCausalRegressionTraining
)
augmentation: ConfigNLPCausalRegressionAugmentation = field(
default_factory=ConfigNLPCausalRegressionAugmentation
)
prediction: ConfigNLPCausalRegressionPrediction = field(
default_factory=ConfigNLPCausalRegressionPrediction
)
environment: ConfigNLPCausalRegressionEnvironment = field(
default_factory=ConfigNLPCausalRegressionEnvironment
)
logging: ConfigNLPCausalRegressionLogging = field(
default_factory=ConfigNLPCausalRegressionLogging
)

def __post_init__(self):
super().__post_init__()

self._visibility["output_directory"] = -1

self._possible_values["llm_backbone"] = possible_values.String(
values=(
"h2oai/h2o-danube2-1.8b-base",
"h2oai/h2o-danube2-1.8b-chat",
"h2oai/h2ogpt-4096-llama2-7b",
"h2oai/h2ogpt-4096-llama2-7b-chat",
"h2oai/h2ogpt-4096-llama2-13b",
"h2oai/h2ogpt-4096-llama2-13b-chat",
"h2oai/h2ogpt-4096-llama2-70b",
"h2oai/h2ogpt-4096-llama2-70b-chat",
"tiiuae/falcon-7b",
"mistralai/Mistral-7B-v0.1",
"HuggingFaceH4/zephyr-7b-beta",
"google/gemma-2b",
"google/gemma-7b",
"stabilityai/stablelm-3b-4e1t",
"microsoft/phi-2",
"facebook/opt-125m",
),
allow_custom=True,
)

def check(self) -> Dict[str, List]:
errors: Dict[str, List] = {"title": [], "message": []}

if self.dataset.parent_id_column not in ["None", None]:
errors["title"] += ["Parent ID column is not supported for regression"]
errors["message"] += [
"Parent ID column is not supported for regression datasets."
]

return errors
38 changes: 38 additions & 0 deletions llm_studio/src/datasets/text_causal_regression_ds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
from typing import Any, Dict

import numpy as np
import pandas as pd

from llm_studio.src.datasets.text_causal_language_modeling_ds import (
CustomDataset as TextCausalLanguageModelingCustomDataset,
)
from llm_studio.src.utils.exceptions import LLMDataException

logger = logging.getLogger(__name__)


class CustomDataset(TextCausalLanguageModelingCustomDataset):
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
super().__init__(df=df, cfg=cfg, mode=mode)
self.answers_float = df[cfg.dataset.answer_column].astype(float).values.tolist()

if cfg.dataset.parent_id_column != "None":
raise LLMDataException(
"Parent ID column is not supported for regression datasets."
)

def __getitem__(self, idx: int) -> Dict:
sample = super().__getitem__(idx)
sample["class_label"] = self.answers_float[idx]
return sample

def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
output["logits"] = output["logits"].float()
preds = output["logits"]
preds = np.array(preds).astype(float).astype(str).reshape(-1)
output["predicted_text"] = preds
return super().postprocess_output(cfg, df, output)

def clean_output(self, output, cfg):
return output
53 changes: 53 additions & 0 deletions llm_studio/src/losses/text_causal_regression_modeling_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from typing import Any, KeysView

from torch import Tensor, nn

__all__ = ["Losses"]


logger = logging.getLogger(__name__)


class MSELoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.MSELoss()

def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
return self.loss_fn(logits, labels.reshape(-1))


class MAELoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.L1Loss()

def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
return self.loss_fn(logits, labels.reshape(-1))


class Losses:
"""Losses factory."""

_losses = {
"MSELoss": MSELoss,
"MAELoss": MAELoss,
}

@classmethod
def names(cls) -> KeysView:
return cls._losses.keys()

@classmethod
def get(cls, name: str) -> Any:
"""Access to Losses.

Args:
name: losses name
Returns:
A class to build the Losses
"""
return cls._losses.get(name, MSELoss)
Loading