Skip to content

changes! #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
.pytest_cache
.ruff_cache
.tox
.venv
.gitignore
makefile
__pycache__
Expand Down
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[submodule "speech-to-speech"]
path = speech-to-speech
url = https://github.com/huggingface/speech-to-speech
[submodule "fast-unidic"]
path = fast-unidic
url = https://huggingface.co/andito/fast-unidic
17 changes: 7 additions & 10 deletions dockerfiles/pytorch/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
ARG BASE_IMAGE=nvidia/cuda:12.1.0-devel-ubuntu22.04

FROM $BASE_IMAGE as base
FROM $BASE_IMAGE AS base
SHELL ["/bin/bash", "-c"]

LABEL maintainer="Hugging Face"
LABEL maintainer="Andres Marafioti"

ENV DEBIAN_FRONTEND=noninteractive

Expand Down Expand Up @@ -35,20 +35,17 @@ RUN apt-get update && \

# Copying only necessary files as filtered by .dockerignore
COPY . .

# install wheel and setuptools
RUN pip install --no-cache-dir --upgrade pip ".[torch,st,diffusers]"
RUN pip install uv

# copy application
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py
COPY fast-unidic /usr/local/lib/python3.10/dist-packages/unidic/dicdir/

RUN uv pip install -r speech-to-speech/requirements.txt --system
RUN uv pip install flash-attn --no-build-isolation --system

# copy entrypoint and change permissions
COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh

ENTRYPOINT ["bash", "-c", "./entrypoint.sh"]

FROM base AS vertex

# Install `google` extra for Vertex AI compatibility
RUN pip install --no-cache-dir --upgrade ".[google]"
1 change: 1 addition & 0 deletions fast-unidic
Submodule fast-unidic added at 65e248
11 changes: 1 addition & 10 deletions scripts/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
#!/bin/bash

# Define the default port
PORT=5000
PORT=80 # 80 is the default port for inference endpoints

# Check if AIP_MODE is set and adjust the port for Vertex AI
if [[ ! -z "${AIP_MODE}" ]]; then
PORT=${AIP_HTTP_PORT}
fi

# Check if HF_MODEL_DIR is set and if not skip installing custom dependencies
if [[ ! -z "${HF_MODEL_DIR}" ]]; then
# Check if requirements.txt exists and if so install dependencies
if [ -f "${HF_MODEL_DIR}/requirements.txt" ]; then
echo "Installing custom dependencies from ${HF_MODEL_DIR}/requirements.txt"
pip install -r ${HF_MODEL_DIR}/requirements.txt --no-cache-dir;
fi
fi

# Start the server
uvicorn webservice_starlette:app --host 0.0.0.0 --port ${PORT}
1 change: 1 addition & 0 deletions speech-to-speech
Submodule speech-to-speech added at 9998e2
167 changes: 35 additions & 132 deletions src/huggingface_inference_toolkit/webservice_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,147 +2,50 @@
from pathlib import Path
from time import perf_counter

import orjson
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse, Response
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.endpoints import WebSocketEndpoint
from starlette.routing import WebSocketRoute

from huggingface_inference_toolkit.async_utils import async_handler_call
from huggingface_inference_toolkit.const import (
HF_FRAMEWORK,
HF_HUB_TOKEN,
HF_MODEL_DIR,
HF_MODEL_ID,
HF_REVISION,
HF_TASK,
)
from huggingface_inference_toolkit.handler import (
get_inference_handler_either_custom_or_default_handler,
)
from huggingface_inference_toolkit.logging import logger
from huggingface_inference_toolkit.serialization.base import ContentType
from huggingface_inference_toolkit.serialization.json_utils import Jsoner
from huggingface_inference_toolkit.utils import (
_load_repository_from_hf,
convert_params_to_int_or_bool,
)
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
import asyncio

import sys
sys.path.append('speech-to-speech')
from s2s_handler import EndpointHandler

