<frozen importlib._bootstrap>:228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject
Train a model on MNIST
# Define data transformationstransform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=3), # Convert to RGB format transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),# convert dtype to float32# transforms.Lambda(lambda x: x.to(torch.float32)), ])
# Load pre-trained ResNet modelresnet = torchvision.models.resnet18(pretrained=True)print("Loaded pre-trained ResNet18 model")print(resnet.fc.in_features)# Modify the last fully connected layer to match MNIST's number of classes (10)num_classes =10resnet.fc = nn.Sequential( nn.Linear(resnet.fc.in_features, resnet.fc.in_features), nn.GELU(), nn.Linear(resnet.fc.in_features, num_classes),)# Freeze all layers except the last fully connected layerfor name, param in resnet.named_parameters(): param.requires_grad =Falseresnet.fc.requires_grad_(True)# Define loss and optimizercriterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(resnet.parameters(), lr=1e-4)# Training loopnum_epochs =50print(f"Training on device {device}")resnet.to(device)print("Training ResNet18 model")for epoch inrange(num_epochs): resnet.train() epoch_loss =0.0for images, labels in tqdm(train_loader): optimizer.zero_grad() outputs = resnet(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /=len(train_loader)print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}")# Evaluation resnet.eval() correct =0 total =0with torch.no_grad(): predicted_list = []for images, labels in test_loader: outputs = resnet(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f"Accuracy on the test set: {(100* correct / total):.2f}%")
/home/patel_zeel/miniconda3/envs/torch_dt/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/home/patel_zeel/miniconda3/envs/torch_dt/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
100%|██████████| 938/938 [00:03<00:00, 242.75it/s]
100%|██████████| 938/938 [00:03<00:00, 262.53it/s]
100%|██████████| 938/938 [00:03<00:00, 270.43it/s]
100%|██████████| 938/938 [00:03<00:00, 265.38it/s]
100%|██████████| 938/938 [00:03<00:00, 265.51it/s]
100%|██████████| 938/938 [00:03<00:00, 266.62it/s]
100%|██████████| 938/938 [00:03<00:00, 268.49it/s]
100%|██████████| 938/938 [00:03<00:00, 266.33it/s]
100%|██████████| 938/938 [00:03<00:00, 268.52it/s]
100%|██████████| 938/938 [00:03<00:00, 269.38it/s]
100%|██████████| 938/938 [00:03<00:00, 264.77it/s]
100%|██████████| 938/938 [00:03<00:00, 266.07it/s]
100%|██████████| 938/938 [00:03<00:00, 262.19it/s]
100%|██████████| 938/938 [00:04<00:00, 214.62it/s]
100%|██████████| 938/938 [00:04<00:00, 219.31it/s]
100%|██████████| 938/938 [00:04<00:00, 226.53it/s]
100%|██████████| 938/938 [00:05<00:00, 171.25it/s]
100%|██████████| 938/938 [00:03<00:00, 278.59it/s]
100%|██████████| 938/938 [00:03<00:00, 248.82it/s]
100%|██████████| 938/938 [00:05<00:00, 175.56it/s]
100%|██████████| 938/938 [00:03<00:00, 277.18it/s]
100%|██████████| 938/938 [00:03<00:00, 273.71it/s]
100%|██████████| 938/938 [00:03<00:00, 242.18it/s]
100%|██████████| 938/938 [00:03<00:00, 279.42it/s]
100%|██████████| 938/938 [00:03<00:00, 269.21it/s]
100%|██████████| 938/938 [00:04<00:00, 227.36it/s]
100%|██████████| 938/938 [00:04<00:00, 222.91it/s]
100%|██████████| 938/938 [00:04<00:00, 223.68it/s]
100%|██████████| 938/938 [00:03<00:00, 261.50it/s]
100%|██████████| 938/938 [00:03<00:00, 246.52it/s]
100%|██████████| 938/938 [00:03<00:00, 281.60it/s]
100%|██████████| 938/938 [00:03<00:00, 278.41it/s]
100%|██████████| 938/938 [00:03<00:00, 275.60it/s]
100%|██████████| 938/938 [00:03<00:00, 250.04it/s]
100%|██████████| 938/938 [00:03<00:00, 280.53it/s]
100%|██████████| 938/938 [00:03<00:00, 279.26it/s]
100%|██████████| 938/938 [00:04<00:00, 207.71it/s]
100%|██████████| 938/938 [00:03<00:00, 265.92it/s]
100%|██████████| 938/938 [00:03<00:00, 279.79it/s]
100%|██████████| 938/938 [00:03<00:00, 276.73it/s]
100%|██████████| 938/938 [00:03<00:00, 278.32it/s]
100%|██████████| 938/938 [00:03<00:00, 243.70it/s]
100%|██████████| 938/938 [00:03<00:00, 240.48it/s]
100%|██████████| 938/938 [00:03<00:00, 274.70it/s]
100%|██████████| 938/938 [00:03<00:00, 276.08it/s]
100%|██████████| 938/938 [00:04<00:00, 215.18it/s]
100%|██████████| 938/938 [00:03<00:00, 244.20it/s]
100%|██████████| 938/938 [00:03<00:00, 267.56it/s]
100%|██████████| 938/938 [00:03<00:00, 267.91it/s]
100%|██████████| 938/938 [00:03<00:00, 266.04it/s]
Loaded pre-trained ResNet18 model
512
Training on device cuda
Training ResNet18 model
Epoch [1/50] Loss: 1.0877
Accuracy on the test set: 75.42%
Epoch [2/50] Loss: 0.8051
Accuracy on the test set: 76.74%
Epoch [3/50] Loss: 0.7578
Accuracy on the test set: 78.27%
Epoch [4/50] Loss: 0.7290
Accuracy on the test set: 78.71%
Epoch [5/50] Loss: 0.7083
Accuracy on the test set: 79.62%
Epoch [6/50] Loss: 0.6761
Accuracy on the test set: 79.82%
Epoch [7/50] Loss: 0.6627
Accuracy on the test set: 80.47%
Epoch [8/50] Loss: 0.6423
Accuracy on the test set: 80.24%
Epoch [9/50] Loss: 0.6257
Accuracy on the test set: 81.11%
Epoch [10/50] Loss: 0.6131
Accuracy on the test set: 81.42%
Epoch [11/50] Loss: 0.5911
Accuracy on the test set: 82.02%
Epoch [12/50] Loss: 0.5765
Accuracy on the test set: 82.32%
Epoch [13/50] Loss: 0.5611
Accuracy on the test set: 82.30%
Epoch [14/50] Loss: 0.5466
Accuracy on the test set: 82.49%
Epoch [15/50] Loss: 0.5358
Accuracy on the test set: 82.81%
Epoch [16/50] Loss: 0.5266
Accuracy on the test set: 83.30%
Epoch [17/50] Loss: 0.5137
Accuracy on the test set: 83.37%
Epoch [18/50] Loss: 0.5051
Accuracy on the test set: 83.17%
Epoch [19/50] Loss: 0.4969
Accuracy on the test set: 83.46%
Epoch [20/50] Loss: 0.4811
Accuracy on the test set: 83.76%
Epoch [21/50] Loss: 0.4714
Accuracy on the test set: 83.57%
Epoch [22/50] Loss: 0.4624
Accuracy on the test set: 84.25%
Epoch [23/50] Loss: 0.4553
Accuracy on the test set: 84.27%
Epoch [24/50] Loss: 0.4506
Accuracy on the test set: 84.62%
Epoch [25/50] Loss: 0.4394
Accuracy on the test set: 83.97%
Epoch [26/50] Loss: 0.4346
Accuracy on the test set: 84.16%
Epoch [27/50] Loss: 0.4271
Accuracy on the test set: 84.38%
Epoch [28/50] Loss: 0.4193
Accuracy on the test set: 84.84%
Epoch [29/50] Loss: 0.4148
Accuracy on the test set: 85.05%
Epoch [30/50] Loss: 0.4040
Accuracy on the test set: 84.49%
Epoch [31/50] Loss: 0.3990
Accuracy on the test set: 84.59%
Epoch [32/50] Loss: 0.4016
Accuracy on the test set: 84.92%
Epoch [33/50] Loss: 0.3979
Accuracy on the test set: 85.01%
Epoch [34/50] Loss: 0.3844
Accuracy on the test set: 84.82%
Epoch [35/50] Loss: 0.3789
Accuracy on the test set: 85.49%
Epoch [36/50] Loss: 0.3760
Accuracy on the test set: 85.26%
Epoch [37/50] Loss: 0.3733
Accuracy on the test set: 85.36%
Epoch [38/50] Loss: 0.3655
Accuracy on the test set: 84.98%
Epoch [39/50] Loss: 0.3627
Accuracy on the test set: 85.19%
Epoch [40/50] Loss: 0.3517
Accuracy on the test set: 84.78%
Epoch [41/50] Loss: 0.3526
Accuracy on the test set: 85.43%
Epoch [42/50] Loss: 0.3523
Accuracy on the test set: 85.55%
Epoch [43/50] Loss: 0.3457
Accuracy on the test set: 85.02%
Epoch [44/50] Loss: 0.3447
Accuracy on the test set: 85.20%
Epoch [45/50] Loss: 0.3411
Accuracy on the test set: 85.47%
Epoch [46/50] Loss: 0.3312
Accuracy on the test set: 85.55%
Epoch [47/50] Loss: 0.3290
Accuracy on the test set: 85.52%
Epoch [48/50] Loss: 0.3277
Accuracy on the test set: 85.35%
Epoch [49/50] Loss: 0.3241
Accuracy on the test set: 85.80%
Epoch [50/50] Loss: 0.3217
Accuracy on the test set: 84.93%
# Evaluationresnet.eval()correct =0total =0with torch.no_grad(): predicted_list = []for images, labels in test_loader: outputs = resnet(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() softmax_outputs = nn.Softmax(dim=1)(outputs) predicted_list.append(softmax_outputs.data.cpu().numpy())all_predicted = np.concatenate(predicted_list, axis=0)print(f"Accuracy on the test set: {(100* correct / total):.2f}%")