|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
from model import TabularModel |
|
|
|
|
|
def load_model_and_predict(data): |
|
|
|
|
|
checkpoint = torch.load('ev_classifier_model.pth') |
|
|
model = TabularModel(input_size=9, hidden_sizes=[128, 64, 32], output_size=2) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
scaler = checkpoint['scaler'] |
|
|
label_encoders = checkpoint['label_encoders'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sample_data = pd.DataFrame({ |
|
|
'model_year': [2021], |
|
|
'make': ['TESLA'], |
|
|
'model': ['MODEL 3'], |
|
|
|
|
|
}) |
|
|
|
|
|
prediction = load_model_and_predict(sample_data) |
|
|
print(f"Prediction: {prediction}") |