async def prepare_model_artifacts():
async def prepare_handler():
global inference_handler
# 1. check if model artifacts available in HF_MODEL_DIR
if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0:
# 2. if not available, try to load from HF_MODEL_ID
if HF_MODEL_ID is not None:
_load_repository_from_hf(
repository_id=HF_MODEL_ID,
target_dir=HF_MODEL_DIR,
framework=HF_FRAMEWORK,
revision=HF_REVISION,
hf_hub_token=HF_HUB_TOKEN,
)
# 3. check if in Vertex AI environment and load from GCS
# If artifactUri not on Model Creation not set returns an empty string
elif len(os.environ.get("AIP_STORAGE_URI", "")) > 0:
_load_repository_from_gcs(
os.environ["AIP_STORAGE_URI"], target_dir=HF_MODEL_DIR
)
# 4. if not available, raise error
else:
raise ValueError(
f"""Can't initialize model.
Please set env HF_MODEL_DIR or provider a HF_MODEL_ID.
Provided values are:
HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}"""
)

logger.info(f"Initializing model from directory:{HF_MODEL_DIR}")
# 2. determine correct inference handler
inference_handler = get_inference_handler_either_custom_or_default_handler(
HF_MODEL_DIR, task=HF_TASK
)
inference_handler = EndpointHandler()
inference_handler.pipeline_manager.start()
logger.info("Model initialized successfully")


async def health(request):
return PlainTextResponse("Ok")


async def predict(request):
try:
# extracts content from request
content_type = request.headers.get("content-Type", None)
# try to deserialize payload
deserialized_body = ContentType.get_deserializer(content_type).deserialize(
await request.body()
)
# checks if input schema is correct
if "inputs" not in deserialized_body and "instances" not in deserialized_body:
raise ValueError(
f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}"
)

# check for query parameter and add them to the body
if request.query_params and "parameters" not in deserialized_body:
deserialized_body["parameters"] = convert_params_to_int_or_bool(
dict(request.query_params)
)

# tracks request time
start_time = perf_counter()
# run async not blocking call
pred = await async_handler_call(inference_handler, deserialized_body)
# log request time
logger.info(
f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms"
)

# response extracts content from request
accept = request.headers.get("accept", None)
if accept is None or accept == "*/*":
accept = "application/json"
# deserialized and resonds with json
serialized_response_body = ContentType.get_serializer(accept).serialize(
pred, accept
)
return Response(serialized_response_body, media_type=accept)
except Exception as e:
logger.error(e)
return Response(
Jsoner.serialize({"error": str(e)}),
status_code=400,
media_type="application/json",
)


# Create app based on which cloud environment is used
if os.getenv("AIP_MODE", None) == "PREDICTION":
logger.info("Running in Vertex AI environment")
# extract routes from environment variables
_predict_route = os.getenv("AIP_PREDICT_ROUTE", None)
_health_route = os.getenv("AIP_HEALTH_ROUTE", None)
if _predict_route is None or _health_route is None:
raise ValueError(
"AIP_PREDICT_ROUTE and AIP_HEALTH_ROUTE need to be set in Vertex AI environment"
)

app = Starlette(
debug=False,
routes=[
Route(_health_route, health, methods=["GET"]),
Route(_predict_route, predict, methods=["POST"]),
],
on_startup=[prepare_model_artifacts],
)
else:
app = Starlette(
debug=False,
routes=[
Route("/", health, methods=["GET"]),
Route("/health", health, methods=["GET"]),
Route("/", predict, methods=["POST"]),
Route("/predict", predict, methods=["POST"]),
],
on_startup=[prepare_model_artifacts],
)
class WebSocketPredictEndpoint(WebSocketEndpoint):
encoding = "bytes"

async def on_connect(self, websocket):
await websocket.accept()

async def on_receive(self, websocket, data):
# Run the handler's processing in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
response_data = await loop.run_in_executor(None, inference_handler.process_streaming_data, data)
if response_data:
await websocket.send_bytes(response_data)

async def on_disconnect(self, websocket, close_code):
pass

app = Starlette(
debug=False,
routes=[
Route("/", health, methods=["GET"]),
Route("/health", health, methods=["GET"]),
WebSocketRoute("/ws", WebSocketPredictEndpoint)
],
on_startup=[prepare_handler],
)