""" TinyByteCNN Model for Fiction vs Non-Fiction Classification """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import unicodedata import re from typing import Union, List class SE(nn.Module): """Squeeze-Excitation module""" def __init__(self, c, r=8): super().__init__() m = max(c // r, 4) self.fc1 = nn.Linear(c, m) self.fc2 = nn.Linear(m, c) def forward(self, x): # x: [B, C, T] s = x.mean(dim=-1) # [B, C] s = F.silu(self.fc1(s)) s = torch.sigmoid(self.fc2(s)) # [B, C] return x * s.unsqueeze(-1) class SepResBlock(nn.Module): """Separable Residual Block with SE attention""" def __init__(self, c_in, c_out, k=7, stride=1, dilation=1, use_gn=False, se_ratio=8, drop=0.0): super().__init__() Norm = (lambda c: nn.GroupNorm(32, c)) if use_gn else nn.BatchNorm1d self.dw = nn.Conv1d(c_in, c_in, k, stride=stride, dilation=dilation, padding=((k-1)//2)*dilation, groups=c_in, bias=False) self.bn1 = Norm(c_in) self.pw = nn.Conv1d(c_in, c_out, 1, bias=False) self.bn2 = Norm(c_out) self.se = SE(c_out, se_ratio) self.drop = nn.Dropout(p=drop) self.proj = None if stride != 1 or c_in != c_out: self.proj = nn.Conv1d(c_in, c_out, 1, stride=stride, bias=False) def forward(self, x): y = self.dw(x) y = F.silu(self.bn1(y)) y = self.pw(y) y = self.bn2(y) y = self.se(y) if self.proj is not None: x = self.proj(x) y = self.drop(y) return F.silu(x + y) class TinyByteCNN(nn.Module): """TinyByteCNN for Fiction vs Non-Fiction Classification""" def __init__(self, config=None): super().__init__() # Default configuration if config is None: config = type('Config', (), { 'vocab_size': 256, 'embed_dim': 32, 'widths': [128, 192, 256, 320], 'use_gn': False, 'head_drop': 0.1, 'stochastic_depth': 0.05 })() self.config = config # Embedding layer for bytes self.embed = nn.Embedding(config.vocab_size, config.embed_dim) # Stem convolution self.stem = nn.Conv1d(config.embed_dim, config.widths[0], 5, stride=2, padding=2, bias=False) self.bn0 = nn.BatchNorm1d(config.widths[0]) if not config.use_gn else nn.GroupNorm(32, config.widths[0]) # Build stages cfg = [ (2, config.widths[0], [1, 2]), (2, config.widths[1], [1, 2]), (3, config.widths[2], [1, 2, 4]), (3, config.widths[3], [1, 2, 8]) ] stages = [] c_prev = config.widths[0] for blocks, c, ds in cfg: for i in range(blocks): stride = 2 if i == 0 else 1 d = ds[i] stages.append(SepResBlock(c_prev, c, k=7, stride=stride, dilation=d, use_gn=config.use_gn, drop=config.stochastic_depth)) c_prev = c self.stages = nn.Sequential(*stages) # Classification head self.head = nn.Sequential( nn.Dropout(p=config.head_drop), nn.Linear(2 * config.widths[-1], 1) ) def forward(self, x_bytes): """ Args: x_bytes: [B, T] uint8 tensor of byte values Returns: logits: [B] tensor of binary classification logits """ x = self.embed(x_bytes.long()) # [B, T, E] x = x.transpose(1, 2).contiguous() # [B, E, T] x = F.silu(self.bn0(self.stem(x))) # [B, C0, T/2] x = self.stages(x) # [B, C, T/32] # Global pooling avg = x.mean(dim=-1) mx = x.amax(dim=-1) feats = torch.cat([avg, mx], dim=1) logits = self.head(feats).squeeze(1) return logits @classmethod def from_pretrained(cls, path_or_repo, use_safetensors=True): """Load pretrained model (supports both .bin and .safetensors)""" import os from pathlib import Path # Determine if it's a file or directory/repo if os.path.isdir(path_or_repo): # Directory path - look for model files base_path = Path(path_or_repo) safetensors_path = base_path / "model.safetensors" pytorch_path = base_path / "pytorch_model.bin" if use_safetensors and safetensors_path.exists(): # Load from safetensors from safetensors.torch import load_file state_dict = load_file(str(safetensors_path)) # Load config if available config_path = base_path / "config.json" if config_path.exists(): import json with open(config_path) as f: config_dict = json.load(f) config = type('Config', (), config_dict)() else: config = None model = cls(config) model.load_state_dict(state_dict) return model elif pytorch_path.exists(): checkpoint = torch.load(pytorch_path, weights_only=False, map_location='cpu') elif os.path.isfile(path_or_repo): if path_or_repo.endswith('.safetensors'): from safetensors.torch import load_file state_dict = load_file(path_or_repo) model = cls() model.load_state_dict(state_dict) return model else: checkpoint = torch.load(path_or_repo, weights_only=False, map_location='cpu') else: # HuggingFace hub loading from huggingface_hub import hf_hub_download if use_safetensors: try: model_file = hf_hub_download(repo_id=path_or_repo, filename="model.safetensors") from safetensors.torch import load_file state_dict = load_file(model_file) model = cls() model.load_state_dict(state_dict) return model except: pass # Fall back to pytorch format model_file = hf_hub_download(repo_id=path_or_repo, filename="pytorch_model.bin") checkpoint = torch.load(model_file, weights_only=False, map_location='cpu') # Load from checkpoint (pytorch format) if 'checkpoint' in locals(): config = checkpoint.get('config', None) model = cls(config) state_dict = checkpoint.get('model_state_dict', checkpoint) model.load_state_dict(state_dict) return model def save_pretrained(self, save_path): """Save model to directory""" import os os.makedirs(save_path, exist_ok=True) torch.save({ 'model_state_dict': self.state_dict(), 'config': self.config }, os.path.join(save_path, 'pytorch_model.bin')) def preprocess_text(text: str, max_len: int = 4096) -> torch.Tensor: """ Preprocess text to bytes for model input Args: text: Input text string max_len: Maximum sequence length (default 4096) Returns: Tensor of shape [1, max_len] containing byte values """ # Unicode NFC normalize text = unicodedata.normalize('NFC', text) # Replace \r\n → \n text = text.replace('\r\n', '\n') # Collapse runs of whitespace to at most 2 text = re.sub(r'\s{3,}', ' ', text) # Convert to bytes text_bytes = text.encode('utf-8', errors='ignore') # Pad or truncate to max_len input_ids = np.zeros(max_len, dtype=np.uint8) input_ids[:min(len(text_bytes), max_len)] = list(text_bytes[:max_len]) return torch.from_numpy(input_ids).unsqueeze(0) # Add batch dimension def classify_text(text: Union[str, List[str]], model=None, device='cpu'): """ Classify text as fiction or non-fiction Args: text: Single string or list of strings to classify model: Pre-loaded model (optional) device: Device to run on ('cpu', 'cuda', 'mps') Returns: Dictionary with predictions and confidence scores """ if model is None: model = TinyByteCNN.from_pretrained("fiction_classifier_hf") model = model.to(device) model.eval() # Handle single text or batch if isinstance(text, str): texts = [text] else: texts = text results = [] for t in texts: input_ids = preprocess_text(t).to(device) with torch.no_grad(): logits = model(input_ids) prob = torch.sigmoid(logits).item() pred_class = "Non-Fiction" if prob > 0.5 else "Fiction" confidence = prob if prob > 0.5 else (1 - prob) results.append({ 'text': t[:100] + '...' if len(t) > 100 else t, 'prediction': pred_class, 'confidence': confidence, 'probability_nonfiction': prob }) return results[0] if isinstance(text, str) else results if __name__ == "__main__": # Example usage sample_text = "The detective's coffee had gone cold hours ago, but she hardly noticed." # Load and use model model = TinyByteCNN.from_pretrained("fiction_model_output_cnn/best_model.pt") result = classify_text(sample_text, model) print(f"Text: {result['text']}") print(f"Prediction: {result['prediction']}") print(f"Confidence: {result['confidence']:.1%}")