Skip to content

Commit d195363

Browse files
committed
feat: enhance text classification output with probabilities and rich table display
- Added support for returning prediction probabilities in classify_texts function. - Updated test function to display classification results and metrics in a formatted table using Rich. - Included Rich library as a dependency in pyproject.toml.
1 parent fd9bf37 commit d195363

File tree

7 files changed

+138
-420
lines changed

7 files changed

+138
-420
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@ Demonstrate how to fine-tune a Sentence-Transformers encoder with the [SetFit](h
1212

1313
## Quick start
1414

15-
`uv` handles the virtual environment and dependencies automatically:
15+
[Taskfile](https://taskfile.dev/installation/) handles task management, and [uv](https://docs.astral.sh/uv/getting-started/installation/) manages the virtual environment and dependencies automatically (including python executable).
1616

1717
```bash
1818
# Train the model and save the best checkpoint under models/
19-
uv run main.py train
19+
task train
2020

2121
# Evaluate on the held-out test set
22-
uv run main.py test
22+
task test
23+
24+
# List all available tasks
25+
task --list
2326
```
2427

2528
## Dataset

Taskfile.yaml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
version: "3"
2+
3+
tasks:
4+
check-uv:
5+
cmds:
6+
- |
7+
if ! command -v uv &> /dev/null; then
8+
echo "UV is not installed. Please install UV:"
9+
echo "https://docs.astral.sh/uv/getting-started/installation/"
10+
exit 1
11+
fi
12+
silent: true
13+
14+
check-asciinema:
15+
cmds:
16+
- |
17+
if ! command -v asciinema &> /dev/null; then
18+
echo "Asciinema is not installed. Please install Asciinema:"
19+
echo "https://docs.asciinema.org/getting-started/"
20+
exit 1
21+
fi
22+
silent: true
23+
24+
check-bunx:
25+
cmds:
26+
- |
27+
if ! command -v bunx &> /dev/null; then
28+
echo "Bun is not installed. Please install Bun:"
29+
echo "https://bun.sh/docs/installation"
30+
exit 1
31+
fi
32+
silent: true
33+
34+
train:
35+
deps:
36+
- check-uv
37+
desc: Train the model using data/train.csv and data/eval.csv
38+
cmds:
39+
- uv run main.py train
40+
41+
test:
42+
deps:
43+
- check-uv
44+
desc: Test the model performance using data/test.csv
45+
cmds:
46+
- uv run main.py test
47+
48+
rec:
49+
deps:
50+
- check-asciinema
51+
- check-bunx
52+
desc: Record a terminal session using Asciinema and convert to SVG
53+
cmds:
54+
- asciinema rec -i 2 --cols 120 --rows 34 docs/demo.cast
55+
- bunx svg-term-cli --in docs/demo.cast --out docs/demo.svg --window --padding 2
56+
- rm docs/demo.cast
57+
generates:
58+
- docs/demo.svg

docs/demo.cast

Lines changed: 0 additions & 408 deletions
This file was deleted.

docs/demo.svg

Lines changed: 1 addition & 1 deletion
Loading

main.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import pandas as pd
88
from datasets import Dataset
9+
from rich.console import Console
10+
from rich.table import Table
911
from setfit import SetFitModel, Trainer, TrainingArguments
1012
from sklearn.metrics import accuracy_score, f1_score
1113

@@ -29,6 +31,8 @@
2931
BASE_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
3032
MODEL_DIR: Path = Path("models/job_interest_classifier")
3133

34+
console = Console()
35+
3236

3337
def compute_metrics(y_pred, y_true) -> dict[str, float]:
3438
"""Return accuracy and F1 for SetFit trainer."""
@@ -38,11 +42,18 @@ def compute_metrics(y_pred, y_true) -> dict[str, float]:
3842
}
3943

4044

41-
def classify_texts(texts: pd.Series, model_dir: Path) -> list[int]:
45+
def classify_texts(
46+
texts: pd.Series, model_dir: Path, probabilities: bool = False
47+
) -> list:
4248
"""Predict binary labels for a sequence of vacancy texts."""
4349
model = SetFitModel.from_pretrained(model_dir)
44-
raw = model.predict(texts.to_list())
45-
return np.atleast_1d(raw).astype(int).tolist()
50+
51+
if probabilities:
52+
probas = model.predict_proba(texts.to_list())
53+
return [proba[1] * 100 for proba in probas]
54+
else:
55+
raw = model.predict(texts.to_list())
56+
return np.atleast_1d(raw).astype(int).tolist()
4657

4758

4859
def train() -> dict[str, float]:
@@ -71,15 +82,32 @@ def train() -> dict[str, float]:
7182
return trainer.evaluate()
7283

7384

74-
def test() -> dict[str, float]:
85+
def test() -> None:
7586
"""Compute and print accuracy and F1 (if labels present) for test data."""
7687
df = TEST_DF
77-
preds = classify_texts(df["text"], MODEL_DIR)
88+
preds = classify_texts(df["text"], MODEL_DIR, probabilities=False)
89+
probas = classify_texts(df["text"], MODEL_DIR, probabilities=True)
90+
91+
results_table = Table(title="Classification Results")
92+
results_table.add_column("Index", justify="right", style="cyan")
93+
results_table.add_column("Text", style="magenta")
94+
results_table.add_column("Probability (%)", justify="center", style="yellow")
95+
results_table.add_column("Prediction", justify="center", style="yellow")
96+
97+
for (idx, text), proba, pred in zip(df["text"].items(), probas, preds):
98+
results_table.add_row(str(idx), text, f"{proba:.2f}%", str(pred))
99+
100+
console.print(results_table)
101+
102+
metrics = compute_metrics(preds, df["label"])
103+
metrics_table = Table(title="Metrics")
104+
metrics_table.add_column("Metric", style="cyan")
105+
metrics_table.add_column("Value", justify="center", style="yellow")
78106

79-
for (idx, text), pred in zip(df["text"].items(), preds):
80-
print(f"{idx}: {text} -> {pred}")
107+
for metric_name, metric_value in metrics.items():
108+
metrics_table.add_row(metric_name, f"{metric_value:.4f}")
81109

82-
return compute_metrics(preds, df["label"])
110+
console.print(metrics_table)
83111

84112

85113
if __name__ == "__main__":

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"datasets>=3.6.0",
99
"fire>=0.7.0",
1010
"pandas>=2.2.3",
11+
"rich>=14.0.0",
1112
"scikit-learn>=1.6.1",
1213
"setfit>=1.1.2",
1314
]

uv.lock

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)