Skip to content

Commit 35af814

Browse files
authored
fix embedding args
1 parent 2433fdc commit 35af814

File tree

1 file changed

+85
-41
lines changed

1 file changed

+85
-41
lines changed

src/database_interactions.py

Lines changed: 85 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from utilities import my_cprint, get_model_native_precision, get_appropriate_dtype, supports_flash_attention
3030
from constants import VECTOR_MODELS
3131

32+
# logging.basicConfig(level=logging.CRITICAL, force=True)
3233
logging.basicConfig(level=logging.INFO, force=True)
3334
# logging.basicConfig(level=logging.DEBUG, force=True)
3435
logger = logging.getLogger(__name__)
@@ -46,12 +47,18 @@ def prepare_kwargs(self):
4647
def prepare_encode_kwargs(self):
4748
if self.is_query:
4849
self.encode_kwargs['batch_size'] = 1
50+
self.encode_kwargs.setdefault('padding', True)
51+
self.encode_kwargs.setdefault('truncation', True)
4952
return self.encode_kwargs
5053

5154
def create(self):
5255
prepared_kwargs = self.prepare_kwargs()
5356
prepared_encode_kwargs = self.prepare_encode_kwargs()
5457

58+
# — Add these two lines to see exactly what the tokenizer will get
59+
print(">>> [BaseEmbeddingModel.create] model_kwargs: ", prepared_kwargs)
60+
print(">>> [BaseEmbeddingModel.create] encode_kwargs: ", prepared_encode_kwargs)
61+
5562
return HuggingFaceEmbeddings(
5663
model_name=self.model_name,
5764
show_progress=not self.is_query,
@@ -119,57 +126,72 @@ def prepare_encode_kwargs(self):
119126
return encode_kwargs
120127

121128

129+
# class Stella400MEmbedding(BaseEmbeddingModel):
130+
# def prepare_kwargs(self):
131+
# stella_kwargs = deepcopy(self.model_kwargs)
132+
# compute_device = self.model_kwargs.get("device", "").lower()
133+
# is_cuda = compute_device == "cuda"
134+
# use_xformers = is_cuda and supports_flash_attention()
135+
136+
# logging.debug(f"Device: {compute_device}")
137+
# logging.debug(f"is_cuda: {is_cuda}")
138+
# logging.debug(f"use_xformers: {use_xformers}")
139+
140+
# stella_kwargs["config_kwargs"] = {
141+
# "use_memory_efficient_attention": use_xformers,
142+
# "unpad_inputs": use_xformers,
143+
# "attn_implementation": "eager" # sdpa is not implemented yet like it is for Stella and Snowflake
144+
# }
145+
146+
# logging.debug("\nFinal config settings:")
147+
# logging.debug(f"use_memory_efficient_attention: {stella_kwargs['config_kwargs']['use_memory_efficient_attention']}")
148+
# logging.debug(f"unpad_inputs: {stella_kwargs['config_kwargs']['unpad_inputs']}")
149+
# logging.debug(f"attn_implementation: {stella_kwargs['config_kwargs']['attn_implementation']}")
150+
151+
# return stella_kwargs
152+
153+
122154
class Stella400MEmbedding(BaseEmbeddingModel):
123155
def prepare_kwargs(self):
124156
stella_kwargs = deepcopy(self.model_kwargs)
125-
compute_device = self.model_kwargs.get("device", "").lower()
157+
158+
# ---------- flash-attention toggle ----------
159+
compute_device = stella_kwargs.get("device", "").lower()
126160
is_cuda = compute_device == "cuda"
127161
use_xformers = is_cuda and supports_flash_attention()
128162

129-
logging.debug(f"Device: {compute_device}")
130-
logging.debug(f"is_cuda: {is_cuda}")
131-
logging.debug(f"use_xformers: {use_xformers}")
132-
133-
stella_kwargs["config_kwargs"] = {
134-
"use_memory_efficient_attention": use_xformers,
135-
"unpad_inputs": use_xformers,
136-
"attn_implementation": "eager" # sdpa is not implemented yet like it is for Stella and Snowflake
163+
stella_kwargs["tokenizer_kwargs"] = {
164+
"padding": "longest",
165+
"truncation": True,
166+
"max_length": 8192
137167
}
138168

139-
logging.debug("\nFinal config settings:")
140-
logging.debug(f"use_memory_efficient_attention: {stella_kwargs['config_kwargs']['use_memory_efficient_attention']}")
141-
logging.debug(f"unpad_inputs: {stella_kwargs['config_kwargs']['unpad_inputs']}")
142-
logging.debug(f"attn_implementation: {stella_kwargs['config_kwargs']['attn_implementation']}")
143-
144-
return stella_kwargs
145-
146-
147-
# class AlibabaEmbedding(BaseEmbeddingModel):
148-
# def prepare_kwargs(self):
149-
# ali_kwargs = deepcopy(self.model_kwargs)
150-
# compute_device = ali_kwargs.get("device", "").lower()
151-
# is_cuda = compute_device == "cuda"
152-
# use_xformers = is_cuda and supports_flash_attention()
153-
# ali_kwargs["tokenizer_kwargs"] = {
154-
# "padding": "longest",
155-
# "truncation": True,
156-
# "max_length": 8192
157-
# }
158-
# ali_kwargs["config_kwargs"] = {
169+
# # uncomment to use xformers
170+
# stella_kwargs["config_kwargs"] = {
159171
# "use_memory_efficient_attention": use_xformers,
160172
# "unpad_inputs": use_xformers,
161-
# "attn_implementation": "eager" if use_xformers else "sdpa"
173+
# "attn_implementation": "eager"
162174
# }
163-
# return ali_kwargs
164175

165-
# def prepare_encode_kwargs(self):
166-
# encode_kwargs = super().prepare_encode_kwargs()
176+
stella_kwargs["config_kwargs"] = {
177+
"use_memory_efficient_attention": False,
178+
"unpad_inputs": False,
179+
"attn_implementation": "eager",
180+
# "attn_implementation": "sdpa"
181+
}
182+
183+
return stella_kwargs
184+
185+
def prepare_encode_kwargs(self):
186+
encode_kwargs = super().prepare_encode_kwargs()
167187
# encode_kwargs.update({
168188
# "padding": True,
169189
# "truncation": True,
170190
# "max_length": 8192
171191
# })
172-
# return encode_kwargs
192+
if self.is_query:
193+
encode_kwargs["prompt_name"] = "s2p_query"
194+
return encode_kwargs
173195

174196

175197
class AlibabaEmbedding(BaseEmbeddingModel):
@@ -194,15 +216,14 @@ def prepare_kwargs(self):
194216

195217
def prepare_encode_kwargs(self):
196218
encode_kwargs = super().prepare_encode_kwargs()
197-
encode_kwargs.update({
198-
"padding": True,
199-
"truncation": True,
200-
"max_length": 8192
201-
})
219+
# encode_kwargs.update({
220+
# "padding": True,
221+
# "truncation": True,
222+
# "max_length": 8192
223+
# })
202224
return encode_kwargs
203225

204226

205-
206227
def create_vector_db_in_process(database_name):
207228
create_vector_db = CreateVectorDB(database_name=database_name)
208229
create_vector_db.run()
@@ -342,6 +363,30 @@ def create_database(self, texts, embeddings):
342363
with open(self.ROOT_DIRECTORY / "config.yaml", 'r', encoding='utf-8') as config_file:
343364
config_data = yaml.safe_load(config_file)
344365

366+
# --- memory-hygiene block ---------------------------------
367+
gc.collect()
368+
if torch.cuda.is_available():
369+
torch.cuda.empty_cache()
370+
371+
try:
372+
# reserve the exact block size early
373+
dummy = np.empty(
374+
(len(all_texts), config_data["EMBEDDING_MODEL_DIMENSIONS"]),
375+
dtype=np.float32,
376+
)
377+
del dummy
378+
except MemoryError:
379+
raise MemoryError(
380+
"Unable to reserve contiguous RAM for the embedding matrix. "
381+
"Try a smaller batch, float16 storage, or run on a machine with "
382+
"more free RAM."
383+
)
384+
385+
# ─── ADD HERE ───
386+
print(">>> [create_database] sample texts (first 5):", texts[:5])
387+
print(">>> [create_database] type(texts):", type(texts))
388+
# ─────────────────
389+
345390
db = TileDB.from_texts(
346391
texts=all_texts,
347392
embedding=embeddings,
@@ -352,7 +397,6 @@ def create_database(self, texts, embeddings):
352397
index_type="FLAT",
353398
dimensions=config_data.get("EMBEDDING_MODEL_DIMENSIONS"),
354399
allow_dangerous_deserialization=True,
355-
# vector_type=np.float32
356400
)
357401

358402
my_cprint(f"Processed {len(all_texts)} chunks", "yellow")

0 commit comments

Comments
 (0)