Skip to content

Refactored batch update functions to utils #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions src/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from shared.vectorflow_request import VectorflowRequest
from services.rabbitmq.rabbit_service import create_connection_params
from pika.exceptions import AMQPConnectionError
from shared.utils import update_batch_and_job_status

logging.basicConfig(filename='./extract-log.txt', level=logging.INFO)
logging.basicConfig(filename='./extract-error-log.txt', level=logging.ERROR)
Expand Down Expand Up @@ -134,24 +135,7 @@ def remove_from_minio(filename):
client = create_minio_client()
client.remove_object(os.getenv("MINIO_BUCKET"), filename)

# TODO: refactor into utils
def update_batch_and_job_status(job_id, batch_status, batch_id):
try:
if not job_id and batch_id:
job = safe_db_operation(batch_service.get_batch, batch_id)
job_id = job.job_id
updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status)
job = safe_db_operation(job_service.update_job_with_batch, job_id, updated_batch_status)
if job.job_status == JobStatus.COMPLETED:
logging.info(f"Job {job_id} completed successfully")
elif job.job_status == JobStatus.PARTIALLY_COMPLETED:
logging.info(f"Job {job_id} partially completed. {job.batches_succeeded} out of {job.total_batches} batches succeeded")
elif job.job_status == JobStatus.FAILED:
logging.info(f"Job {job_id} failed. {job.batches_succeeded} out of {job.total_batches} batches succeeded")

except Exception as e:
logging.error('Error updating job and batch status: %s', e)
safe_db_operation(job_service.update_job_status, job_id, JobStatus.FAILED)


####################
## RabbitMQ Logic ##
Expand Down
43 changes: 42 additions & 1 deletion src/shared/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
import uuid
import requests
import json
import logging
from services.database.database import get_db, safe_db_operation
import services.database.job_service as job_service
from shared.job_status import JobStatus
import services.database.batch_service as batch_service

def generate_uuid_from_tuple(t, namespace_uuid='6ba7b810-9dad-11d1-80b4-00c04fd430c8'):
namespace = uuid.UUID(namespace_uuid)
Expand Down Expand Up @@ -30,4 +38,37 @@ def send_embeddings_to_webhook(embedded_chunks: list[dict], job):
json=data
)

return response
return response


def update_batch_and_job_status(job_id, batch_status, batch_id):
try:
if not job_id and batch_id:
job = safe_db_operation(batch_service.get_batch, batch_id)
job_id = job.job_id
updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status)
job = safe_db_operation(job_service.update_job_with_batch, job_id, updated_batch_status)
if job.job_status == JobStatus.COMPLETED:
logging.info(f"Job {job_id} completed successfully")
elif job.job_status == JobStatus.PARTIALLY_COMPLETED:
logging.info(f"Job {job_id} partially completed. {job.batches_succeeded} out of {job.total_batches} batches succeeded")
elif job.job_status == JobStatus.FAILED:
logging.info(f"Job {job_id} failed. {job.batches_succeeded} out of {job.total_batches} batches succeeded")

except Exception as e:
logging.error('Error updating job and batch status: %s', e)
safe_db_operation(job_service.update_job_status, job_id, JobStatus.FAILED)



def update_batch_status(job_id, batch_status, batch_id, retries = None, bypass_retries=False):
try:
updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status)
logging.info(f"Status for batch {batch_id} as part of job {job_id} updated to {updated_batch_status}")
if updated_batch_status == BatchStatus.FAILED and (retries == config.MAX_BATCH_RETRIES or bypass_retries):
logging.info(f"Batch {batch_id} failed. Updating job status.")
update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id)
except Exception as e:
logging.error('Error updating batch status: %s', e)


