Skip to content

Dataset load for benchmarking #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions prompttools/benchmarks/load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from datasets import load_dataset_builder,load_dataset,get_dataset_config_names, Dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want to guard prompttools from requiring datasets to be installed as part of the requirements.

Suggested change:

try:
    from datasets import load_dataset_builder, load_dataset, get_dataset_config_names, Dataset
    from datasets.dataset_dict import DatasetDict
except ImportError:
    load_dataset = None

from datasets.dataset_dict import DatasetDict
from typing import Literal

class DatasetLoader():
r"""
A dataset class used to load dataset.

Args:
dataset_name (str): The name of the dataset.
split (str()): load a specific split
"""

def __init__(
self,
dataset_name: str,
split: Literal["train","validation","test"] | None
):
self.dataset_name = dataset_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change:

if load_dataset is None:
            raise ModuleNotFoundError(
                "Package `datasets` is required to be installed to use this experiment."
                "Please use `pip install datasets` to install the package"
            )

self.split = split,
super().__init__()

def builder(self) -> DatasetDict | Dataset:
r"""
Initializes and prepares the datasetbuilder.
"""
return load_dataset_builder(path=self.dataset_name)

def load_dataset(self)-> DatasetDict | Dataset:
r"""
Return the loaded dataset"""
if self.split == None:
return load_dataset(path=self.dataset_name)
else:
return load_dataset(path=self.dataset_name, split=self.split)

def get_config(self)-> list:
r"""
Return the configuration dataset"""
return get_dataset_config_names(self.dataset_name)

# Example usecase
# d = DatasetLoader(dataset_name='rotten_tomatoes',split=None)
# print(d.builder().info.description)
3 changes: 2 additions & 1 deletion prompttools/playground/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ jinja2
huggingface_hub
llama-cpp-python
anthropic
pyperclip
pyperclip
datasets
4 changes: 2 additions & 2 deletions prompttools/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '0.0.30a0+b3007fc'
git_version = 'b3007fcc8f6dc39a859cad5ae92a64696cebf124'
__version__ = '0.0.30a0+df575ec'
git_version = 'df575ece8a0c66206b28611e84e55d420a0047b3'