CNN-Based Deepfake Image Classification

This project implements a deepfake image classification model using a custom Convolutional Neural Network (CNN) built in PyTorch. The task is to distinguish between real and fake images using a binary classification approach. The dataset used for training comes from the Celeb-DF v2 dataset, from which a subsample of 5110 frames was selected. These frames are evenly split into two categories—real and fake—and stored in separate folders inside Google Drive. The code begins by mounting Google Drive to access these images. Standard deep learning libraries are imported including PyTorch, torchvision, matplotlib for visualization, and scikit-learn for evaluation metrics. The device configuration automatically detects whether a GPU is available and sets it as the computation device to accelerate training. All images are resized to 128x128 pixels using torchvision transforms and converted into tensors. Using PyTorch’s ImageFolder utility, the images are loaded and labeled automatically based on their folder names. The dataset is split into 80% training and 20% validation subsets using torch’s random_split. DataLoader is used to batch the data with a batch size of 32 and to shuffle the training data for each epoch. The model architecture is defined in a custom class called SimpleCNN. It consists of three convolutional layers, each followed by a LeakyReLU activation (slope 0.1) and a MaxPooling layer to reduce the spatial resolution. The final output from the last convolutional layer is flattened and passed through a fully connected linear layer, which outputs a single value between 0 and 1 using a Sigmoid activation, representing the probability of the image being fake. The model is trained using the Binary Cross-Entropy Loss (BCELoss) as the loss function and the RMSprop optimizer with a learning rate of 0.001. Training runs for 10 epochs. In each epoch, the model is put in training mode where it processes each batch, computes the loss, backpropagates the gradients, and updates the weights. After each epoch, the model switches to evaluation mode and computes the validation loss and accuracy without updating the weights. Training loss, validation loss, training accuracy, and validation accuracy are recorded at each epoch. After training is complete, matplotlib is used to plot the training and validation losses and accuracies across all epochs to visualize the learning progress. Finally, the model is evaluated on the validation set by collecting predictions and true labels. These are used to compute a classification report, which includes precision, recall, and F1-score for each class, and to generate a confusion matrix, which shows the number of correct and incorrect predictions for each class. The confusion matrix is visualized using matplotlib to provide insight into the model’s classification performance.