@@ -96,6 +96,11 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
96
96
srt_outputs = srt_runner .forward (
97
97
prompts , max_new_tokens = max_new_tokens , lora_paths = batch_lora_paths
98
98
)
99
+ srt_outputs_lora_path_none = srt_runner .forward (
100
+ prompts ,
101
+ max_new_tokens = max_new_tokens ,
102
+ lora_paths = [None ] * len (prompts ),
103
+ )
99
104
100
105
with HFRunner (
101
106
base_path , torch_dtype = torch_dtype , model_type = "generation"
@@ -169,18 +174,20 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
169
174
print (f"{ srt_outputs .output_strs = } " )
170
175
print (f"{ hf_no_lora_outputs .output_strs = } " )
171
176
print (f"{ srt_no_lora_outputs .output_strs = } " )
177
+ print (f"{ srt_outputs_lora_path_none .output_strs = } " )
172
178
for i in range (len (prompts )):
173
179
assert srt_outputs .output_strs [i ].strip (" " ) == hf_outputs .output_strs [i ], (
174
180
srt_outputs .output_strs [i ].strip (" " ),
175
181
hf_outputs .output_strs [i ],
176
182
)
177
- # assert (
178
- # srt_no_lora_outputs.output_strs[i].strip(" ")
179
- # == hf_no_lora_outputs.output_strs[i]
180
- # ), (
181
- # srt_no_lora_outputs.output_strs[i].strip(" "),
182
- # hf_no_lora_outputs.output_strs[i],
183
- # )
183
+ assert (
184
+ srt_no_lora_outputs .output_strs [i ].strip (" " )
185
+ == hf_no_lora_outputs .output_strs [i ]
186
+ ), (
187
+ srt_no_lora_outputs .output_strs [i ].strip (" " ),
188
+ hf_no_lora_outputs .output_strs [i ],
189
+ )
190
+ assert srt_outputs_lora_path_none == srt_no_lora_outputs
184
191
185
192
def serving (self , prompts , lora_set , tp_size , torch_dtype , max_new_tokens ):
186
193
print ("=================== testing serving =======================" )
@@ -257,7 +264,7 @@ def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens
257
264
srt_no_lora_logprobs = torch .Tensor (
258
265
srt_no_lora_outputs .top_input_logprobs [i ]
259
266
)
260
- srt_logprobs = torch .uensor (srt_outputs .top_input_logprobs [i ])
267
+ srt_logprobs = torch .Tensor (srt_outputs .top_input_logprobs [i ])
261
268
print ("max_diff" , torch .max (abs (srt_no_lora_logprobs - srt_logprobs )))
262
269
263
270
print (f"{ srt_no_lora_outputs .output_strs = } " )
@@ -280,7 +287,7 @@ def test_all(self):
280
287
tp_size = 1
281
288
max_new_tokens = 32
282
289
self .inference (PROMPTS , lora_set , tp_size , torch_dtype , max_new_tokens )
283
- # self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
290
+ self .serving (PROMPTS , lora_set , tp_size , torch_dtype , max_new_tokens )
284
291
# self.base_inference(
285
292
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
286
293
# )
0 commit comments