Spaces:
Running
Running
优化预测函数,增加执行时间和输入文本长度的打印,调整处理逻辑以提高可读性
Browse files- blkeras.py +7 -1
- preprocess.py +2 -1
blkeras.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from huggingface_hub import login
|
| 3 |
from huggingface_hub import hf_hub_download
|
| 4 |
|
|
@@ -98,6 +99,8 @@ def predict(text: str, stock_codes: list):
|
|
| 98 |
from preprocess import get_document_vector, get_stock_info, preprocessing_entry, process_entities, process_pos_tags, processing_entry
|
| 99 |
|
| 100 |
try:
|
|
|
|
|
|
|
| 101 |
input_text = text
|
| 102 |
affected_stock_codes = stock_codes
|
| 103 |
|
|
@@ -110,7 +113,7 @@ def predict(text: str, stock_codes: list):
|
|
| 110 |
processed_entry = processing_entry(input_text)
|
| 111 |
|
| 112 |
# 解包 processed_entry 中的各个值
|
| 113 |
-
lemmatized_entry, pos_tag, ner,
|
| 114 |
|
| 115 |
# 分别打印每个变量,便于调试
|
| 116 |
#print("Lemmatized Entry:", lemmatized_entry)
|
|
@@ -403,6 +406,9 @@ def predict(text: str, stock_codes: list):
|
|
| 403 |
print(f"predict() error: {e}")
|
| 404 |
print(traceback_str)
|
| 405 |
return {"predict() error": str(e), "traceback": traceback_str}
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
|
| 408 |
def stock_fix_for_1118_model(score, predictions, last_prices, is_index=True):
|
|
|
|
| 1 |
import os
|
| 2 |
+
from tracemalloc import start
|
| 3 |
from huggingface_hub import login
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
|
|
|
|
| 99 |
from preprocess import get_document_vector, get_stock_info, preprocessing_entry, process_entities, process_pos_tags, processing_entry
|
| 100 |
|
| 101 |
try:
|
| 102 |
+
|
| 103 |
+
start_time = datetime.now()
|
| 104 |
input_text = text
|
| 105 |
affected_stock_codes = stock_codes
|
| 106 |
|
|
|
|
| 113 |
processed_entry = processing_entry(input_text)
|
| 114 |
|
| 115 |
# 解包 processed_entry 中的各个值
|
| 116 |
+
lemmatized_entry, pos_tag, ner, _ , sentiment_score = processed_entry
|
| 117 |
|
| 118 |
# 分别打印每个变量,便于调试
|
| 119 |
#print("Lemmatized Entry:", lemmatized_entry)
|
|
|
|
| 406 |
print(f"predict() error: {e}")
|
| 407 |
print(traceback_str)
|
| 408 |
return {"predict() error": str(e), "traceback": traceback_str}
|
| 409 |
+
finally:
|
| 410 |
+
end_time = datetime.now()
|
| 411 |
+
print(f"predict() Text: {input_text[:200] if len(input_text) > 200 else input_text} \n execution time: {end_time - start_time}, Text Length: {len(input_text)} \n")
|
| 412 |
|
| 413 |
|
| 414 |
def stock_fix_for_1118_model(score, predictions, last_prices, is_index=True):
|
preprocess.py
CHANGED
|
@@ -562,8 +562,9 @@ def processing_entry(entry):
|
|
| 562 |
ner = named_entity_recognition(cleaned_text)
|
| 563 |
# print(f"named_entity_recognition: {db_ner}")
|
| 564 |
|
| 565 |
-
dependency_parsed = dependency_parsing(cleaned_text)
|
| 566 |
# print(f"dependency_parsing: {db_dependency_parsing}")
|
|
|
|
| 567 |
|
| 568 |
sentiment_score = get_sentiment_score(cleaned_text)
|
| 569 |
# print(f"sentiment_score: {sentiment_score}")
|
|
|
|
| 562 |
ner = named_entity_recognition(cleaned_text)
|
| 563 |
# print(f"named_entity_recognition: {db_ner}")
|
| 564 |
|
| 565 |
+
# dependency_parsed = dependency_parsing(cleaned_text)
|
| 566 |
# print(f"dependency_parsing: {db_dependency_parsing}")
|
| 567 |
+
dependency_parsed = None
|
| 568 |
|
| 569 |
sentiment_score = get_sentiment_score(cleaned_text)
|
| 570 |
# print(f"sentiment_score: {sentiment_score}")
|