Skip to content

Commit 4633edd

Browse files
authored
option to use --reduced-test in dabstep evaluation (#16)
1 parent 3d3a673 commit 4633edd

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

eval/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
## 📊 Evaluation
32

43
The ReAct Data Science Agent includes two comprehensive evaluation frameworks to test its capabilities on real-world data science tasks:
@@ -22,12 +21,12 @@ The DABstep evaluation requires no manual data setup - everything is handled aut
2221
cd eval
2322
python dabstep.py
2423

25-
# Test with just the first few examples
26-
python dabstep.py --test-first-only
27-
2824
# Skip hard difficulty questions
2925
python dabstep.py --skip-hard
3026

27+
# Sample 30 easy and 30 hard tasks (mutually exclusive with --skip-hard)
28+
python dabstep.py --reduced-test
29+
3130
# Submit results (creates submission file)
3231
python dabstep.py --submit --which-split dev
3332
```
@@ -37,6 +36,7 @@ python dabstep.py --submit --which-split dev
3736
2. **Specialized prompts**: The agent receives detailed instructions about financial domain concepts from `manual.md`
3837
3. **Precise file paths**: Uses absolute paths like `/app/downloaded_data/data/context/payments.csv`
3938
4. **Domain validation**: Emphasizes reading domain manuals before analysis to ensure correct interpretations
39+
5. **Task sampling**: Supports either skipping hard tasks or sampling a balanced easy and hard tasks
4040

4141
### 🏆 Kaggle Competition Evaluation
4242

eval/dabstep.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import json
44
from dataclasses import dataclass
55
from pathlib import Path
6+
import random
67

7-
from datasets import load_dataset
8+
from datasets import load_dataset, concatenate_datasets
89
from open_data_scientist.codeagent import ReActDataScienceAgent
910

1011

@@ -150,31 +151,49 @@ def write_jsonl(data: list[dict], filepath: Path) -> None:
150151

151152

152153
def main(
153-
test_first_only=False,
154154
submit=False,
155155
data_dir=None,
156156
which_split="dev",
157157
skip_hard=False,
158+
reduced_test=False,
158159
):
160+
if skip_hard and reduced_test:
161+
raise ValueError("Cannot use both --skip-hard and --reduced-test at the same time")
162+
159163
# Load the dataset
160164
ds = load_dataset("adyen/DABstep", "tasks")
161165

162166
dataset = ds[which_split]
163167

164168
# Store hard tasks before filtering if we're skipping and submitting
165169
skipped_tasks = []
166-
if skip_hard and submit:
167-
skipped_tasks = [task for task in dataset if task.get("level") == "hard"]
168-
169170
if skip_hard:
171+
skipped_tasks = [task for task in dataset if task.get("level") == "hard"]
170172
dataset = dataset.filter(lambda example: example.get("level") != "hard")
171-
172-
if test_first_only:
173-
dataset = dataset.select([0, 1, 2])
173+
elif reduced_test:
174+
dataset = dataset.shuffle(seed=42)
175+
easy_tasks = dataset.filter(lambda x: x["level"] == "easy")
176+
hard_tasks = dataset.filter(lambda x: x["level"] == "hard")
177+
178+
# Sample 20 tasks from each difficulty level
179+
sampled_easy = easy_tasks.select(range(20))
180+
sampled_hard = hard_tasks.select(range(20))
181+
182+
sampled_ids = set()
183+
for task in sampled_easy:
184+
sampled_ids.add(task["task_id"])
185+
for task in sampled_hard:
186+
sampled_ids.add(task["task_id"])
187+
188+
skipped_tasks = [task for task in dataset if task["task_id"] not in sampled_ids]
189+
dataset = concatenate_datasets([sampled_easy, sampled_hard])
190+
dataset = dataset.shuffle(seed=42)
191+
else:
192+
print("Running all tasks")
174193

175194
number_of_examples = len(dataset)
176195
results = []
177-
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
196+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
178197
future_to_task = {
179198
executor.submit(process_task, task, submit, data_dir): task
180199
for task in dataset
@@ -222,9 +241,6 @@ def main(
222241

223242
if __name__ == "__main__":
224243
parser = argparse.ArgumentParser(description="Run DABstep evaluation")
225-
parser.add_argument(
226-
"--test-first-only", action="store_true", help="Test only the first example"
227-
)
228244
parser.add_argument(
229245
"--submit", action="store_true", help="Submit the results to the leaderboard"
230246
)
@@ -237,6 +253,9 @@ def main(
237253
parser.add_argument(
238254
"--skip-hard", action="store_true", help="Skip examples with level=hard"
239255
)
256+
parser.add_argument(
257+
"--reduced-test", action="store_true", help="Sample 20 easy and 20 hard tasks"
258+
)
240259
parser.add_argument(
241260
"--data-dir",
242261
default=None,
@@ -245,9 +264,9 @@ def main(
245264
args = parser.parse_args()
246265

247266
main(
248-
test_first_only=args.test_first_only,
249267
submit=args.submit,
250268
data_dir=args.data_dir,
251269
which_split=args.which_split,
252270
skip_hard=args.skip_hard,
271+
reduced_test=args.reduced_test,
253272
)

0 commit comments

Comments
 (0)