Skip to content

Add basic structure for GDrive OAuth tool #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 313 additions & 168 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pypdf = "^4.2.0"
pyjwt = "^2.8.0"
pydantic-settings = "^2.3.1"
transformers = "^4.41.2"
google_auth_oauthlib ="^1.2.0"
google-auth-httplib2="^0.2.0"
google-api-python-client="^2.133.0"

[tool.poetry.group.dev]
optional = true
Expand Down
45 changes: 45 additions & 0 deletions src/backend/alembic/versions/a6efd9f047b4_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""empty message

Revision ID: a6efd9f047b4
Revises: 8bc604e45f2d
Create Date: 2024-06-21 10:04:26.857068

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a6efd9f047b4"
down_revision: Union[str, None] = "8bc604e45f2d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"tool_auth",
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("tool_id", sa.Text(), nullable=False),
sa.Column("token_type", sa.Text(), nullable=False),
sa.Column("encrypted_access_token", sa.LargeBinary(), nullable=False),
sa.Column("encrypted_refresh_token", sa.LargeBinary(), nullable=False),
sa.Column("expires_at", sa.DateTime(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", "tool_id", name="_user_tool_uc"),
)
op.create_index("tool_auth_index", "tool_auth", ["user_id", "tool_id"], unique=True)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("tool_auth_index", table_name="tool_auth")
op.drop_table("tool_auth")
# ### end Alembic commands ###
24 changes: 24 additions & 0 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
SearchFileTool,
TavilyInternetSearch,
)
from backend.tools.google_drive import (
GOOGLE_DRIVE_TOOL_ID,
GoogleDrive,
GoogleDriveAuth,
)

"""
List of available tools. Each tool should have a name, implementation, is_visible and category.
Expand All @@ -32,6 +37,7 @@ class ToolName(StrEnum):
Python_Interpreter = "toolkit_python_interpreter"
Calculator = "calculator"
Tavily_Internet_Search = "web_search"
Google_Drive = GOOGLE_DRIVE_TOOL_ID


ALL_TOOLS = {
Expand Down Expand Up @@ -143,6 +149,24 @@ class ToolName(StrEnum):
category=Category.DataLoader,
description="Returns a list of relevant document snippets for a textual query retrieved from the internet using Tavily.",
),
ToolName.Google_Drive: ManagedTool(
name=ToolName.Google_Drive,
display_name="Google Drive",
implementation=GoogleDrive,
parameter_definitions={
"query": {
"description": "Query to search google drive documents with.",
"type": "str",
"required": True,
}
},
is_visible=False,
is_available=GoogleDrive.is_available(),
auth_implementation=GoogleDriveAuth,
error_message="Google Drive not available",
category=Category.DataLoader,
description="Returns a list of relevant document snippets for the user's google drive.",
),
}


Expand Down
78 changes: 78 additions & 0 deletions src/backend/crud/tool_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from sqlalchemy.orm import Session

from backend.database_models.tool_auth import ToolAuth


def create_tool_auth(db: Session, tool_auth: ToolAuth) -> ToolAuth:
"""
Create a new tool auth link.

Tool Auth stores the access tokens for tool's that need auth

Args:
db (Session): Database session.
tool_auth (ToolAuth): ToolAuth to be created.

Returns:
ToolAuth: Created tool auth.
"""
db.add(tool_auth)
db.commit()
db.refresh(tool_auth)
return tool_auth


def get_tool_auth(db: Session, tool_id: str, user_id: str) -> ToolAuth:
"""
Get an tool auth by user and tool ID.

Args:
db (Session): Database session.
user_id (str): User ID.
tool_id (str): Tool ID.

Returns:
ToolAuth: ToolAuth with the given ID.
"""
return (
db.query(ToolAuth)
.filter(ToolAuth.tool_id == tool_id, ToolAuth.user_id == user_id)
.first()
)


def update_tool_auth(
db: Session, tool_auth: ToolAuth, new_tool_auth: ToolAuth
) -> ToolAuth:
"""
Update a tool auth by user id and tool id.

Args:
db (Session): Database session.
tool_auth (ToolAuth): Tool auth to be updated.
new_tool_auth (ToolAuth): New tool auth data.

Returns:
ToolAuth: Updated tool auth.
"""
for attr, value in new_tool_auth.model_dump().items():
setattr(tool_auth, attr, value)
db.commit()
db.refresh(tool_auth)
return tool_auth


def delete_tool_auth(db: Session, user_id: str, tool_id: str) -> None:
"""
Delete an tool_auth by user and tool ID.

