TiberiuCristianLeon commited on
Commit
4e055aa
·
verified ·
1 Parent(s): c6ff6a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -65,21 +65,22 @@ class Translators:
65
  def mitre(self):
66
  from transformers import AutoModel, AutoTokenizer
67
  tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, use_fast=False)
68
- model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
69
- # model.half() # recommended
 
70
  model.eval()
71
-
72
  # Translating from one or several sentences to a sole language
73
  src_tokens = tokenizer.encode_source_tokens_to_input_ids([self.input_text, ], target_language=self.tl)
74
- src_tokens = src_tokens.to(self.device)
 
 
75
  # Translating from one or several sentences to corresponding languages
76
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
77
- # generated_tokens = model.generate(src_tokensto(self.device))
78
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
79
-
80
- with torch.no_grad():
81
- generated_tokens = model.generate(src_tokens)
82
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
83
 
84
  def hplt(self, opus = False):
85
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']
 
65
  def mitre(self):
66
  from transformers import AutoModel, AutoTokenizer
67
  tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, use_fast=False)
68
+ # model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
69
+ model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
70
+ model.half() # recommended
71
  model.eval()
 
72
  # Translating from one or several sentences to a sole language
73
  src_tokens = tokenizer.encode_source_tokens_to_input_ids([self.input_text, ], target_language=self.tl)
74
+ # src_tokens = src_tokens.to(self.device)
75
+ generated_tokens = model.generate(src_tokens)
76
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
77
  # Translating from one or several sentences to corresponding languages
78
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
79
+ # generated_tokens = model.generate(src_tokens.to(self.device))
80
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
+ # with torch.no_grad():
82
+ # generated_tokens = model.generate(src_tokens)
83
+ # return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
 
84
 
85
  def hplt(self, opus = False):
86
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']