| """ |
| Example script showing how to use the Hello World model with its dataset. |
| """ |
|
|
| from transformers import PreTrainedTokenizerFast |
| from model import HelloWorldModel, HelloWorldConfig |
| from datasets import load_dataset |
| import torch |
|
|
|
|
| def main(): |
| print("Loading Hello World Model and Dataset Example\n") |
| print("=" * 50) |
| |
| |
| print("Loading model and tokenizer...") |
| config = HelloWorldConfig.from_pretrained("chiedo/hello-world") |
| model = HelloWorldModel.from_pretrained("chiedo/hello-world") |
| tokenizer = PreTrainedTokenizerFast.from_pretrained("chiedo/hello-world") |
| |
| |
| print("\n1. Loading dataset using model's load_dataset method:") |
| dataset = HelloWorldModel.load_dataset("chiedo/hello-world") |
| |
| if dataset: |
| print(f"Dataset loaded successfully!") |
| print(f"Splits available: {list(dataset.keys())}") |
| print(f"Train examples: {len(dataset['train'])}") |
| print(f"Validation examples: {len(dataset['validation'])}") |
| print(f"Test examples: {len(dataset['test'])}") |
| |
| |
| print("\nFirst 3 training examples:") |
| for i in range(min(3, len(dataset['train']))): |
| example = dataset['train'][i] |
| print(f" {i+1}. Text: '{example['text']}', Label: {example['label']}") |
| |
| |
| print("\n2. Loading dataset directly with datasets library:") |
| dataset_direct = load_dataset("chiedo/hello-world") |
| |
| |
| label_names = dataset_direct['train'].features['label'].names |
| print(f"Label categories: {label_names}") |
| |
| |
| print("\n3. Processing a batch from the dataset:") |
| batch_texts = dataset_direct['train']['text'][:3] |
| print(f"Batch texts: {batch_texts}") |
| |
| |
| inputs = model.prepare_dataset_batch(batch_texts, tokenizer) |
| print(f"Tokenized input shape: {inputs['input_ids'].shape}") |
| |
| |
| print("\n4. Running model inference on dataset batch:") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| print(f"Model output shape: {outputs.logits.shape}") |
| |
| |
| print("\n5. Testing generate_hello_world function:") |
| result = model.generate_hello_world() |
| print(f"Generated output: {result}") |
| |
| |
| print("\n6. Iterating through test set:") |
| for i, example in enumerate(dataset_direct['test']): |
| if i >= 3: |
| break |
| text = example['text'] |
| label_id = example['label'] |
| label_name = label_names[label_id] |
| |
| |
| inputs = tokenizer(text, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| predicted_token = outputs.logits[0, -1].argmax().item() |
| |
| print(f" Text: '{text}' | Label: {label_name} | Predicted next token ID: {predicted_token}") |
| |
| print("\n" + "=" * 50) |
| print("Example completed successfully!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |