| """ |
| TinyByteCNN Model for Fiction vs Non-Fiction Classification |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import unicodedata |
| import re |
| from typing import Union, List |
|
|
|
|
| class SE(nn.Module): |
| """Squeeze-Excitation module""" |
| def __init__(self, c, r=8): |
| super().__init__() |
| m = max(c // r, 4) |
| self.fc1 = nn.Linear(c, m) |
| self.fc2 = nn.Linear(m, c) |
| |
| def forward(self, x): |
| |
| s = x.mean(dim=-1) |
| s = F.silu(self.fc1(s)) |
| s = torch.sigmoid(self.fc2(s)) |
| return x * s.unsqueeze(-1) |
|
|
|
|
| class SepResBlock(nn.Module): |
| """Separable Residual Block with SE attention""" |
| def __init__(self, c_in, c_out, k=7, stride=1, dilation=1, use_gn=False, se_ratio=8, drop=0.0): |
| super().__init__() |
| Norm = (lambda c: nn.GroupNorm(32, c)) if use_gn else nn.BatchNorm1d |
| |
| self.dw = nn.Conv1d(c_in, c_in, k, stride=stride, dilation=dilation, |
| padding=((k-1)//2)*dilation, groups=c_in, bias=False) |
| self.bn1 = Norm(c_in) |
| self.pw = nn.Conv1d(c_in, c_out, 1, bias=False) |
| self.bn2 = Norm(c_out) |
| self.se = SE(c_out, se_ratio) |
| self.drop = nn.Dropout(p=drop) |
| |
| self.proj = None |
| if stride != 1 or c_in != c_out: |
| self.proj = nn.Conv1d(c_in, c_out, 1, stride=stride, bias=False) |
| |
| def forward(self, x): |
| y = self.dw(x) |
| y = F.silu(self.bn1(y)) |
| y = self.pw(y) |
| y = self.bn2(y) |
| y = self.se(y) |
| if self.proj is not None: |
| x = self.proj(x) |
| y = self.drop(y) |
| return F.silu(x + y) |
|
|
|
|
| class TinyByteCNN(nn.Module): |
| """TinyByteCNN for Fiction vs Non-Fiction Classification""" |
| |
| def __init__(self, config=None): |
| super().__init__() |
| |
| |
| if config is None: |
| config = type('Config', (), { |
| 'vocab_size': 256, |
| 'embed_dim': 32, |
| 'widths': [128, 192, 256, 320], |
| 'use_gn': False, |
| 'head_drop': 0.1, |
| 'stochastic_depth': 0.05 |
| })() |
| |
| self.config = config |
| |
| |
| self.embed = nn.Embedding(config.vocab_size, config.embed_dim) |
| |
| |
| self.stem = nn.Conv1d(config.embed_dim, config.widths[0], 5, stride=2, padding=2, bias=False) |
| self.bn0 = nn.BatchNorm1d(config.widths[0]) if not config.use_gn else nn.GroupNorm(32, config.widths[0]) |
| |
| |
| cfg = [ |
| (2, config.widths[0], [1, 2]), |
| (2, config.widths[1], [1, 2]), |
| (3, config.widths[2], [1, 2, 4]), |
| (3, config.widths[3], [1, 2, 8]) |
| ] |
| |
| stages = [] |
| c_prev = config.widths[0] |
| for blocks, c, ds in cfg: |
| for i in range(blocks): |
| stride = 2 if i == 0 else 1 |
| d = ds[i] |
| stages.append(SepResBlock(c_prev, c, k=7, stride=stride, dilation=d, |
| use_gn=config.use_gn, drop=config.stochastic_depth)) |
| c_prev = c |
| |
| self.stages = nn.Sequential(*stages) |
| |
| |
| self.head = nn.Sequential( |
| nn.Dropout(p=config.head_drop), |
| nn.Linear(2 * config.widths[-1], 1) |
| ) |
| |
| def forward(self, x_bytes): |
| """ |
| Args: |
| x_bytes: [B, T] uint8 tensor of byte values |
| Returns: |
| logits: [B] tensor of binary classification logits |
| """ |
| x = self.embed(x_bytes.long()) |
| x = x.transpose(1, 2).contiguous() |
| x = F.silu(self.bn0(self.stem(x))) |
| x = self.stages(x) |
| |
| |
| avg = x.mean(dim=-1) |
| mx = x.amax(dim=-1) |
| feats = torch.cat([avg, mx], dim=1) |
| |
| logits = self.head(feats).squeeze(1) |
| return logits |
| |
| @classmethod |
| def from_pretrained(cls, path_or_repo, use_safetensors=True): |
| """Load pretrained model (supports both .bin and .safetensors)""" |
| import os |
| from pathlib import Path |
| |
| |
| if os.path.isdir(path_or_repo): |
| |
| base_path = Path(path_or_repo) |
| safetensors_path = base_path / "model.safetensors" |
| pytorch_path = base_path / "pytorch_model.bin" |
| |
| if use_safetensors and safetensors_path.exists(): |
| |
| from safetensors.torch import load_file |
| state_dict = load_file(str(safetensors_path)) |
| |
| |
| config_path = base_path / "config.json" |
| if config_path.exists(): |
| import json |
| with open(config_path) as f: |
| config_dict = json.load(f) |
| config = type('Config', (), config_dict)() |
| else: |
| config = None |
| |
| model = cls(config) |
| model.load_state_dict(state_dict) |
| return model |
| elif pytorch_path.exists(): |
| checkpoint = torch.load(pytorch_path, weights_only=False, map_location='cpu') |
| elif os.path.isfile(path_or_repo): |
| if path_or_repo.endswith('.safetensors'): |
| from safetensors.torch import load_file |
| state_dict = load_file(path_or_repo) |
| model = cls() |
| model.load_state_dict(state_dict) |
| return model |
| else: |
| checkpoint = torch.load(path_or_repo, weights_only=False, map_location='cpu') |
| else: |
| |
| from huggingface_hub import hf_hub_download |
| |
| if use_safetensors: |
| try: |
| model_file = hf_hub_download(repo_id=path_or_repo, filename="model.safetensors") |
| from safetensors.torch import load_file |
| state_dict = load_file(model_file) |
| model = cls() |
| model.load_state_dict(state_dict) |
| return model |
| except: |
| pass |
| |
| model_file = hf_hub_download(repo_id=path_or_repo, filename="pytorch_model.bin") |
| checkpoint = torch.load(model_file, weights_only=False, map_location='cpu') |
| |
| |
| if 'checkpoint' in locals(): |
| config = checkpoint.get('config', None) |
| model = cls(config) |
| state_dict = checkpoint.get('model_state_dict', checkpoint) |
| model.load_state_dict(state_dict) |
| return model |
| |
| def save_pretrained(self, save_path): |
| """Save model to directory""" |
| import os |
| os.makedirs(save_path, exist_ok=True) |
| |
| torch.save({ |
| 'model_state_dict': self.state_dict(), |
| 'config': self.config |
| }, os.path.join(save_path, 'pytorch_model.bin')) |
|
|
|
|
| def preprocess_text(text: str, max_len: int = 4096) -> torch.Tensor: |
| """ |
| Preprocess text to bytes for model input |
| |
| Args: |
| text: Input text string |
| max_len: Maximum sequence length (default 4096) |
| |
| Returns: |
| Tensor of shape [1, max_len] containing byte values |
| """ |
| |
| text = unicodedata.normalize('NFC', text) |
| |
| |
| text = text.replace('\r\n', '\n') |
| |
| |
| text = re.sub(r'\s{3,}', ' ', text) |
| |
| |
| text_bytes = text.encode('utf-8', errors='ignore') |
| |
| |
| input_ids = np.zeros(max_len, dtype=np.uint8) |
| input_ids[:min(len(text_bytes), max_len)] = list(text_bytes[:max_len]) |
| |
| return torch.from_numpy(input_ids).unsqueeze(0) |
|
|
|
|
| def classify_text(text: Union[str, List[str]], model=None, device='cpu'): |
| """ |
| Classify text as fiction or non-fiction |
| |
| Args: |
| text: Single string or list of strings to classify |
| model: Pre-loaded model (optional) |
| device: Device to run on ('cpu', 'cuda', 'mps') |
| |
| Returns: |
| Dictionary with predictions and confidence scores |
| """ |
| if model is None: |
| model = TinyByteCNN.from_pretrained("fiction_classifier_hf") |
| |
| model = model.to(device) |
| model.eval() |
| |
| |
| if isinstance(text, str): |
| texts = [text] |
| else: |
| texts = text |
| |
| results = [] |
| |
| for t in texts: |
| input_ids = preprocess_text(t).to(device) |
| |
| with torch.no_grad(): |
| logits = model(input_ids) |
| prob = torch.sigmoid(logits).item() |
| |
| pred_class = "Non-Fiction" if prob > 0.5 else "Fiction" |
| confidence = prob if prob > 0.5 else (1 - prob) |
| |
| results.append({ |
| 'text': t[:100] + '...' if len(t) > 100 else t, |
| 'prediction': pred_class, |
| 'confidence': confidence, |
| 'probability_nonfiction': prob |
| }) |
| |
| return results[0] if isinstance(text, str) else results |
|
|
|
|
| if __name__ == "__main__": |
| |
| sample_text = "The detective's coffee had gone cold hours ago, but she hardly noticed." |
| |
| |
| model = TinyByteCNN.from_pretrained("fiction_model_output_cnn/best_model.pt") |
| result = classify_text(sample_text, model) |
| |
| print(f"Text: {result['text']}") |
| print(f"Prediction: {result['prediction']}") |
| print(f"Confidence: {result['confidence']:.1%}") |