| import segmentation_models_pytorch as smp | |
| import torch | |
| paths = [ | |
| "2_Class_CCBY_FTW_Pretrained.ckpt", | |
| "2_Class_FULL_FTW_Pretrained.ckpt", | |
| "3_Class_CCBY_FTW_Pretrained.ckpt", | |
| "3_Class_FULL_FTW_Pretrained.ckpt", | |
| ] | |
| classes = [2, 2, 3, 3] | |
| for num_classes, path in zip(classes, paths): | |
| state_dict = torch.load(path, weights_only=True, map_location="cpu")["state_dict"] | |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| del state_dict["criterion.weight"] | |
| model = smp.Unet(encoder_name="efficientnet-b3", in_channels=8, classes=num_classes, encoder_weights=None) | |
| model.load_state_dict(state_dict) | |
| torch.save(model.state_dict(), path.replace(".ckpt", ".pth")) |