66import numpy as np
77import pandas as pd
88from datasets import Dataset
9+ from rich .console import Console
10+ from rich .table import Table
911from setfit import SetFitModel , Trainer , TrainingArguments
1012from sklearn .metrics import accuracy_score , f1_score
1113
2931BASE_MODEL : str = "sentence-transformers/all-MiniLM-L6-v2"
3032MODEL_DIR : Path = Path ("models/job_interest_classifier" )
3133
34+ console = Console ()
35+
3236
3337def compute_metrics (y_pred , y_true ) -> dict [str , float ]:
3438 """Return accuracy and F1 for SetFit trainer."""
@@ -38,11 +42,18 @@ def compute_metrics(y_pred, y_true) -> dict[str, float]:
3842 }
3943
4044
41- def classify_texts (texts : pd .Series , model_dir : Path ) -> list [int ]:
45+ def classify_texts (
46+ texts : pd .Series , model_dir : Path , probabilities : bool = False
47+ ) -> list :
4248 """Predict binary labels for a sequence of vacancy texts."""
4349 model = SetFitModel .from_pretrained (model_dir )
44- raw = model .predict (texts .to_list ())
45- return np .atleast_1d (raw ).astype (int ).tolist ()
50+
51+ if probabilities :
52+ probas = model .predict_proba (texts .to_list ())
53+ return [proba [1 ] * 100 for proba in probas ]
54+ else :
55+ raw = model .predict (texts .to_list ())
56+ return np .atleast_1d (raw ).astype (int ).tolist ()
4657
4758
4859def train () -> dict [str , float ]:
@@ -71,15 +82,32 @@ def train() -> dict[str, float]:
7182 return trainer .evaluate ()
7283
7384
74- def test () -> dict [ str , float ] :
85+ def test () -> None :
7586 """Compute and print accuracy and F1 (if labels present) for test data."""
7687 df = TEST_DF
77- preds = classify_texts (df ["text" ], MODEL_DIR )
88+ preds = classify_texts (df ["text" ], MODEL_DIR , probabilities = False )
89+ probas = classify_texts (df ["text" ], MODEL_DIR , probabilities = True )
90+
91+ results_table = Table (title = "Classification Results" )
92+ results_table .add_column ("Index" , justify = "right" , style = "cyan" )
93+ results_table .add_column ("Text" , style = "magenta" )
94+ results_table .add_column ("Probability (%)" , justify = "center" , style = "yellow" )
95+ results_table .add_column ("Prediction" , justify = "center" , style = "yellow" )
96+
97+ for (idx , text ), proba , pred in zip (df ["text" ].items (), probas , preds ):
98+ results_table .add_row (str (idx ), text , f"{ proba :.2f} %" , str (pred ))
99+
100+ console .print (results_table )
101+
102+ metrics = compute_metrics (preds , df ["label" ])
103+ metrics_table = Table (title = "Metrics" )
104+ metrics_table .add_column ("Metric" , style = "cyan" )
105+ metrics_table .add_column ("Value" , justify = "center" , style = "yellow" )
78106
79- for ( idx , text ), pred in zip ( df [ "text" ] .items (), preds ):
80- print ( f"{ idx } : { text } -> { pred } " )
107+ for metric_name , metric_value in metrics .items ():
108+ metrics_table . add_row ( metric_name , f"{ metric_value :.4f } " )
81109
82- return compute_metrics ( preds , df [ "label" ] )
110+ console . print ( metrics_table )
83111
84112
85113if __name__ == "__main__" :
0 commit comments