|
27 | 27 | "\n",
|
28 | 28 | "Many use cases such as building a chatbot require text (text2text) generation models like **[BloomZ 7B1](https://huggingface.co/bigscience/bloomz-7b1)**, **[Flan T5 XXL](https://huggingface.co/google/flan-t5-xxl)**, and **[Flan T5 UL2](https://huggingface.co/google/flan-ul2)** to respond to user questions with insightful answers. The **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.\n",
|
29 | 29 | "\n",
|
30 |
| - "In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **GPT-J-6B** embedding model. \n", |
| 30 | + "In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **MiniLM-L6-v2** embedding model. \n", |
31 | 31 | "\n",
|
32 | 32 | "**This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.**"
|
33 | 33 | ]
|
|
45 | 45 | "cell_type": "code",
|
46 | 46 | "execution_count": null,
|
47 | 47 | "metadata": {
|
48 |
| - "collapsed": false, |
49 | 48 | "jupyter": {
|
50 | 49 | "outputs_hidden": false
|
51 | 50 | },
|
|
57 | 56 | "outputs": [],
|
58 | 57 | "source": [
|
59 | 58 | "!pip install --upgrade sagemaker --quiet\n",
|
60 |
| - "!pip install ipywidgets==7.0.0 --quiet\n", |
61 |
| - "!pip install langchain==0.0.148 --quiet\n", |
62 |
| - "!pip install faiss-cpu --quiet" |
| 59 | + "!pip install faiss-cpu --quiet\n", |
| 60 | + "!pip install langchain --quiet" |
63 | 61 | ]
|
64 | 62 | },
|
65 | 63 | {
|
|
70 | 68 | },
|
71 | 69 | "outputs": [],
|
72 | 70 | "source": [
|
73 |
| - "import time\n", |
74 |
| - "import sagemaker, boto3, json\n", |
75 |
| - "from sagemaker.session import Session\n", |
76 |
| - "from sagemaker.model import Model\n", |
77 |
| - "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", |
78 |
| - "from sagemaker.predictor import Predictor\n", |
| 71 | + "from sagemaker import Session\n", |
79 | 72 | "from sagemaker.utils import name_from_base\n",
|
80 |
| - "from typing import Any, Dict, List, Optional\n", |
81 |
| - "from langchain.embeddings import SagemakerEndpointEmbeddings\n", |
82 |
| - "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", |
| 73 | + "from sagemaker.jumpstart.model import JumpStartModel\n", |
83 | 74 | "\n",
|
84 |
| - "sagemaker_session = Session()\n", |
85 |
| - "aws_role = sagemaker_session.get_caller_identity_arn()\n", |
86 |
| - "aws_region = boto3.Session().region_name\n", |
87 |
| - "sess = sagemaker.Session()\n", |
88 |
| - "model_version = \"1.*\"" |
89 |
| - ] |
90 |
| - }, |
91 |
| - { |
92 |
| - "cell_type": "code", |
93 |
| - "execution_count": null, |
94 |
| - "metadata": { |
95 |
| - "tags": [] |
96 |
| - }, |
97 |
| - "outputs": [], |
98 |
| - "source": [ |
99 |
| - "def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n", |
100 |
| - " client = boto3.client(\"runtime.sagemaker\")\n", |
101 |
| - " response = client.invoke_endpoint(\n", |
102 |
| - " EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n", |
103 |
| - " )\n", |
104 |
| - " return response\n", |
105 |
| - "\n", |
106 |
| - "\n", |
107 |
| - "def parse_response_model_flan_t5(query_response):\n", |
108 |
| - " model_predictions = json.loads(query_response[\"Body\"].read())\n", |
109 |
| - " generated_text = model_predictions[\"generated_texts\"]\n", |
110 |
| - " return generated_text\n", |
111 |
| - "\n", |
112 |
| - "\n", |
113 |
| - "def parse_response_multiple_texts_bloomz(query_response):\n", |
114 |
| - " generated_text = []\n", |
115 |
| - " model_predictions = json.loads(query_response[\"Body\"].read())\n", |
116 |
| - " for x in model_predictions[0]:\n", |
117 |
| - " generated_text.append(x[\"generated_text\"])\n", |
118 |
| - " return generated_text" |
| 75 | + "sagemaker_session = Session()" |
119 | 76 | ]
|
120 | 77 | },
|
121 | 78 | {
|
122 | 79 | "cell_type": "markdown",
|
123 | 80 | "metadata": {},
|
124 | 81 | "source": [
|
125 |
| - "Deploy SageMaker endpoint(s) for large language models and GPT-J 6B embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance." |
| 82 | + "Deploy SageMaker endpoint(s) for large language models and MiniLM-L6-v2 embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance." |
126 | 83 | ]
|
127 | 84 | },
|
128 | 85 | {
|
|
135 | 92 | "source": [
|
136 | 93 | "_MODEL_CONFIG_ = {\n",
|
137 | 94 | " \"huggingface-text2text-flan-t5-xxl\": {\n",
|
| 95 | + " \"model_version\": \"2.*\",\n", |
138 | 96 | " \"instance type\": \"ml.g5.12xlarge\",\n",
|
139 |
| - " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", |
140 |
| - " \"parse_function\": parse_response_model_flan_t5,\n", |
141 |
| - " \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n", |
142 | 97 | " },\n",
|
143 |
| - " \"huggingface-textembedding-gpt-j-6b\": {\n", |
| 98 | + " \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n", |
| 99 | + " \"model_version\": \"1.*\",\n", |
144 | 100 | " \"instance type\": \"ml.g5.24xlarge\",\n",
|
145 |
| - " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", |
146 | 101 | " },\n",
|
147 |
| - " # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n", |
148 |
| - " # \"instance type\": \"ml.g5.12xlarge\",\n", |
149 |
| - " # \"env\": {},\n", |
150 |
| - " # \"parse_function\": parse_response_multiple_texts_bloomz,\n", |
151 |
| - " # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n", |
| 102 | + " # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n", |
| 103 | + " # \"model_version\": \"3.*\",\n", |
| 104 | + " # \"instance type\": \"ml.g5.12xlarge\"\n", |
152 | 105 | " # },\n",
|
153 | 106 | " # \"huggingface-text2text-flan-ul2-bf16\": {\n",
|
154 |
| - " # \"instance type\": \"ml.g5.24xlarge\",\n", |
155 |
| - " # \"env\": {\n", |
156 |
| - " # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n", |
157 |
| - " # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n", |
158 |
| - " # },\n", |
159 |
| - " # \"parse_function\": parse_response_model_flan_t5,\n", |
160 |
| - " # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n", |
161 |
| - " # },\n", |
| 107 | + " # \"model_version\": \"2.*\",\n", |
| 108 | + " # \"instance type\": \"ml.g5.24xlarge\"\n", |
| 109 | + " # }\n", |
162 | 110 | "}"
|
163 | 111 | ]
|
164 | 112 | },
|
|
168 | 116 | "metadata": {},
|
169 | 117 | "outputs": [],
|
170 | 118 | "source": [
|
171 |
| - "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", |
172 |
| - "\n", |
173 | 119 | "for model_id in _MODEL_CONFIG_:\n",
|
174 | 120 | " endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n",
|
175 | 121 | " inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
|
| 122 | + " model_version = _MODEL_CONFIG_[model_id][\"model_version\"]\n", |
176 | 123 | "\n",
|
177 |
| - " # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n", |
178 |
| - " deploy_image_uri = image_uris.retrieve(\n", |
179 |
| - " region=None,\n", |
180 |
| - " framework=None, # automatically inferred from model_id\n", |
181 |
| - " image_scope=\"inference\",\n", |
182 |
| - " model_id=model_id,\n", |
183 |
| - " model_version=model_version,\n", |
184 |
| - " instance_type=inference_instance_type,\n", |
185 |
| - " )\n", |
186 |
| - " # Retrieve the model uri.\n", |
187 |
| - " model_uri = model_uris.retrieve(\n", |
188 |
| - " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", |
189 |
| - " )\n", |
190 |
| - " model_inference = Model(\n", |
191 |
| - " image_uri=deploy_image_uri,\n", |
192 |
| - " model_data=model_uri,\n", |
193 |
| - " role=aws_role,\n", |
194 |
| - " predictor_cls=Predictor,\n", |
195 |
| - " name=endpoint_name,\n", |
196 |
| - " env=_MODEL_CONFIG_[model_id][\"env\"],\n", |
197 |
| - " )\n", |
198 |
| - " model_predictor_inference = model_inference.deploy(\n", |
199 |
| - " initial_instance_count=1,\n", |
200 |
| - " instance_type=inference_instance_type,\n", |
201 |
| - " predictor_cls=Predictor,\n", |
202 |
| - " endpoint_name=endpoint_name,\n", |
203 |
| - " )\n", |
204 |
| - " print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n", |
205 |
| - " _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name" |
| 124 | + " print(f\"Deploying {model_id}...\")\n", |
| 125 | + "\n", |
| 126 | + " model = JumpStartModel(model_id=model_id, model_version=model_version)\n", |
| 127 | + "\n", |
| 128 | + " try:\n", |
| 129 | + " predictor = model.deploy(\n", |
| 130 | + " initial_instance_count=1,\n", |
| 131 | + " instance_type=inference_instance_type,\n", |
| 132 | + " endpoint_name=name_from_base(f\"jumpstart-example-raglc-{model_id}\"),\n", |
| 133 | + " )\n", |
| 134 | + " print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n", |
| 135 | + " _MODEL_CONFIG_[model_id][\"predictor\"] = predictor\n", |
| 136 | + " except Exception as e:\n", |
| 137 | + " print(f\"Error deploying {model_id}: {str(e)}\")\n", |
| 138 | + "\n", |
| 139 | + "print(\"Deployment process completed.\")" |
206 | 140 | ]
|
207 | 141 | },
|
208 | 142 | {
|
|
229 | 163 | "metadata": {},
|
230 | 164 | "outputs": [],
|
231 | 165 | "source": [
|
232 |
| - "payload = {\n", |
233 |
| - " \"text_inputs\": question,\n", |
234 |
| - " \"max_length\": 100,\n", |
235 |
| - " \"num_return_sequences\": 1,\n", |
236 |
| - " \"top_k\": 50,\n", |
237 |
| - " \"top_p\": 0.95,\n", |
238 |
| - " \"do_sample\": True,\n", |
239 |
| - "}\n", |
240 |
| - "\n", |
241 | 166 | "list_of_LLMs = list(_MODEL_CONFIG_.keys())\n",
|
242 |
| - "list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n", |
243 |
| - "\n", |
| 167 | + "list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n", |
244 | 168 | "\n",
|
245 | 169 | "for model_id in list_of_LLMs:\n",
|
246 |
| - " endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n", |
247 |
| - " query_response = query_endpoint_with_json_payload(\n", |
248 |
| - " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", |
249 |
| - " )\n", |
250 |
| - " generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n", |
251 |
| - " print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")" |
| 170 | + " predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n", |
| 171 | + " response = predictor.predict({\"inputs\": question})\n", |
| 172 | + " print(f\"For model: {model_id}, the generated output is:\\n\")\n", |
| 173 | + " print(f\"{response[0]['generated_text']}\\n\")" |
252 | 174 | ]
|
253 | 175 | },
|
254 | 176 | {
|
|
283 | 205 | "metadata": {},
|
284 | 206 | "outputs": [],
|
285 | 207 | "source": [
|
286 |
| - "parameters = {\n", |
287 |
| - " \"max_length\": 200,\n", |
288 |
| - " \"num_return_sequences\": 1,\n", |
289 |
| - " \"top_k\": 250,\n", |
290 |
| - " \"top_p\": 0.95,\n", |
291 |
| - " \"do_sample\": False,\n", |
292 |
| - " \"temperature\": 1,\n", |
293 |
| - "}\n", |
| 208 | + "prompt = f\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\n", |
294 | 209 | "\n",
|
295 | 210 | "for model_id in list_of_LLMs:\n",
|
296 |
| - " endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n", |
297 |
| - "\n", |
298 |
| - " prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n", |
299 |
| - "\n", |
300 |
| - " text_input = prompt.replace(\"{context}\", context)\n", |
301 |
| - " text_input = text_input.replace(\"{question}\", question)\n", |
302 |
| - " payload = {\"text_inputs\": text_input, **parameters}\n", |
303 |
| - "\n", |
304 |
| - " query_response = query_endpoint_with_json_payload(\n", |
305 |
| - " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", |
306 |
| - " )\n", |
307 |
| - " generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n", |
308 |
| - " print(\n", |
309 |
| - " f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n", |
310 |
| - " )" |
| 211 | + " predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n", |
| 212 | + " response = predictor.predict({\"inputs\": prompt})\n", |
| 213 | + " print(f\"For model: {model_id}, the generated output is:\\n\")\n", |
| 214 | + " print(f\"{response[0]['generated_text']}\\n\")" |
311 | 215 | ]
|
312 | 216 | },
|
313 | 217 | {
|
|
330 | 234 | "\n",
|
331 | 235 | "To achieve that, we will do following.\n",
|
332 | 236 | "\n",
|
333 |
| - "1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**\n", |
| 237 | + "1. **Generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**\n", |
334 | 238 | "2. **Identify top K most relevant documents based on user query.**\n",
|
335 | 239 | " - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**\n",
|
336 | 240 | " - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**\n",
|
|
365 | 269 | "outputs": [],
|
366 | 270 | "source": [
|
367 | 271 | "from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
|
| 272 | + "from langchain.embeddings import SagemakerEndpointEmbeddings\n", |
| 273 | + "from typing import List\n", |
| 274 | + "import boto3\n", |
| 275 | + "\n", |
| 276 | + "aws_region = boto3.Session().region_name\n", |
368 | 277 | "\n",
|
369 | 278 | "\n",
|
370 | 279 | "class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):\n",
|
|
405 | 314 | "\n",
|
406 | 315 | "\n",
|
407 | 316 | "content_handler = ContentHandler()\n",
|
| 317 | + "endpoint_name = _MODEL_CONFIG_[\"huggingface-textembedding-all-MiniLM-L6-v2\"][\n", |
| 318 | + " \"predictor\"\n", |
| 319 | + "].endpoint_name\n", |
408 | 320 | "\n",
|
409 | 321 | "embeddings = SagemakerEndpointEmbeddingsJumpStart(\n",
|
410 |
| - " endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n", |
| 322 | + " endpoint_name=endpoint_name,\n", |
411 | 323 | " region_name=aws_region,\n",
|
412 | 324 | " content_handler=content_handler,\n",
|
413 | 325 | ")"
|
|
428 | 340 | "source": [
|
429 | 341 | "from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n",
|
430 | 342 | "\n",
|
431 |
| - "parameters = {\n", |
432 |
| - " \"max_length\": 200,\n", |
433 |
| - " \"num_return_sequences\": 1,\n", |
434 |
| - " \"top_k\": 250,\n", |
435 |
| - " \"top_p\": 0.95,\n", |
436 |
| - " \"do_sample\": False,\n", |
437 |
| - " \"temperature\": 1,\n", |
438 |
| - "}\n", |
439 |
| - "\n", |
440 | 343 | "\n",
|
441 | 344 | "class ContentHandler(LLMContentHandler):\n",
|
442 | 345 | " content_type = \"application/json\"\n",
|
443 | 346 | " accepts = \"application/json\"\n",
|
444 | 347 | "\n",
|
445 | 348 | " def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
|
446 |
| - " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", |
| 349 | + " input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n", |
447 | 350 | " return input_str.encode(\"utf-8\")\n",
|
448 | 351 | "\n",
|
449 | 352 | " def transform_output(self, output: bytes) -> str:\n",
|
450 | 353 | " response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
451 |
| - " return response_json[\"generated_texts\"][0]\n", |
| 354 | + " return response_json[0][\"generated_text\"]\n", |
452 | 355 | "\n",
|
453 | 356 | "\n",
|
454 | 357 | "content_handler = ContentHandler()\n",
|
| 358 | + "endpoint_name = _MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"predictor\"].endpoint_name\n", |
| 359 | + "\n", |
| 360 | + "parameters = {\n", |
| 361 | + " \"max_length\": 200,\n", |
| 362 | + " \"num_return_sequences\": 1,\n", |
| 363 | + " \"top_k\": 250,\n", |
| 364 | + " \"top_p\": 0.95,\n", |
| 365 | + " \"do_sample\": False,\n", |
| 366 | + " \"temperature\": 1,\n", |
| 367 | + "}\n", |
455 | 368 | "\n",
|
456 | 369 | "sm_llm = SagemakerEndpoint(\n",
|
457 |
| - " endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n", |
| 370 | + " endpoint_name=endpoint_name,\n", |
458 | 371 | " region_name=aws_region,\n",
|
459 | 372 | " model_kwargs=parameters,\n",
|
460 | 373 | " content_handler=content_handler,\n",
|
|
568 | 481 | "from langchain.text_splitter import CharacterTextSplitter\n",
|
569 | 482 | "from langchain import PromptTemplate\n",
|
570 | 483 | "from langchain.chains.question_answering import load_qa_chain\n",
|
571 |
| - "from langchain.document_loaders.csv_loader import CSVLoader" |
| 484 | + "from langchain.document_loaders.csv_loader import CSVLoader\n", |
| 485 | + "import json" |
572 | 486 | ]
|
573 | 487 | },
|
574 | 488 | {
|
|
670 | 584 | "cell_type": "markdown",
|
671 | 585 | "metadata": {},
|
672 | 586 | "source": [
|
673 |
| - "Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**" |
| 587 | + "Firstly, we **generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**" |
674 | 588 | ]
|
675 | 589 | },
|
676 | 590 | {
|
|
1384 | 1298 | ],
|
1385 | 1299 | "instance_type": "ml.t3.medium",
|
1386 | 1300 | "kernelspec": {
|
1387 |
| - "display_name": "Python 3 (Data Science 2.0)", |
| 1301 | + "display_name": "Python 3 (ipykernel)", |
1388 | 1302 | "language": "python",
|
1389 |
| - "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38" |
| 1303 | + "name": "python3" |
1390 | 1304 | },
|
1391 | 1305 | "language_info": {
|
1392 | 1306 | "codemirror_mode": {
|
|
1398 | 1312 | "name": "python",
|
1399 | 1313 | "nbconvert_exporter": "python",
|
1400 | 1314 | "pygments_lexer": "ipython3",
|
1401 |
| - "version": "3.8.13" |
| 1315 | + "version": "3.11.9" |
1402 | 1316 | }
|
1403 | 1317 | },
|
1404 | 1318 | "nbformat": 4,
|
|
0 commit comments