-
Notifications
You must be signed in to change notification settings - Fork 186
Improving Transform and Rerank Module #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 37 commits
509f5d9
90a4a04
7193ec0
1ae1b0d
71e129b
faa5b7b
c0e6ef6
9e40238
f4298b8
9c9fa75
665b546
e18989a
5d9761e
b5f094f
1e169eb
bf72344
758a6da
7bfa104
9e9f33e
a12e062
d9dc1ab
63a4a75
931edf0
a2c2cc6
41d6fd0
d32dc22
5a93ae7
921a98f
88130ba
379a503
acd6e7a
415aa21
0d009b2
60051a7
e0c0d0f
6558279
15a61eb
775da63
ad0d502
4b6427b
95dcea0
b92dca2
423510e
92f127a
547d4a3
c2ba8b3
0c81d18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
"""This module contains the functions to construct the prompt for task expansion.""" | ||
METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples." # noqa: E501 | ||
|
||
TASK = """ | ||
Task Description: {task_description} | ||
|
||
Task Examples: {examples} | ||
""" | ||
|
||
|
||
def construct_prompt_for_task_explanation(instruction: str, demonstrations: str): | ||
"""Constructs prompt for task explanation. | ||
|
||
This is useful for clarifying the requirements of a task, | ||
and providing a clearer description of the task. | ||
|
||
Args: | ||
instruction (str): The task instruction. | ||
demonstrations (str): The task demonstrations. | ||
|
||
Returns: | ||
str: The constructed prompt. | ||
""" | ||
task = TASK.format(task_description=instruction, examples=demonstrations) | ||
prompt = "\n--------\n".join([METAPROMPT_BASE, task]) | ||
return prompt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,9 @@ | |
|
||
import datasets | ||
|
||
from prompt2model.dataset_retriever.task_expansion_prompt import ( | ||
construct_prompt_for_task_explanation, | ||
) | ||
from prompt2model.dataset_transformer.base import DatasetTransformer | ||
from prompt2model.dataset_transformer.prompt_template import ( | ||
construct_prompt_for_plan, | ||
|
@@ -31,99 +34,72 @@ class PromptBasedDatasetTransformer(DatasetTransformer): | |
|
||
def __init__( | ||
self, | ||
num_points_to_transform: int = 10, | ||
max_allowed_failed_transforms: int = 3, | ||
plan_prompt_fn: Callable[ | ||
[str, str, list[dict], int], str | ||
[str, str, list[dict]], str | ||
] = construct_prompt_for_plan, | ||
transform_prompt_fn: Callable[ | ||
[str, str, str, dict], str | ||
[str, str, str, str], str | ||
] = construct_prompt_for_transform_data, | ||
): | ||
"""Initialize the class. | ||
"""Initializes an instance of the PromptBasedDatasetTransformer class. | ||
|
||
Args: | ||
plan_prompt_fn: A function that takes in a description of the target task, | ||
example of the target task, | ||
list of dictionaries where each dictionary is a row from a potentially | ||
relevant dataset, | ||
and the number of rows to use from this potentially relevant dataset, | ||
and returns a plan prompt. | ||
|
||
transform_prompt_fn: A function that takes in a description of the target | ||
task, an example of the target task, | ||
plan for dataset transformation, | ||
and the row from a potentially relevant dataset to be transformed. | ||
num_points_to_transform: The number of points to transform. | ||
max_allowed_failed_transforms: The maximum number of | ||
failed transforms allowed. | ||
plan_prompt_fn: The function to construct the prompt for plan | ||
transform_prompt_fn: The function to construct the prompt | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for transform data. | ||
""" | ||
self.plan_prompt_fn = plan_prompt_fn | ||
self.transform_prompt_fn = transform_prompt_fn | ||
self.plan: str = "" | ||
|
||
def make_dataset_from_samples( | ||
self, | ||
inputs: list[str], | ||
outputs: list[str], | ||
) -> datasets.DatasetDict: | ||
"""Given a list of inputs and outputs, make a dataset. | ||
|
||
This function takes in inputs and outputs, both as list of strings, | ||
and returns a DatasetDict object with a single split, "train". It has | ||
two columns, "input_col" and "output_col". | ||
|
||
|
||
Args: | ||
inputs: A list of inputs, each input is a string. | ||
outputs: A list of outputs, each output is a string. | ||
|
||
Returns: | ||
A DatasetDict object with a single split, "train". It has two | ||
columns, "input_col" and "output_col". | ||
""" | ||
if len(inputs) <= 0 or len(inputs) != len(outputs): | ||
raise ValueError("Length of inputs and outputs must be >0 and equal.") | ||
|
||
dataset_dict = {} | ||
dataset_dict["train"] = datasets.Dataset.from_dict( | ||
{"input_col": inputs, "output_col": outputs} | ||
self.num_points_to_transform = num_points_to_transform | ||
self.curr_failed_transforms = 0 | ||
self.max_allowed_failed_transforms = max_allowed_failed_transforms | ||
|
||
def generate_task_explanation(self, prompt_spec: PromptSpec) -> str: | ||
"""Generate task explanation.""" | ||
task_explanation_prompt = construct_prompt_for_task_explanation( | ||
prompt_spec.instruction, prompt_spec.examples | ||
) | ||
return datasets.DatasetDict(dataset_dict) | ||
return make_single_api_request(task_explanation_prompt, max_api_calls=10) | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def transform_data( | ||
self, | ||
prompt_spec: PromptSpec, | ||
dataset: datasets.Dataset, | ||
num_points_to_transform: int, | ||
) -> datasets.DatasetDict: | ||
"""Transform the dataset according to the prompt_spec and dataset.""" | ||
def generate_plan( | ||
self, task_explanation: str, dataset: datasets.Dataset, prompt_spec: PromptSpec | ||
) -> str: | ||
"""Generate plan for the task.""" | ||
plan_prompt = self.plan_prompt_fn( | ||
prompt_spec.instruction, | ||
prompt_spec.examples, | ||
dataset, | ||
min(5, len(dataset)), | ||
task_explanation, prompt_spec.examples, dataset | ||
) | ||
self.plan = make_single_api_request(plan_prompt) | ||
|
||
logger.info(f"Plan created. Plan: {self.plan}") | ||
return make_single_api_request(plan_prompt, max_api_calls=10) | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
inputs = [] | ||
outputs = [] | ||
|
||
max_len = min(num_points_to_transform, len(dataset)) | ||
len_count = 0 | ||
def generate_transform_prompts( | ||
self, | ||
task_explanation: str, | ||
dataset: datasets.Dataset, | ||
prompt_spec: PromptSpec, | ||
) -> list[str]: | ||
"""Get transform prompts for each row in the dataset.""" | ||
transform_prompts = [] | ||
for row in dataset: | ||
for i in range(min(self.num_points_to_transform, len(dataset))): | ||
row = dataset[i] | ||
transform_prompt = self.transform_prompt_fn( | ||
prompt_spec.instruction, | ||
prompt_spec.examples, | ||
self.plan, | ||
row, | ||
task_explanation, row, self.plan, prompt_spec.examples | ||
) | ||
transform_prompts.append(transform_prompt) | ||
return transform_prompts | ||
|
||
len_count += 1 | ||
if len_count >= max_len: | ||
break | ||
def generate_responses(self, transform_prompts_batch: list[str]) -> list[str]: | ||
"""Generate responses for the transform prompts.""" | ||
|
||
async def generate_responses(transform_prompts): | ||
responses = await api_tools.default_api_agent.generate_batch_completion( | ||
async def generate_responses_async(transform_prompts): | ||
"""Generate responses asynchronously using the specified model.""" | ||
responses = await api_tools.APIAgent( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we expose APIAgent as a function parameter with this APIAgent being the default choice? This is currently too restrictive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @viswavi let me know if the generate_responses() function should have a parameter of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that either is probably fine, so I'm ok with the current design |
||
model_name="azure/GPT-3-5-turbo-chat", max_tokens=4000 | ||
).generate_batch_completion( | ||
transform_prompts, | ||
temperature=0, | ||
responses_per_request=1, | ||
|
@@ -133,20 +109,97 @@ async def generate_responses(transform_prompts): | |
|
||
try: | ||
loop = asyncio.get_event_loop() | ||
responses = loop.run_until_complete(generate_responses(transform_prompts)) | ||
responses = loop.run_until_complete( | ||
generate_responses_async(transform_prompts_batch) | ||
) | ||
except API_ERRORS as e: | ||
handle_api_error(e) | ||
# TODO: What to return here? | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return responses | ||
|
||
def process_responses( | ||
self, responses: list, prompt_spec: PromptSpec | ||
) -> tuple[list[str], list[str]]: | ||
"""Process the responses received from the API. | ||
|
||
Also write the current set of inputs and outputs to a dump text just in case. | ||
|
||
Args: | ||
responses: A list of response strings from the API. | ||
prompt_spec: The PromptSpec object containing the instruction and examples. | ||
|
||
Returns: | ||
A tuple containing two lists: inputs and outputs. | ||
- inputs: A list of transformed input strings. | ||
- outputs: A list of transformed output strings. | ||
""" | ||
inputs, outputs = [], [] | ||
show_sample_flag = True | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for response in responses: | ||
try: | ||
extraction = find_and_parse_json(response, ["input", "output"], []) | ||
if extraction is not None: | ||
inputs.append(str(extraction["input"])) | ||
outputs.append(str(extraction["output"])) | ||
if extraction["input"] is None or extraction["output"] is None: | ||
raise ValueError("Input or output is None") | ||
input = str(extraction["input"]).strip() | ||
output = str(extraction["output"]).strip() | ||
if input in prompt_spec.examples: | ||
raise ValueError("Repeated Task Examples from prompt") | ||
|
||
inputs.append(input) | ||
outputs.append(output) | ||
if show_sample_flag: | ||
logger.info(f"inputs\n{input}\n\nouputs\n{output}") | ||
show_sample_flag = False | ||
|
||
except Exception as e: | ||
logger.error(f"Error extracting from response: {response}\nError: {e}") | ||
continue | ||
logger.error(f"Error extracting from response: {e}") | ||
self.curr_failed_transforms += 1 | ||
if self.curr_failed_transforms > self.max_allowed_failed_transforms: | ||
break | ||
|
||
with open("dump.txt", "a") as file: | ||
file.write("Input: " + ", ".join(map(str, inputs)) + "\n") | ||
file.write("Output: " + ", ".join(map(str, outputs)) + "\n") | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n") | ||
return inputs, outputs | ||
|
||
def transform_data( | ||
self, prompt_spec: PromptSpec, dataset: datasets.Dataset | ||
) -> tuple[list[str], list[str]]: | ||
"""Transforms the given dataset based on the provided prompt specification. | ||
|
||
return self.make_dataset_from_samples(inputs, outputs) | ||
Args: | ||
prompt_spec (PromptSpec): The prompt specification object that defines | ||
the transformation rules. | ||
dataset (datasets.Dataset): The dataset to be transformed. | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
A tuple containing two lists: inputs and outputs. | ||
""" | ||
task_explanation = self.generate_task_explanation(prompt_spec) | ||
self.plan = self.generate_plan(task_explanation, dataset, prompt_spec) | ||
logger.info(f"Plan created. Plan: {self.plan}") | ||
|
||
transform_prompts = self.generate_transform_prompts( | ||
task_explanation, dataset, prompt_spec | ||
) | ||
inputs, outputs = [], [] | ||
for batch_indices in range(0, len(transform_prompts), 100): | ||
transform_prompt_batch = transform_prompts[ | ||
batch_indices : batch_indices + 100 | ||
] | ||
responses = self.generate_responses(transform_prompt_batch) | ||
curr_inputs, curr_outputs = self.process_responses(responses, prompt_spec) | ||
inputs.extend(curr_inputs) | ||
outputs.extend(curr_outputs) | ||
if self.curr_failed_transforms > self.max_allowed_failed_transforms: | ||
ritugala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.error( | ||
f"Exceeded max allowed failed transforms: {self.curr_failed_transforms}" # noqa: E501 | ||
) | ||
break | ||
|
||
logger.info( | ||
f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501 | ||
) | ||
return inputs, outputs |
Uh oh!
There was an error while loading. Please reload this page.