Machine Learning - Convolutional Neural Networks (CNN)
Overview
CNNs exploit spatial locality with convolutional filters and shared weights, enabling efficient learning on images and other grid-like data.
Minimal CNN (PyTorch)
import torch, torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(), nn.Linear(32*7*7, 10)
)
def forward(self, x):
return self.classifier(self.features(x))
net = Net(); x = torch.randn(8,1,28,28); y = torch.randint(0,10,(8,))
loss_fn = nn.CrossEntropyLoss(); opt = torch.optim.Adam(net.parameters(), 1e-3)
opt.zero_grad(); loss = loss_fn(net(x), y); loss.backward(); opt.step()