Update app.py
Browse files
app.py
CHANGED
|
@@ -69,16 +69,16 @@ class Translators:
|
|
| 69 |
# model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
|
| 70 |
# model.half() # recommended for GPU
|
| 71 |
model.eval()
|
| 72 |
-
model.float()
|
| 73 |
# Translating from one or several sentences to a sole language
|
| 74 |
src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
|
| 75 |
# src_tokens may be a torch.Tensor or dict depending on tokenizer; ensure it's a tensor
|
| 76 |
-
if isinstance(src_tokens, torch.Tensor):
|
| 77 |
-
|
| 78 |
-
else:
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
# src_tokens = src_tokens.to(self.device)
|
| 83 |
# generated_tokens = model.generate(src_tokens)
|
| 84 |
# return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
|
@@ -86,7 +86,7 @@ class Translators:
|
|
| 86 |
# src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
|
| 87 |
# generated_tokens = model.generate(src_tokens.to(self.device))
|
| 88 |
# results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 89 |
-
with torch.
|
| 90 |
generated_tokens = model.generate(src_tokens)
|
| 91 |
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
| 92 |
return result
|
|
|
|
| 69 |
# model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
|
| 70 |
# model.half() # recommended for GPU
|
| 71 |
model.eval()
|
| 72 |
+
# model.float()
|
| 73 |
# Translating from one or several sentences to a sole language
|
| 74 |
src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
|
| 75 |
# src_tokens may be a torch.Tensor or dict depending on tokenizer; ensure it's a tensor
|
| 76 |
+
# if isinstance(src_tokens, torch.Tensor):
|
| 77 |
+
# src_tokens = src_tokens.to(self.device)
|
| 78 |
+
# else:
|
| 79 |
+
# # if tokenizer returns dict-like inputs (input_ids, attention_mask)
|
| 80 |
+
# for k, v in src_tokens.items():
|
| 81 |
+
# src_tokens[k] = v.to(self.device)
|
| 82 |
# src_tokens = src_tokens.to(self.device)
|
| 83 |
# generated_tokens = model.generate(src_tokens)
|
| 84 |
# return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
|
|
|
| 86 |
# src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
|
| 87 |
# generated_tokens = model.generate(src_tokens.to(self.device))
|
| 88 |
# results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 89 |
+
with torch.no_grad(): # no_grad inference_mode
|
| 90 |
generated_tokens = model.generate(src_tokens)
|
| 91 |
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
| 92 |
return result
|