Energy-Based Generative Adversarial Network (EBGAN)
This folder contains an implementation of an Energy-Based Generative Adversarial Network (EBGAN) using PyTorch. EBGAN focuses on matching the energy distribution of generated samples to that of real data, optimizing both a discriminator and a generator network.
Overview
EBGAN introduces an energy function that is used to measure the quality of generated samples. The discriminator (autoencoder-like) network tries to minimize this energy function while the generator tries to maximize it. This results in a more stable training process compared to traditional GANs.
Usage
To use this implementation, follow these steps:
-
Clone the repository:
-
Install dependencies: Make sure you have Python 3 and pip installed. Then install the required dependencies:
This will install PyTorch, torchvision, matplotlib, and numpy. -
Train the cGAN: Run the
EBGAN.py
script to train the ACGAN model. This will train the ACGAN on the MNIST dataset and save the trained models (G_ebgan.pth
andD_ebgan.pth
). -
Generate new images: After training, you can generate new images using the trained generator by running the
This script loads the trained generator model and generates a grid of sample images.test_EBGAN.py
script.
Files
EBGAN.py
: Contains the implementation of the ACGAN model, training loop, and saving of trained models.test_EBGAN.py
: Uses the trained generator to generate sample images after training.
Contributing
Contributions are welcome! If you have ideas for improvements or new features, feel free to open an issue or submit a pull request.