Skip to content

Commit 4ba750e

Browse files
committed
#4725: Change model deployment to JumpStart
1 parent faf8648 commit 4ba750e

File tree

1 file changed

+72
-158
lines changed

1 file changed

+72
-158
lines changed

introduction_to_amazon_algorithms/jumpstart-foundation-models/question_answering_retrieval_augmented_generation/question_answering_langchain_jumpstart.ipynb

Lines changed: 72 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"\n",
2828
"Many use cases such as building a chatbot require text (text2text) generation models like **[BloomZ 7B1](https://huggingface.co/bigscience/bloomz-7b1)**, **[Flan T5 XXL](https://huggingface.co/google/flan-t5-xxl)**, and **[Flan T5 UL2](https://huggingface.co/google/flan-ul2)** to respond to user questions with insightful answers. The **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.\n",
2929
"\n",
30-
"In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **GPT-J-6B** embedding model. \n",
30+
"In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **MiniLM-L6-v2** embedding model. \n",
3131
"\n",
3232
"**This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.**"
3333
]
@@ -45,7 +45,6 @@
4545
"cell_type": "code",
4646
"execution_count": null,
4747
"metadata": {
48-
"collapsed": false,
4948
"jupyter": {
5049
"outputs_hidden": false
5150
},
@@ -57,9 +56,8 @@
5756
"outputs": [],
5857
"source": [
5958
"!pip install --upgrade sagemaker --quiet\n",
60-
"!pip install ipywidgets==7.0.0 --quiet\n",
61-
"!pip install langchain==0.0.148 --quiet\n",
62-
"!pip install faiss-cpu --quiet"
59+
"!pip install faiss-cpu --quiet\n",
60+
"!pip install langchain --quiet"
6361
]
6462
},
6563
{
@@ -70,59 +68,18 @@
7068
},
7169
"outputs": [],
7270
"source": [
73-
"import time\n",
74-
"import sagemaker, boto3, json\n",
75-
"from sagemaker.session import Session\n",
76-
"from sagemaker.model import Model\n",
77-
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
78-
"from sagemaker.predictor import Predictor\n",
71+
"from sagemaker import Session\n",
7972
"from sagemaker.utils import name_from_base\n",
80-
"from typing import Any, Dict, List, Optional\n",
81-
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
82-
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
73+
"from sagemaker.jumpstart.model import JumpStartModel\n",
8374
"\n",
84-
"sagemaker_session = Session()\n",
85-
"aws_role = sagemaker_session.get_caller_identity_arn()\n",
86-
"aws_region = boto3.Session().region_name\n",
87-
"sess = sagemaker.Session()\n",
88-
"model_version = \"1.*\""
89-
]
90-
},
91-
{
92-
"cell_type": "code",
93-
"execution_count": null,
94-
"metadata": {
95-
"tags": []
96-
},
97-
"outputs": [],
98-
"source": [
99-
"def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n",
100-
" client = boto3.client(\"runtime.sagemaker\")\n",
101-
" response = client.invoke_endpoint(\n",
102-
" EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n",
103-
" )\n",
104-
" return response\n",
105-
"\n",
106-
"\n",
107-
"def parse_response_model_flan_t5(query_response):\n",
108-
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
109-
" generated_text = model_predictions[\"generated_texts\"]\n",
110-
" return generated_text\n",
111-
"\n",
112-
"\n",
113-
"def parse_response_multiple_texts_bloomz(query_response):\n",
114-
" generated_text = []\n",
115-
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
116-
" for x in model_predictions[0]:\n",
117-
" generated_text.append(x[\"generated_text\"])\n",
118-
" return generated_text"
75+
"sagemaker_session = Session()"
11976
]
12077
},
12178
{
12279
"cell_type": "markdown",
12380
"metadata": {},
12481
"source": [
125-
"Deploy SageMaker endpoint(s) for large language models and GPT-J 6B embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance."
82+
"Deploy SageMaker endpoint(s) for large language models and MiniLM-L6-v2 embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance."
12683
]
12784
},
12885
{
@@ -135,30 +92,21 @@
13592
"source": [
13693
"_MODEL_CONFIG_ = {\n",
13794
" \"huggingface-text2text-flan-t5-xxl\": {\n",
95+
" \"model_version\": \"2.*\",\n",
13896
" \"instance type\": \"ml.g5.12xlarge\",\n",
139-
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
140-
" \"parse_function\": parse_response_model_flan_t5,\n",
141-
" \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
14297
" },\n",
143-
" \"huggingface-textembedding-gpt-j-6b\": {\n",
98+
" \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
99+
" \"model_version\": \"1.*\",\n",
144100
" \"instance type\": \"ml.g5.24xlarge\",\n",
145-
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
146101
" },\n",
147-
" # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n",
148-
" # \"instance type\": \"ml.g5.12xlarge\",\n",
149-
" # \"env\": {},\n",
150-
" # \"parse_function\": parse_response_multiple_texts_bloomz,\n",
151-
" # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
102+
" # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
103+
" # \"model_version\": \"3.*\",\n",
104+
" # \"instance type\": \"ml.g5.12xlarge\"\n",
152105
" # },\n",
153106
" # \"huggingface-text2text-flan-ul2-bf16\": {\n",
154-
" # \"instance type\": \"ml.g5.24xlarge\",\n",
155-
" # \"env\": {\n",
156-
" # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n",
157-
" # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n",
158-
" # },\n",
159-
" # \"parse_function\": parse_response_model_flan_t5,\n",
160-
" # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
161-
" # },\n",
107+
" # \"model_version\": \"2.*\",\n",
108+
" # \"instance type\": \"ml.g5.24xlarge\"\n",
109+
" # }\n",
162110
"}"
163111
]
164112
},
@@ -168,41 +116,27 @@
168116
"metadata": {},
169117
"outputs": [],
170118
"source": [
171-
"newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n",
172-
"\n",
173119
"for model_id in _MODEL_CONFIG_:\n",
174120
" endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n",
175121
" inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
122+
" model_version = _MODEL_CONFIG_[model_id][\"model_version\"]\n",
176123
"\n",
177-
" # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n",
178-
" deploy_image_uri = image_uris.retrieve(\n",
179-
" region=None,\n",
180-
" framework=None, # automatically inferred from model_id\n",
181-
" image_scope=\"inference\",\n",
182-
" model_id=model_id,\n",
183-
" model_version=model_version,\n",
184-
" instance_type=inference_instance_type,\n",
185-
" )\n",
186-
" # Retrieve the model uri.\n",
187-
" model_uri = model_uris.retrieve(\n",
188-
" model_id=model_id, model_version=model_version, model_scope=\"inference\"\n",
189-
" )\n",
190-
" model_inference = Model(\n",
191-
" image_uri=deploy_image_uri,\n",
192-
" model_data=model_uri,\n",
193-
" role=aws_role,\n",
194-
" predictor_cls=Predictor,\n",
195-
" name=endpoint_name,\n",
196-
" env=_MODEL_CONFIG_[model_id][\"env\"],\n",
197-
" )\n",
198-
" model_predictor_inference = model_inference.deploy(\n",
199-
" initial_instance_count=1,\n",
200-
" instance_type=inference_instance_type,\n",
201-
" predictor_cls=Predictor,\n",
202-
" endpoint_name=endpoint_name,\n",
203-
" )\n",
204-
" print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n",
205-
" _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name"
124+
" print(f\"Deploying {model_id}...\")\n",
125+
"\n",
126+
" model = JumpStartModel(model_id=model_id, model_version=model_version)\n",
127+
"\n",
128+
" try:\n",
129+
" predictor = model.deploy(\n",
130+
" initial_instance_count=1,\n",
131+
" instance_type=inference_instance_type,\n",
132+
" endpoint_name=name_from_base(f\"jumpstart-example-raglc-{model_id}\"),\n",
133+
" )\n",
134+
" print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n",
135+
" _MODEL_CONFIG_[model_id][\"predictor\"] = predictor\n",
136+
" except Exception as e:\n",
137+
" print(f\"Error deploying {model_id}: {str(e)}\")\n",
138+
"\n",
139+
"print(\"Deployment process completed.\")"
206140
]
207141
},
208142
{
@@ -229,26 +163,14 @@
229163
"metadata": {},
230164
"outputs": [],
231165
"source": [
232-
"payload = {\n",
233-
" \"text_inputs\": question,\n",
234-
" \"max_length\": 100,\n",
235-
" \"num_return_sequences\": 1,\n",
236-
" \"top_k\": 50,\n",
237-
" \"top_p\": 0.95,\n",
238-
" \"do_sample\": True,\n",
239-
"}\n",
240-
"\n",
241166
"list_of_LLMs = list(_MODEL_CONFIG_.keys())\n",
242-
"list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n",
243-
"\n",
167+
"list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n",
244168
"\n",
245169
"for model_id in list_of_LLMs:\n",
246-
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
247-
" query_response = query_endpoint_with_json_payload(\n",
248-
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
249-
" )\n",
250-
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
251-
" print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")"
170+
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
171+
" response = predictor.predict({\"inputs\": question})\n",
172+
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
173+
" print(f\"{response[0]['generated_text']}\\n\")"
252174
]
253175
},
254176
{
@@ -283,31 +205,13 @@
283205
"metadata": {},
284206
"outputs": [],
285207
"source": [
286-
"parameters = {\n",
287-
" \"max_length\": 200,\n",
288-
" \"num_return_sequences\": 1,\n",
289-
" \"top_k\": 250,\n",
290-
" \"top_p\": 0.95,\n",
291-
" \"do_sample\": False,\n",
292-
" \"temperature\": 1,\n",
293-
"}\n",
208+
"prompt = f\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\n",
294209
"\n",
295210
"for model_id in list_of_LLMs:\n",
296-
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
297-
"\n",
298-
" prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n",
299-
"\n",
300-
" text_input = prompt.replace(\"{context}\", context)\n",
301-
" text_input = text_input.replace(\"{question}\", question)\n",
302-
" payload = {\"text_inputs\": text_input, **parameters}\n",
303-
"\n",
304-
" query_response = query_endpoint_with_json_payload(\n",
305-
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
306-
" )\n",
307-
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
308-
" print(\n",
309-
" f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n",
310-
" )"
211+
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
212+
" response = predictor.predict({\"inputs\": prompt})\n",
213+
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
214+
" print(f\"{response[0]['generated_text']}\\n\")"
311215
]
312216
},
313217
{
@@ -330,7 +234,7 @@
330234
"\n",
331235
"To achieve that, we will do following.\n",
332236
"\n",
333-
"1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**\n",
237+
"1. **Generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**\n",
334238
"2. **Identify top K most relevant documents based on user query.**\n",
335239
" - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**\n",
336240
" - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**\n",
@@ -365,6 +269,11 @@
365269
"outputs": [],
366270
"source": [
367271
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
272+
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
273+
"from typing import List\n",
274+
"import boto3\n",
275+
"\n",
276+
"aws_region = boto3.Session().region_name\n",
368277
"\n",
369278
"\n",
370279
"class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):\n",
@@ -405,9 +314,12 @@
405314
"\n",
406315
"\n",
407316
"content_handler = ContentHandler()\n",
317+
"endpoint_name = _MODEL_CONFIG_[\"huggingface-textembedding-all-MiniLM-L6-v2\"][\n",
318+
" \"predictor\"\n",
319+
"].endpoint_name\n",
408320
"\n",
409321
"embeddings = SagemakerEndpointEmbeddingsJumpStart(\n",
410-
" endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n",
322+
" endpoint_name=endpoint_name,\n",
411323
" region_name=aws_region,\n",
412324
" content_handler=content_handler,\n",
413325
")"
@@ -428,33 +340,34 @@
428340
"source": [
429341
"from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n",
430342
"\n",
431-
"parameters = {\n",
432-
" \"max_length\": 200,\n",
433-
" \"num_return_sequences\": 1,\n",
434-
" \"top_k\": 250,\n",
435-
" \"top_p\": 0.95,\n",
436-
" \"do_sample\": False,\n",
437-
" \"temperature\": 1,\n",
438-
"}\n",
439-
"\n",
440343
"\n",
441344
"class ContentHandler(LLMContentHandler):\n",
442345
" content_type = \"application/json\"\n",
443346
" accepts = \"application/json\"\n",
444347
"\n",
445348
" def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
446-
" input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n",
349+
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
447350
" return input_str.encode(\"utf-8\")\n",
448351
"\n",
449352
" def transform_output(self, output: bytes) -> str:\n",
450353
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
451-
" return response_json[\"generated_texts\"][0]\n",
354+
" return response_json[0][\"generated_text\"]\n",
452355
"\n",
453356
"\n",
454357
"content_handler = ContentHandler()\n",
358+
"endpoint_name = _MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"predictor\"].endpoint_name\n",
359+
"\n",
360+
"parameters = {\n",
361+
" \"max_length\": 200,\n",
362+
" \"num_return_sequences\": 1,\n",
363+
" \"top_k\": 250,\n",
364+
" \"top_p\": 0.95,\n",
365+
" \"do_sample\": False,\n",
366+
" \"temperature\": 1,\n",
367+
"}\n",
455368
"\n",
456369
"sm_llm = SagemakerEndpoint(\n",
457-
" endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n",
370+
" endpoint_name=endpoint_name,\n",
458371
" region_name=aws_region,\n",
459372
" model_kwargs=parameters,\n",
460373
" content_handler=content_handler,\n",
@@ -568,7 +481,8 @@
568481
"from langchain.text_splitter import CharacterTextSplitter\n",
569482
"from langchain import PromptTemplate\n",
570483
"from langchain.chains.question_answering import load_qa_chain\n",
571-
"from langchain.document_loaders.csv_loader import CSVLoader"
484+
"from langchain.document_loaders.csv_loader import CSVLoader\n",
485+
"import json"
572486
]
573487
},
574488
{
@@ -670,7 +584,7 @@
670584
"cell_type": "markdown",
671585
"metadata": {},
672586
"source": [
673-
"Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**"
587+
"Firstly, we **generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**"
674588
]
675589
},
676590
{
@@ -1384,9 +1298,9 @@
13841298
],
13851299
"instance_type": "ml.t3.medium",
13861300
"kernelspec": {
1387-
"display_name": "Python 3 (Data Science 2.0)",
1301+
"display_name": "Python 3 (ipykernel)",
13881302
"language": "python",
1389-
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38"
1303+
"name": "python3"
13901304
},
13911305
"language_info": {
13921306
"codemirror_mode": {
@@ -1398,7 +1312,7 @@
13981312
"name": "python",
13991313
"nbconvert_exporter": "python",
14001314
"pygments_lexer": "ipython3",
1401-
"version": "3.8.13"
1315+
"version": "3.11.9"
14021316
}
14031317
},
14041318
"nbformat": 4,

0 commit comments

Comments
 (0)