12 changes: 6 additions & 6 deletions src/worker/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestWorker(unittest.TestCase):
@patch('services.database.job_service.get_job')
@patch('services.database.batch_service.get_batch')
@patch('worker.worker.embed_openai_batch')
@patch('worker.worker.update_batch_status')
@patch('shared.utils.update_batch_status')
def test_process_batch_success(
self,
mock_update_batch_and_job_status,
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_process_batch_success(
@patch('services.database.job_service.get_job')
@patch('services.database.batch_service.get_batch')
@patch('worker.worker.embed_openai_batch')
@patch('worker.worker.update_batch_status')
@patch('shared.utils.update_batch_status')
def test_process_batch_success_different_model(
self,
mock_update_batch_and_job_status,
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_process_batch_success_different_model(
@patch('services.database.job_service.get_job')
@patch('services.database.batch_service.get_batch')
@patch('worker.worker.embed_openai_batch')
@patch('worker.worker.update_batch_status')
@patch('shared.utils.update_batch_status')
def test_process_batch_failure_no_vectors(
self,
mock_update_batch_and_job_status,
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_process_batch_failure_no_vectors(
@patch('services.database.job_service.get_job')
@patch('services.database.batch_service.get_batch')
@patch('worker.worker.embed_openai_batch')
@patch('worker.worker.update_batch_status')
@patch('shared.utils.update_batch_status')
def test_process_batch_failure_openai(
self,
mock_update_batch_and_job_status,
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_process_batch_failure_openai(
@patch('services.database.job_service.get_job')
@patch('services.database.batch_service.get_batch')
@patch('worker.worker.embed_openai_batch')
@patch('worker.worker.update_batch_and_job_status')
@patch('shared.utils.update_batch_and_job_status')
def test_process_batch_failure_validate_chunks(
self,
mock_update_batch_and_job_status,
Expand Down Expand Up @@ -344,4 +344,4 @@ def test_chunk_sentence_by_characters_too_big(self):
self.assertEqual(type(chunks[0]), dict)

if __name__ == '__main__':
unittest.main()
unittest.main()
43 changes: 8 additions & 35 deletions src/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import worker.config as config
import services.database.batch_service as batch_service
import services.database.job_service as job_service
import shared.utils as utils
import tiktoken
from pika.exceptions import AMQPConnectionError
from shared.chunk_strategy import ChunkStrategy
Expand All @@ -22,7 +23,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from services.database.database import get_db, safe_db_operation
from shared.job_status import JobStatus
from shared.utils import send_embeddings_to_webhook, generate_uuid_from_tuple
from shared.utils import send_embeddings_to_webhook, generate_uuid_from_tuple
from services.rabbitmq.rabbit_service import create_connection_params
from worker.vector_uploader import VectorUploader

Expand Down Expand Up @@ -62,15 +63,15 @@ def process_batch(batch_id, source_data, vector_db_key, embeddings_api_key):
upload_to_vector_db(batch_id, embedded_chunks)
else:
logging.error(f"Failed to get OPEN AI embeddings for batch {batch.id}. Adding batch to retry queue.")
update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, batch.retries)
utils.update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, batch.retries)

except Exception as e:
logging.error('Error embedding batch: %s', e)
update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id)
utils.update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id)

else:
logging.error('Unsupported embeddings type: %s', embeddings_type.value)
update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, bypass_retries=True)
utils.update_batch_status(batch.job_id, BatchStatus.FAILED, batch.id, bypass_retries=True)

# NOTE: this method will embed mulitple chunks (a list of strings) at once and return a list of lists of floats (a list of embeddings)
# NOTE: this assumes that the embedded chunks are returned in the same order the raw chunks were sent
Expand Down Expand Up @@ -149,7 +150,7 @@ def chunk_data(batch, source_data, job):
chunked_data = validate_chunks(chunked_data, job.chunk_validation_url)

if not chunked_data:
update_batch_and_job_status(batch.job_id, BatchStatus.FAILED, batch.id)
utils.update_batch_and_job_status(batch.job_id, BatchStatus.FAILED, batch.id)
raise Exception("Failed to chunk data")
return chunked_data

Expand Down Expand Up @@ -323,16 +324,6 @@ def create_batches_for_embedding(chunks, max_batch_size):
embedding_batches = [chunks[i:i + max_batch_size] for i in range(0, len(chunks), max_batch_size)]
return embedding_batches

# TODO: refactor into utils
def update_batch_status(job_id, batch_status, batch_id, retries = None, bypass_retries=False):
try:
updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status)
logging.info(f"Status for batch {batch_id} as part of job {job_id} updated to {updated_batch_status}")
if updated_batch_status == BatchStatus.FAILED and (retries == config.MAX_BATCH_RETRIES or bypass_retries):
logging.info(f"Batch {batch_id} failed. Updating job status.")
update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id)
except Exception as e:
logging.error('Error updating batch status: %s', e)

def upload_to_vector_db(batch_id, text_embeddings_list):
try:
Expand All @@ -345,31 +336,13 @@ def upload_to_vector_db(batch_id, text_embeddings_list):

def process_webhook_response(response, job_id, batch_id):
if response and hasattr(response, 'status_code') and response.status_code == 200:
update_batch_and_job_status(job_id, BatchStatus.COMPLETED, batch_id)
utils.update_batch_and_job_status(job_id, BatchStatus.COMPLETED, batch_id)
else:
logging.error("Error sending embeddings to webhook. Response: %s", response)
update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id)
utils.update_batch_and_job_status(job_id, BatchStatus.FAILED, batch_id)
if response.json() and response.json()['error']:
logging.error("Error message: %s", response.json()['error'])

# TODO: refactor into utils
def update_batch_and_job_status(job_id, batch_status, batch_id):
try:
if not job_id and batch_id:
job = safe_db_operation(batch_service.get_batch, batch_id)
job_id = job.job_id
updated_batch_status = safe_db_operation(batch_service.update_batch_status, batch_id, batch_status)
job = safe_db_operation(job_service.update_job_with_batch, job_id, updated_batch_status)
if job.job_status == JobStatus.COMPLETED:
logging.info(f"Job {job_id} completed successfully")
elif job.job_status == JobStatus.PARTIALLY_COMPLETED:
logging.info(f"Job {job_id} partially completed. {job.batches_succeeded} out of {job.total_batches} batches succeeded")
elif job.job_status == JobStatus.FAILED:
logging.info(f"Job {job_id} failed. {job.batches_succeeded} out of {job.total_batches} batches succeeded")

except Exception as e:
logging.error('Error updating job and batch status: %s', e)
safe_db_operation(job_service.update_job_status, job_id, JobStatus.FAILED)

def callback(ch, method, properties, body):
try:
Expand Down