faisalsns's picture
initial commit
7bb1e78
raw
history blame contribute delete
875 Bytes
### 6. **Inference Script** (`inference.py`)
import torch
import pandas as pd
from model import TabularModel
def load_model_and_predict(data):
# Load model
checkpoint = torch.load('ev_classifier_model.pth')
model = TabularModel(input_size=9, hidden_sizes=[128, 64, 32], output_size=2)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Get preprocessors
scaler = checkpoint['scaler']
label_encoders = checkpoint['label_encoders']
# Preprocess and predict
# ... (preprocessing code)
return predictions
# Example usage
if __name__ == "__main__":
sample_data = pd.DataFrame({
'model_year': [2021],
'make': ['TESLA'],
'model': ['MODEL 3'],
# ... other features
})
prediction = load_model_and_predict(sample_data)
print(f"Prediction: {prediction}")