Skip to content

added text-generation-webui infrence support #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,37 @@ character_maker(

The prompt above typically takes just over 2.5 seconds to complete on a A6000 GPU when using LLaMA 7B. If we were to run the same prompt adapted to be a single generation call (the standard practice today) it takes about 5 seconds to complete (4 of which is token generation and 1 of which is prompt processing). *This means Guidance acceleration delivers a 2x speedup over the standard approach for this prompt.* In practice the exact speed-up factor depends on the format of your specific prompt and the size of your model (larger models benefit more). Acceleration is also only supported for Transformers LLMs at the moment. See the [notebook](https://github.com/microsoft/guidance/blob/main/notebooks/guidance_acceleration.ipynb) for more details.


This class allows integration with [text-generation-webui](https://github.com/oobabooga/text-generation-webui) which
allows for easy setup and enables running Larger Models with less VRAM. Additonally ensures that machine which runs guidance
and [text-generation-webui](https://github.com/oobabooga/text-generation-webui) infrence server to do not need to be the same.



````python

guidance.llm = guidance.llms.TGWUI("http://127.0.0.1:9000")

# define the prompt
character_maker = guidance("""The following is a character profile for a Soccer Game in JSON format.
```json
{
"Nationality": "{{nationality}}",
"league": "{{league}}",
"name": "{{gen 'name'}}",
"age": {{gen 'age' pattern='[0-9]+' stop=','}},
"overall": {{gen 'overall' pattern='[0-9]+' stop=','}},
"description": "{{gen 'description' temperature=1.25}}",
}```""")

# generate a character
character_maker(
nationality="Türkiye",
league="Premier League"
)
````


## Token healing (<a href="https://github.com/microsoft/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb">notebook</a>)

The standard greedy tokenizations used by most language models introduce a subtle and powerful bias that can have all kinds of unintended consequences for your prompts. Using a process we call "token healing" `guidance` automatically removes these surprising biases, freeing you to focus on designing the prompts you want without worrying about tokenization artifacts.
Expand Down
1 change: 1 addition & 0 deletions guidance/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._transformers import Transformers
from ._mock import Mock
from ._llm import LLM, LLMSession, SyncSession
from ._tgwui import TGWUI
from ._deep_speed import DeepSpeed
from . import transformers
from . import caches
130 changes: 130 additions & 0 deletions guidance/llms/_tgwui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import time
import collections
import regex
import pygtrie
import traceback
import queue
import threading
import logging
import collections.abc
import asyncio
import requests

from typing import Any, Dict, Optional, Callable
from ._llm import LLM, LLMSession, SyncSession




class TGWUI(LLM):
instruction_template = None
def __init__(self, base_url, chat_mode=False):
self.chat_mode = False # by default models are not in role-based chat mode
self.base_url = base_url
self.model_info= self.getModelInfo()
self.model_name = self.model_info["model_name"]
if self.model_info['instruction_following'] != chat_mode:
print(str("Warning the model "+self.model_info["model_name"]+": "+str(self.model_info['instruction_following']) +" however chat_mode: "+str(chat_mode)))



def getModelInfo(self):
response = requests.get(self.base_url+'/api/v1/model')
resp=response.json()["results"]
return resp


def __getitem__(self, key):
"""Gets an attribute from the LLM."""
return getattr(self, key)

def session(self, asynchronous=False):
"""Creates a session for the LLM.

This implementation is meant to be overridden by subclasses.
"""
return TWGUISession(self)

def encode(self, string, **kwargs):
tmp={"text": string, "kwargs": kwargs}
response = requests.post(self.base_url+'/api/v1/encode',json=tmp)
resp=response.json()
return resp['results'][0]['tokens']

def decode(self, tokens, **kwargs):
tmp={"tokens": tokens, "kwargs": kwargs}
response = requests.post(self.base_url+'/api/v1/decode',json=tmp)
resp=response.json()
return resp['results'][0]['ids']


def role_start(self, role):

if self.model_info['instruction_following'] == False:
assert (False), "Model does not support chat mode, may be next word completion model"
return ''
elif role == 'user':
return self.model_info['instruction_template']['user']
elif role == 'assistant' or role == 'system':
return self.model_info['instruction_template']['bot']
else:
return ''


def role_end(self, role):
return ''

def end_of_text(self):
return self.model_info['eos_token']






class TWGUISession(LLMSession):
def __init__(self, llm):
self.llm = llm
self._call_counts = {} # tracks the number of repeated identical calls to the LLM with non-zero temperature

def __enter__(self):
return self

async def __call__(
self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None,
top_p=1.0, echo=False, logit_bias=None, token_healing=None, pattern=None, stream=None,
cache_seed=0, caching=None, **completion_kwargs
):
args={
"prompt":prompt, "stop": stop, "stop_regex":stop_regex, "temperature": temperature, "n":n,
"max_tokens":max_tokens, "logprobs":logprobs, "top_p":top_p, "echo":echo, "logit_bias":logit_bias,
"token_healing":token_healing, "pattern":pattern, "stream":stream, "cache_seed":cache_seed,
"completion_kwargs":completion_kwargs, "chat":self.llm.chat_mode
}
response = requests.post(self.llm.base_url+'/api/v1/call',json=args)
resp=response.json()
print(resp["choices"][0]["text"])
return resp["choices"][0]["text"]

def __exit__(self, exc_type, exc_value, traceback):
pass

def _gen_key(self, args_dict):
del args_dict["self"] # skip the "self" arg
return "_---_".join([str(v) for v in ([args_dict[k] for k in args_dict] + [self.llm.model_name, self.llm.__class__.__name__, self.llm.cache_version])])

def _cache_params(self, args_dict) -> Dict[str, Any]:
"""get the parameters for generating the cache key"""
key = self._gen_key(args_dict)
# if we have non-zero temperature we include the call count in the cache key
if args_dict.get("temperature", 0) > 0:
args_dict["call_count"] = self._call_counts.get(key, 0)

# increment the call count
self._call_counts[key] = args_dict["call_count"] + 1
args_dict["model_name"] = self.llm.model_name
args_dict["cache_version"] = self.llm.cache_version
args_dict["class_name"] = self.llm.__class__.__name__

return args_dict
81 changes: 80 additions & 1 deletion tests/test_program.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import guidance
import pytest
from .utils import get_llm
from utils import get_llm

def test_chat_stream():
""" Test the behavior of `stream=True` for an openai chat endpoint.
Expand Down Expand Up @@ -159,3 +159,82 @@ async def call_async():
"Expect the exception to be propagated"

loop.close()




# TGWUI test some have issues TODO
def test_basic_gen():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':False})
prompt = guidance("if you give a mouse a cookie, {{gen 'next_verse' temperature=0.7}}", llm=model)
res = prompt()
assert res is not None # Assuming you expect a non-null result

def test_encode_to_decode():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':False})
string ="Hello World"
tokens= model.encode(string)
converted = model.decode(tokens)
print(converted)
assert string == converted

def test_chat_mode():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':True})
experts = guidance('''
{{#system~}}
You are a helpful and terse assistant.
{{~/system}}

{{#user~}}
I want a response to the following question:
{{query}}
Name 3 world-class experts (past or present) who would be great at answering this?
Don't answer the question yet.
{{~/user}}

{{#assistant~}}
{{gen 'expert_names' temperature=0 max_tokens=300}}
{{~/assistant}}

{{#user~}}
Great, now please answer the question as if these experts had collaborated in writing a joint anonymous answer.
{{~/user}}

{{#assistant~}}
{{gen 'answer' temperature=0 max_tokens=500}}
{{~/assistant}}
''', llm=model)
res = experts(query='How can I be more productive?')
print(res)
assert res is not None # Assuming you expect a non-null result

def test_basic_geneach():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':False})
prompt = guidance("""{{#geneach 'items' num_iterations=3}} "{{gen 'this'}}",{{/geneach}}""", llm=model)
res = prompt()
print(res)
assert res is not None # Assuming you expect a non-null result

def test_basic_pattern_gen():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':False})
prompt = guidance("strength: {{gen 'strength' pattern='[0-9]+' temperature=0.7}}", llm=model)
res = prompt()
print(res)
assert isinstance(int(res), int) # Assuming you expect a numeric string

def test_basic_select():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':False})
valid_weapons = ["sword", "axe", "mace", "spear", "bow", "crossbow"]
prompt = guidance("weapon {{select 'weapon' options=valid_weapons}}", valid_weapons=valid_weapons, llm=model)
res = prompt()
print(res)
assert res in valid_weapons # Expect selected weapon to be in the valid_weapons list


def test_basic_stop():
model = get_llm("tgwui:http://127.0.0.1:9555",kwargs={'chat_mode':False})
prompt = guidance("how {{gen 'strength' stop=',' temperature=0.7}}", llm=model)
res = str(prompt())
print(res)
assert res.endswith(',')

26 changes: 26 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def get_llm(model_name, caching=False, **kwargs):
return get_openai_llm(model_name[7:], caching, **kwargs)
elif model_name.startswith("transformers:"):
return get_transformers_llm(model_name[13:], caching, **kwargs)
elif model_name.startswith("tgwui:"):
return get_tgwui_llm(model_name[6:], caching, **kwargs)

def get_openai_llm(model_name, caching=False, **kwargs):
""" Get an OpenAI LLM with model reuse and smart test skipping.
Expand All @@ -29,6 +31,7 @@ def get_openai_llm(model_name, caching=False, **kwargs):

transformers_model_cache = {}


def get_transformers_llm(model_name, caching=False):
""" Get an OpenAI LLM with model reuse.
"""
Expand All @@ -40,3 +43,26 @@ def get_transformers_llm(model_name, caching=False):
transformers_model_cache[key] = guidance.llms.Transformers(model_name, caching=caching)

return transformers_model_cache[key]




tgwui_model_cache = {}
def get_tgwui_llm( base_url, caching=False, **kwargs):
""" Get an tgwui LLM with model reuse and smart test skipping.
"""

# we cache the models so lots of tests using the same model don't have to
# load it over and over again
chat_mode=False
if chat_mode in kwargs:
chat_mode=kwargs['chat_mode']

key = "tgwui"+"_"+str(caching)
if key not in tgwui_model_cache:
tgwui_model_cache[key] = guidance.llms.TGWUI(base_url, chat_mode=chat_mode)
llm = tgwui_model_cache[key]
return llm