Args:
db (Session): Database session.
user_id (str): User ID.
tool_id (str): Tool ID.
"""
tool_auth = db.query(ToolAuth).filter(
ToolAuth.tool_id == tool_id, ToolAuth.user_id == user_id
)
tool_auth.delete()
db.commit()
1 change: 1 addition & 0 deletions src/backend/database_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from backend.database_models.file import *
from backend.database_models.message import *
from backend.database_models.organization import *
from backend.database_models.tool_auth import *
from backend.database_models.user import *
17 changes: 17 additions & 0 deletions src/backend/database_models/tool_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from sqlalchemy import DateTime, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base


class ToolAuth(Base):
__tablename__ = "tool_auth"

user_id: Mapped[str] = mapped_column(Text, nullable=False)
tool_id: Mapped[str] = mapped_column(Text, nullable=False)
token_type: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_access_token: Mapped[bytes] = mapped_column()
encrypted_refresh_token: Mapped[bytes] = mapped_column()
expires_at = mapped_column(DateTime, nullable=False)

__table_args__ = (UniqueConstraint("user_id", "tool_id", name="_user_tool_uc"),)
21 changes: 21 additions & 0 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
import os
from typing import Union

from authlib.integrations.starlette_client import OAuthError
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from starlette.requests import Request

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.config.routers import RouterName
from backend.config.tools import ALL_TOOLS
from backend.crud import blacklist as blacklist_crud
from backend.database_models import Blacklist
from backend.database_models.database import DBSessionDep
Expand Down Expand Up @@ -191,3 +195,20 @@ async def authorize(
token = JWTService().create_and_encode_jwt(user)

return {"token": token}


# Tool based auth is experimental and in development
@router.get("/tool/auth")
async def login(request: Request, session: DBSessionDep):
redirect_url = os.getenv("FRONTEND_HOSTNAME")
# TODO: Store user id and tool id in the DB for state key
state = json.loads(request.query_params.get("state"))
tool_id = state["tool_id"]
if tool_id in ALL_TOOLS:
tool = ALL_TOOLS.get(tool_id)
if tool.auth_implementation is not None:
err = tool.auth_implementation.process_auth_token(request, session)
if err:
return RedirectResponse(redirect_url + "?error=" + err)
response = RedirectResponse(redirect_url)
return response
20 changes: 16 additions & 4 deletions src/backend/routers/tool.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from typing import Optional

from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request

from backend.config.routers import RouterName
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import agent as agent_crud
from backend.database_models.database import DBSessionDep
from backend.schemas.tool import ManagedTool
from backend.services.auth.utils import get_header_user_id

router = APIRouter(prefix="/v1/tools")
router.name = RouterName.TOOL


@router.get("", response_model=list[ManagedTool])
def list_tools(session: DBSessionDep, agent_id: str | None = None) -> list[ManagedTool]:
def list_tools(
request: Request, session: DBSessionDep, agent_id: str | None = None
) -> list[ManagedTool]:
"""
List all available tools.

Returns:
list[ManagedTool]: List of available tools.
"""
all_tools = AVAILABLE_TOOLS.values()
if agent_id:
agent_tools = []
agent = agent_crud.get_agent_by_id(session, agent_id)
Expand All @@ -32,6 +36,14 @@ def list_tools(session: DBSessionDep, agent_id: str | None = None) -> list[Manag

for tool in agent.tools:
agent_tools.append(AVAILABLE_TOOLS[tool])
return agent_tools
all_tools = agent_tools

return AVAILABLE_TOOLS.values()
user_id = get_header_user_id(request)
for tool in all_tools:
if tool.is_available and tool.auth_implementation is not None:
tool.is_auth_required = tool.auth_implementation.is_auth_required(
session, user_id
)
tool.auth_url = tool.auth_implementation.get_auth_url(user_id)

return all_tools
5 changes: 5 additions & 0 deletions src/backend/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class ManagedTool(Tool):
is_available: bool = False
error_message: Optional[str] = ""
category: Category = Category.DataLoader

is_auth_required: bool = False # Per user
auth_url: Optional[str] = "" # Per user

implementation: Any = Field(exclude=True)
auth_implementation: Any = Field(default=None, exclude=True)

class Config:
from_attributes = True
Expand Down
22 changes: 22 additions & 0 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from abc import abstractmethod
from typing import Any, Dict, List

from fastapi import Request

from backend.database_models.database import DBSessionDep


class BaseTool:
"""
Expand All @@ -13,3 +17,21 @@ def is_available(cls) -> bool: ...

@abstractmethod
def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: ...


class BaseAuth:
"""
Abstract base class for auth for tools
"""

@classmethod
@abstractmethod
def get_auth_url(user_id: str) -> str: ...

@classmethod
@abstractmethod
def is_auth_required(session: DBSessionDep, user_id: str) -> bool: ...

@classmethod
@abstractmethod
def process_auth_token(request: Request, session: DBSessionDep) -> str: ...
Loading
Loading