Skip to content

Commit 88d577b

Browse files
authored
feat: add first setup screen for LLM & Embedding models (#314) (bump:minor)
* fix: utf-8 txt reader * fix: revise vectorstore import and make it optional * feat: add cohere chat model with tool call support * fix: simplify citation pipeline * fix: improve citation logic * fix: improve decompose func call * fix: revise question rewrite prompt * fix: revise chat box default placeholder * fix: add key from ktem to cohere rerank * fix: conv name suggestion * fix: ignore default key cohere rerank * fix: improve test connection UI * fix: reorder requirements * feat: add first setup screen * fix: update requirements * fix: vectorstore tests * fix: update cohere version * fix: relax langchain core version * fix: add demo mode * fix: update flowsettings * fix: typo * fix: fix bool env passing
1 parent 0bdb9a3 commit 88d577b

File tree

27 files changed

+644
-141
lines changed

27 files changed

+644
-141
lines changed

flowsettings.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@
2424
except Exception:
2525
KH_APP_VERSION = "local"
2626

27+
KH_ENABLE_FIRST_SETUP = True
28+
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
29+
2730
# App can be ran from anywhere and it's not trivial to decide where to store app data.
2831
# So let's use the same directory as the flowsetting.py file.
2932
KH_APP_DATA_DIR = this_dir / "ktem_app_data"
33+
KH_APP_DATA_EXISTS = KH_APP_DATA_DIR.exists()
3034
KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True)
3135

3236
# User data directory
@@ -59,7 +63,9 @@
5963
KH_DOC_DIR = this_dir / "docs"
6064

6165
KH_MODE = "dev"
62-
KH_FEATURE_USER_MANAGEMENT = True
66+
KH_FEATURE_USER_MANAGEMENT = config(
67+
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
68+
)
6369
KH_USER_CAN_SEE_PUBLIC = None
6470
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
6571
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
@@ -202,6 +208,14 @@
202208
},
203209
"default": False,
204210
}
211+
KH_LLMS["cohere"] = {
212+
"spec": {
213+
"__type__": "kotaemon.llms.chats.LCCohereChat",
214+
"model_name": "command-r-plus-08-2024",
215+
"api_key": "your-key",
216+
},
217+
"default": False,
218+
}
205219

206220
# additional embeddings configurations
207221
KH_EMBEDDINGS["cohere"] = {

libs/kotaemon/kotaemon/embeddings/langchain_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __init__(
183183

184184
def _get_lc_class(self):
185185
try:
186-
from langchain_community.embeddings import CohereEmbeddings
186+
from langchain_cohere import CohereEmbeddings
187187
except ImportError:
188188
from langchain.embeddings import CohereEmbeddings
189189

Lines changed: 31 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterator, List
1+
from typing import List
22

33
from pydantic import BaseModel, Field
44

@@ -7,53 +7,14 @@
77
from kotaemon.llms import BaseLLM
88

99

10-
class FactWithEvidence(BaseModel):
11-
"""Class representing a single statement.
10+
class CiteEvidence(BaseModel):
11+
"""List of evidences (maximum 5) to support the answer."""
1212

13-
Each fact has a body and a list of sources.
14-
If there are multiple facts make sure to break them apart
15-
such that each one only uses a set of sources that are relevant to it.
16-
"""
17-
18-
fact: str = Field(..., description="Body of the sentence, as part of a response")
19-
substring_quote: List[str] = Field(
13+
evidences: List[str] = Field(
2014
...,
2115
description=(
2216
"Each source should be a direct quote from the context, "
23-
"as a substring of the original content"
24-
),
25-
)
26-
27-
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
28-
import regex
29-
30-
minor = quote
31-
major = context
32-
33-
errs_ = 0
34-
s = regex.search(f"({minor}){{e<={errs_}}}", major)
35-
while s is None and errs_ <= errs:
36-
errs_ += 1
37-
s = regex.search(f"({minor}){{e<={errs_}}}", major)
38-
39-
if s is not None:
40-
yield from s.spans()
41-
42-
def get_spans(self, context: str) -> Iterator[str]:
43-
for quote in self.substring_quote:
44-
yield from self._get_span(quote, context)
45-
46-
47-
class QuestionAnswer(BaseModel):
48-
"""A question and its answer as a list of facts each one should have a source.
49-
each sentence contains a body and a list of sources."""
50-
51-
question: str = Field(..., description="Question that was asked")
52-
answer: List[FactWithEvidence] = Field(
53-
...,
54-
description=(
55-
"Body of the answer, each fact should be "
56-
"its separate object with a body and a list of sources"
17+
"as a substring of the original content (max 15 words)."
5718
),
5819
)
5920

@@ -68,15 +29,16 @@ def run(self, context: str, question: str):
6829
return self.invoke(context, question)
6930

7031
def prepare_llm(self, context: str, question: str):
71-
schema = QuestionAnswer.schema()
32+
schema = CiteEvidence.schema()
7233
function = {
7334
"name": schema["title"],
7435
"description": schema["description"],
7536
"parameters": schema,
7637
}
7738
llm_kwargs = {
7839
"tools": [{"type": "function", "function": function}],
79-
"tool_choice": "auto",
40+
"tool_choice": "required",
41+
"tools_pydantic": [CiteEvidence],
8042
}
8143
messages = [
8244
SystemMessage(
@@ -85,7 +47,12 @@ def prepare_llm(self, context: str, question: str):
8547
"questions with correct and exact citations."
8648
)
8749
),
88-
HumanMessage(content="Answer question using the following context"),
50+
HumanMessage(
51+
content=(
52+
"Answer question using the following context. "
53+
"Use the provided function CiteEvidence() to cite your sources."
54+
)
55+
),
8956
HumanMessage(content=context),
9057
HumanMessage(content=f"Question: {question}"),
9158
HumanMessage(
@@ -103,33 +70,29 @@ def invoke(self, context: str, question: str):
10370
print("CitationPipeline: invoking LLM")
10471
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
10572
print("CitationPipeline: finish invoking LLM")
106-
if not llm_output.messages or not llm_output.additional_kwargs.get(
107-
"tool_calls"
108-
):
73+
if not llm_output.additional_kwargs.get("tool_calls"):
10974
return None
110-
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
111-
"arguments"
112-
]
113-
output = QuestionAnswer.parse_raw(function_output)
114-
except Exception as e:
115-
print(e)
116-
return None
11775

118-
return output
76+
first_func = llm_output.additional_kwargs["tool_calls"][0]
11977

120-
async def ainvoke(self, context: str, question: str):
121-
messages, llm_kwargs = self.prepare_llm(context, question)
78+
if "function" in first_func:
79+
# openai and cohere format
80+
function_output = first_func["function"]["arguments"]
81+
else:
82+
# anthropic format
83+
function_output = first_func["args"]
12284

123-
try:
124-
print("CitationPipeline: async invoking LLM")
125-
llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs)
126-
print("CitationPipeline: finish async invoking LLM")
127-
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
128-
"arguments"
129-
]
130-
output = QuestionAnswer.parse_raw(function_output)
85+
print("CitationPipeline:", function_output)
86+
87+
if isinstance(function_output, str):
88+
output = CiteEvidence.parse_raw(function_output)
89+
else:
90+
output = CiteEvidence.parse_obj(function_output)
13191
except Exception as e:
13292
print(e)
13393
return None
13494

13595
return output
96+
97+
async def ainvoke(self, context: str, question: str):
98+
raise NotImplementedError()

libs/kotaemon/kotaemon/indices/rankings/cohere.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class CohereReranking(BaseReranking):
1111
model_name: str = "rerank-multilingual-v2.0"
1212
cohere_api_key: str = config("COHERE_API_KEY", "")
13+
use_key_from_ktem: bool = False
1314

1415
def run(self, documents: list[Document], query: str) -> list[Document]:
1516
"""Use Cohere Reranker model to re-order documents
@@ -18,9 +19,25 @@ def run(self, documents: list[Document], query: str) -> list[Document]:
1819
import cohere
1920
except ImportError:
2021
raise ImportError(
21-
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
22+
"Please install Cohere `pip install cohere` to use Cohere Reranking"
2223
)
2324

25+
# try to get COHERE_API_KEY from embeddings
26+
if not self.cohere_api_key and self.use_key_from_ktem:
27+
try:
28+
from ktem.embeddings.manager import (
29+
embedding_models_manager as embeddings,
30+
)
31+
32+
cohere_model = embeddings.get("cohere")
33+
ktem_cohere_api_key = cohere_model._kwargs.get( # type: ignore
34+
"cohere_api_key"
35+
)
36+
if ktem_cohere_api_key != "your-key":
37+
self.cohere_api_key = ktem_cohere_api_key
38+
except Exception as e:
39+
print("Cannot get Cohere API key from `ktem`", e)
40+
2441
if not self.cohere_api_key:
2542
print("Cohere API key not found. Skipping reranking.")
2643
return documents
@@ -35,7 +52,7 @@ def run(self, documents: list[Document], query: str) -> list[Document]:
3552
response = cohere_client.rerank(
3653
model=self.model_name, query=query, documents=_docs
3754
)
38-
print("Cohere score", [r.relevance_score for r in response.results])
55+
# print("Cohere score", [r.relevance_score for r in response.results])
3956
for r in response.results:
4057
doc = documents[r.index]
4158
doc.metadata["cohere_reranking_score"] = r.relevance_score

libs/kotaemon/kotaemon/llms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
LCAnthropicChat,
1111
LCAzureChatOpenAI,
1212
LCChatOpenAI,
13+
LCCohereChat,
1314
LCGeminiChat,
1415
LlamaCppChat,
1516
)
@@ -31,6 +32,7 @@
3132
"ChatOpenAI",
3233
"LCAnthropicChat",
3334
"LCGeminiChat",
35+
"LCCohereChat",
3436
"LCAzureChatOpenAI",
3537
"LCChatOpenAI",
3638
"LlamaCppChat",

libs/kotaemon/kotaemon/llms/chats/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
LCAzureChatOpenAI,
66
LCChatMixin,
77
LCChatOpenAI,
8+
LCCohereChat,
89
LCGeminiChat,
910
)
1011
from .llamacpp import LlamaCppChat
@@ -18,6 +19,7 @@
1819
"ChatOpenAI",
1920
"LCAnthropicChat",
2021
"LCGeminiChat",
22+
"LCCohereChat",
2123
"LCChatOpenAI",
2224
"LCAzureChatOpenAI",
2325
"LCChatMixin",

libs/kotaemon/kotaemon/llms/chats/langchain_based.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def _get_lc_class(self):
1818
"Please return the relevant Langchain class in in _get_lc_class"
1919
)
2020

21+
def _get_tool_call_kwargs(self):
22+
return {}
23+
2124
def __init__(self, stream: bool = False, **params):
2225
self._lc_class = self._get_lc_class()
2326
self._obj = self._lc_class(**params)
@@ -56,9 +59,7 @@ def prepare_response(self, pred):
5659
total_tokens = pred.llm_output["token_usage"]["total_tokens"]
5760
prompt_tokens = pred.llm_output["token_usage"]["prompt_tokens"]
5861
except Exception:
59-
logger.warning(
60-
f"Cannot get token usage from LLM output for {self._lc_class.__name__}"
61-
)
62+
pass
6263

6364
return LLMInterface(
6465
text=all_text[0] if len(all_text) > 0 else "",
@@ -83,8 +84,30 @@ def invoke(
8384
LLMInterface: generated response
8485
"""
8586
input_ = self.prepare_message(messages)
86-
pred = self._obj.generate(messages=[input_], **kwargs)
87-
return self.prepare_response(pred)
87+
88+
if "tools_pydantic" in kwargs:
89+
tools = kwargs.pop(
90+
"tools_pydantic",
91+
)
92+
lc_tool_call = self._obj.bind_tools(tools)
93+
pred = lc_tool_call.invoke(
94+
input_,
95+
**self._get_tool_call_kwargs(),
96+
)
97+
if pred.tool_calls:
98+
tool_calls = pred.tool_calls
99+
else:
100+
tool_calls = pred.additional_kwargs.get("tool_calls", [])
101+
102+
output = LLMInterface(
103+
content="",
104+
additional_kwargs={"tool_calls": tool_calls},
105+
)
106+
else:
107+
pred = self._obj.generate(messages=[input_], **kwargs)
108+
output = self.prepare_response(pred)
109+
110+
return output
88111

89112
async def ainvoke(
90113
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
@@ -235,6 +258,9 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
235258
required=True,
236259
)
237260

261+
def _get_tool_call_kwargs(self):
262+
return {"tool_choice": {"type": "any"}}
263+
238264
def __init__(
239265
self,
240266
api_key: str | None = None,
@@ -291,3 +317,35 @@ def _get_lc_class(self):
291317
raise ImportError("Please install langchain-google-genai")
292318

293319
return ChatGoogleGenerativeAI
320+
321+
322+
class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
323+
api_key: str = Param(
324+
help="API key (https://dashboard.cohere.com/api-keys)", required=True
325+
)
326+
model_name: str = Param(
327+
help=("Model name to use (https://dashboard.cohere.com/playground/chat)"),
328+
required=True,
329+
)
330+
331+
def __init__(
332+
self,
333+
api_key: str | None = None,
334+
model_name: str | None = None,
335+
temperature: float = 0.7,
336+
**params,
337+
):
338+
super().__init__(
339+
cohere_api_key=api_key,
340+
model_name=model_name,
341+
temperature=temperature,
342+
**params,
343+
)
344+
345+
def _get_lc_class(self):
346+
try:
347+
from langchain_cohere import ChatCohere
348+
except ImportError:
349+
raise ImportError("Please install langchain-cohere")
350+
351+
return ChatCohere

libs/kotaemon/kotaemon/llms/chats/openai.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def prepare_client(self, async_version: bool = False):
292292

293293
def openai_response(self, client, **kwargs):
294294
"""Get the openai response"""
295+
if "tools_pydantic" in kwargs:
296+
kwargs.pop("tools_pydantic")
297+
295298
params_ = {
296299
"model": self.model,
297300
"temperature": self.temperature,
@@ -360,6 +363,9 @@ def prepare_client(self, async_version: bool = False):
360363

361364
def openai_response(self, client, **kwargs):
362365
"""Get the openai response"""
366+
if "tools_pydantic" in kwargs:
367+
kwargs.pop("tools_pydantic")
368+
363369
params_ = {
364370
"model": self.azure_deployment,
365371
"temperature": self.temperature,

0 commit comments

Comments
 (0)