import gradio as gr from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline import torch import re import csv import os # Set device to CPU explicitly device = "cpu" # Load the model and tokenizer model_name = "HooshvareLab/bert-base-parsbert-ner-uncased" print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name) model.to(device) # Create NER pipeline ner_pipeline = pipeline( "ner", model=model, tokenizer=tokenizer, device=-1, # -1 means CPU aggregation_strategy="simple" # Groups entities together ) # Load stock symbols from CSV file def load_stock_symbols_from_csv(csv_path='symbols.csv'): """Load stock symbols from CSV file""" stock_symbols = {} try: with open(csv_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: symbol = row['symbol'] company_name = row['company_name'] stock_symbols[symbol] = company_name print(f"Loaded {len(stock_symbols)} stock symbols from CSV") except FileNotFoundError: print(f"Warning: {csv_path} not found. Using default symbols.") return stock_symbols # Load stock symbols STOCK_SYMBOLS = load_stock_symbols_from_csv() # Hypernym patterns (generic terms that can be made more specific) HYPERNYM_PATTERNS = { "شرکت": "ORG", "سازمان": "ORG", "موسسه": "ORG", "بانک": "ORG", "دانشگاه": "ORG", "شهر": "LOC", "کشور": "LOC", "استان": "LOC", "آقای": "PER", "خانم": "PER", "دکتر": "PER", "مهندس": "PER", } # Label mapping for better readability label_colors = { "B-PER": "#FF6B6B", "I-PER": "#FFB3B3", "B-ORG": "#4ECDC4", "I-ORG": "#A7E9E4", "B-LOC": "#95E1D3", "I-LOC": "#C7F0E8", "B-DAT": "#FFA07A", "I-DAT": "#FFDAB9", "B-TIM": "#DDA0DD", "I-TIM": "#E6D0E6", "B-MON": "#FFD700", "I-MON": "#FFEB99", "B-PCT": "#87CEEB", "I-PCT": "#B3DFEF", "STK": "#FF1493", # Stock symbol - Deep Pink "HYP": "#A9A9A9", # Hypernym - Dark Gray } label_names = { "PER": "شخص (Person)", "ORG": "سازمان (Organization)", "LOC": "مکان (Location)", "DAT": "تاریخ (Date)", "TIM": "زمان (Time)", "MON": "پول (Money)", "PCT": "درصد (Percent)", "STK": "نماد بورس (Stock Symbol)", "HYP": "واژه عمومی (Hypernym)", } def detect_stock_symbols(text): """Detect Persian stock market symbols in text""" stock_entities = [] # Split text into words words = re.findall(r'[\u0600-\u06FF]+', text) for word in words: if word in STOCK_SYMBOLS: # Find all occurrences of this symbol in the text for match in re.finditer(re.escape(word), text): stock_entities.append({ 'entity_group': 'STK', 'word': word, 'start': match.start(), 'end': match.end(), 'score': 0.99, # High confidence for dictionary match 'full_name': STOCK_SYMBOLS[word] }) return stock_entities def detect_hypernyms(text, entities): """Detect hypernyms (general terms) in text and classify them""" hypernym_entities = [] for hypernym, entity_type in HYPERNYM_PATTERNS.items(): for match in re.finditer(re.escape(hypernym), text): start, end = match.start(), match.end() # Check if this position already has a specific entity is_covered = False for ent in entities: if start >= ent['start'] and end <= ent['end']: is_covered = True break if not is_covered: hypernym_entities.append({ 'entity_group': 'HYP', 'word': hypernym, 'start': start, 'end': end, 'score': 0.95, 'base_type': entity_type, 'is_hypernym': True }) return hypernym_entities def merge_entities(entities, stock_entities, hypernym_entities): """Merge all entity types and remove overlaps, prioritizing specific entities""" all_entities = entities + stock_entities + hypernym_entities # Sort by start position all_entities.sort(key=lambda x: x['start']) # Remove overlapping entities (keep higher priority) # Priority: STK > specific entities > HYP filtered_entities = [] for entity in all_entities: overlaps = False for existing in filtered_entities: # Check for overlap if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']): overlaps = True # If new entity is stock symbol, replace existing if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK': filtered_entities.remove(existing) overlaps = False # If existing is hypernym and new is specific, replace elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP': filtered_entities.remove(existing) overlaps = False break if not overlaps: filtered_entities.append(entity) return sorted(filtered_entities, key=lambda x: x['start']) def highlight_entities(text, entities): """Create HTML with highlighted entities""" if not entities: return text # Sort entities by start position (reverse order to replace from end to start) entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True) result = text for entity in entities_sorted: start = entity['start'] end = entity['end'] label = entity['entity_group'] word = text[start:end] score = entity['score'] # Get color for this label if label == 'STK': color = label_colors.get('STK') extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else "" title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})" elif label == 'HYP': color = label_colors.get('HYP') base_type = entity.get('base_type', '') title_text = f"Hypernym (general term for {base_type})" else: color = label_colors.get(f"B-{label}", "#CCCCCC") title_text = f"{label} (confidence: {score:.2f})" # Create highlighted span highlighted = f'{word} [{label}]' result = result[:start] + highlighted + result[end:] return result def perform_ner(text): """Perform NER on input text""" if not text.strip(): return "
لطفا متن فارسی وارد کنید (Please enter Persian text)
", "" try: # Perform base NER entities = ner_pipeline(text) # Detect stock symbols stock_entities = detect_stock_symbols(text) # Detect hypernyms hypernym_entities = detect_hypernyms(text, entities) # Merge all entities all_entities = merge_entities(entities, stock_entities, hypernym_entities) # Create highlighted version highlighted_html = f"خطا (Error): {str(e)}
", "" # Save stock symbols to CSV function def save_symbols_to_csv(output_path='symbols.csv'): """Save current stock symbols to CSV file""" with open(output_path, 'w', encoding='utf-8', newline='') as f: writer = csv.writer(f) writer.writerow(['symbol', 'company_name']) for symbol, name in STOCK_SYMBOLS.items(): writer.writerow([symbol, name]) print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}") # Example texts examples = [ ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی میکرد."], ["شرکت گوگل در کالیفرنیا واقع شده است."], ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."], ["دانشگاه تهران یکی از قدیمیترین دانشگاههای ایران است."], ["علی و حسین به همراه مریم به مشهد سفر کردند."], ["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."], ["بانک ملت و وتجارت در بازار بورس فعال هستند."], ["آقای احمدی مدیرعامل شرکت پتروشیمی است."], ["وبملت و فملی امروز در صف خرید قرار گرفتند."], ] # Create Gradio interface with gr.Blocks(title="Persian NER - شناسایی موجودیتهای نامدار فارسی", theme=gr.themes.Soft()) as demo: gr.Markdown(f""" # 🇮🇷 Persian Named Entity Recognition + Stock Symbols # شناسایی موجودیتهای نامدار فارسی + نمادهای بورس این سیستم موجودیتهای نامدار مانند اسامی اشخاص، سازمانها، مکانها، تاریخها، **نمادهای بورس** و **واژههای عمومی (Hypernyms)** را در متن فارسی شناسایی میکند. This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text. **Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection **Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE) **Running on:** CPU (may be slow for long texts) --- ### 📊 APIs for Updating Stock Symbols: **Recommended Python Libraries:** 1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data 2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM 3. **tse-dataloader** - Data extraction from Tehran Stock Exchange **Example Usage:** ```python # Using tsetmc-api from tsetmc_api import market_watch stocks = market_watch.get_market_watch() # Using tehran-stocks from tehran_stocks import Stocks all_stocks = Stocks.query.all() ``` **Official TSE Website:** https://tse.ir **TSETMC Data Portal:** http://www.tsetmc.com """) with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="متن فارسی خود را وارد کنید (Enter Persian Text)", placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...", lines=5, rtl=True ) submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary") with gr.Column(): output_html = gr.HTML(label="متن با موجودیتهای برجسته (Text with Highlighted Entities)") output_entities = gr.Markdown(label="لیست موجودیتها (Entity List)") gr.Examples( examples=examples, inputs=input_text, label="مثالها (Examples)" ) # Legend gr.Markdown(""" ### راهنمای رنگها (Color Guide): - 🔴 **PER (شخص)**: اسامی اشخاص / Person names - 🔵 **ORG (سازمان)**: نام سازمانها / Organizations - 🟢 **LOC (مکان)**: نام مکانها / Locations - 🟠 **DAT (تاریخ)**: تاریخها / Dates - 🟣 **TIM (زمان)**: زمانها / Times - 🟡 **MON (پول)**: مقادیر پولی / Money - 🔷 **PCT (درصد)**: درصدها / Percentages - 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols - ⚫ **HYP (واژه عمومی)**: واژههای عمومی / Hypernyms (general terms) --- ### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد *برای بهروزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.* """) # Event handler submit_btn.click( fn=perform_ner, inputs=input_text, outputs=[output_html, output_entities] ) input_text.submit( fn=perform_ner, inputs=input_text, outputs=[output_html, output_entities] ) # Launch the app if __name__ == "__main__": demo.launch()