DEV Community

Mathys Daviet
Mathys Daviet

Posted on

Transform WGAN into cGAN

"I am working on developing a Generative Adversarial Network (GAN) with the aim of generating new microstructures based on their characteristics. The objective is to create a microstructure using a given characteristic, provided to the GAN in vector form. This process is implemented using a database containing 40,000 pairs of microstructures and their corresponding characteristics. I have already coded a Wasserstein GAN (WGAN) that successfully generates coherent microstructures from the database, although it currently lacks a connection to the specified characteristics. Additionally, I have coded a conditional GAN (cGAN) that operates on the MNIST dataset. However, I require your assistance in merging these two code structures. Thank you very much for any help you can provide!

##cGAN
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm

# Common config
batch_size = 64

# Generator config
sample_size = 100  # Random sample size
g_alpha = 0.01  # LeakyReLU alpha
g_lr = 1.0e-4  # Learning rate

# Discriminator config
d_alpha = 0.01  # LeakyReLU alpha
d_lr = 1.0e-4  # Learning rate

# Data Loader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


# Coverts conditions into feature vectors
class Condition(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # From one-hot encoding to features: 10 => 784
        self.fc = nn.Sequential(
            nn.Linear(10, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

    def forward(self, labels: torch.Tensor):
        # One-hot encode labels
        x = F.one_hot(labels, num_classes=10)

        # From Long to Float
        x = x.float()

        # To feature vectors
        return self.fc(x)


# Reshape helper
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()

        self.shape = shape

    def forward(self, x):
        return x.reshape(-1, *self.shape)


# Generator network
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32,
                               kernel_size=5, stride=2, padding=2,
                               output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1,
                               kernel_size=5, stride=2, padding=2,
                               output_padding=1, bias=False),
            nn.Sigmoid())

        # Random value sample size
        self.sample_size = sample_size

        # To convert labels into feature vectors
        self.cond = Condition(alpha)

    def forward(self, labels: torch.Tensor):
        # Labels as feature vectors
        c = self.cond(labels)

        # Batch size is the number of labels
        batch_size = len(labels)

        # Generate random inputs
        z = torch.randn(batch_size, self.sample_size)

        # Inputs are the sum of random inputs and label features
        x = self.fc(z)  # => 784
        x = self.reshape(x + c)  # => 16 x 7 x 7
        x = self.conv1(x)  # => 32 x 14 x 14
        x = self.conv2(x)  # => 1 x 28 x 28
        return x


# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32,
                      kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16,
                      kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

        # Reshape label features: 784 => 16 x 7 x 7
        self.cond = nn.Sequential(
            Condition(alpha),
            Reshape(16, 7, 7))

    def forward(self, images: torch.Tensor,
                labels: torch.Tensor,
                targets: torch.Tensor):
        # Label features
        c = self.cond(labels)

        # Image features + Label features => real or fake?
        x = self.conv1(images)  # => 32 x 14 x 14
        x = self.conv2(x)  # => 16 x 7 x 7
        prediction = self.fc(x + c)  # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# To save grid images
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)  # Into a grid
    image_grid = image_grid.permute(1, 2, 0)  # Channel to last
    image_grid = image_grid.cpu().numpy()  # Into Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# Real / Fake targets
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)

# Generator and discriminator
generator = Generator(sample_size, g_alpha)
discriminator = Discriminator(d_alpha)

# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)

# Train loop
for epoch in range(100):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):
        # ===============================
        # Disciminator Network Training
        # ===============================

        # Images from MNIST are considered as real
        d_loss = discriminator(images, labels, real_targets)

        # Images from Generator are considered as fake
        d_loss += discriminator(generator(labels), labels, fake_targets)

        # Discriminator paramter update
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============================
        # Generator Network Training
        # ===============================

        # Images from Generator should be as real as ones from MNIST
        g_loss = discriminator(generator(labels), labels, real_targets)

        # Generator parameter update
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Keep losses for logging
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # Print loss
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # Save generated images
    labels = torch.LongTensor(list(range(10))).repeat(8).flatten()
    save_image_grid(epoch, generator(labels), ncol=10)

##WGAN
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.layers import Dropout
# from tensorflow.keras.constraints import ClipConstraint

import numpy as np
import pandas as pd
from utils import load_parquet_files, wasserstein_loss, generator_loss, ClipConstraint, generate_batches, Conv2DCircularPadding
import matplotlib.pyplot as plt
import os
from vtk import vtkStructuredPoints, vtkXMLImageDataWriter
import vtk
from vtkmodules.util import numpy_support
# from sklearn.model_selection import train_test_split

class Generator(models.Model):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.model = self.build_model()

    # def build_model(self):
    #     model = models.Sequential()
    #     model.add(layers.Dense(16 * 16 * 256, input_dim=self.latent_dim, activation='relu'))
    #     model.add(layers.Reshape((16, 16, 256)))  # La taille avant la convolution transposee
    #     model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     model.add(layers.BatchNormalization())  # Batch normalization
    #     model.add(layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     model.add(layers.BatchNormalization())  # Batch normalization
    #     model.add(layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     model.add(layers.BatchNormalization())  # Batch normalization
    #     # model.add(layers.Conv2DTranspose(16, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     # model.add(layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='same', activation='hard_sigmoid'))
    #     model.add(layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='same', activation='sigmoid'))
    #     return model

    def build_model(self):
        model = models.Sequential()
        model.add(layers.Dense(8 * 8 * 512, input_dim=self.latent_dim, activation='relu'))
        model.add(layers.Reshape((8, 8, 512)))
        model.add(layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(Dropout(0.25))
        model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(Dropout(0.25))
        model.add(layers.UpSampling2D())
        model.add(layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.UpSampling2D())
        # model.add(layers.Conv2D(1, kernel_size=3, padding='same', activation='hard_sigmoid'))
        model.add(layers.Conv2D(1, kernel_size=3, padding='same', activation='sigmoid'))

        return model

    def call(self, inputs):
        return self.model(inputs)

class Discriminator(models.Model):
    def __init__(self, circ_pad):
        super(Discriminator, self).__init__()
        self.circ_pad = circ_pad
        self.model = self.build_model()

    # def build_model(self):
    #     model = models.Sequential()
    #     model.add(layers.Conv2D(64, kernel_size=4, strides=2, padding='same', kernel_constraint=ClipConstraint(0.5), input_shape=(256, 256, 1)))
    #     model.add(layers.LeakyReLU(alpha=0.2))
    #     model.add(Dropout(0.25))  # Ajout de Dropout
    #     model.add(layers.Conv2D(128, kernel_size=4, strides=2, padding='same', kernel_constraint=ClipConstraint(0.5)))
    #     model.add(layers.LeakyReLU(alpha=0.2))
    #     model.add(Dropout(0.25))  # Ajout de Dropout
    #     model.add(layers.Flatten())
    #     model.add(layers.Dense(1, activation='linear'))
    #     return model

    # PADDING NORMAL
    def build_model(self):
        if not self.circ_pad :
            model = models.Sequential()
            model.add(layers.Conv2D(32, kernel_size=3, strides=1, padding='same', kernel_constraint=ClipConstraint(0.2), input_shape=(256, 256, 1)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(layers.Conv2D(64, kernel_size=5, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(layers.Conv2D(128, kernel_size=7, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(layers.Conv2D(256, kernel_size=7, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(layers.Conv2D(512, kernel_size=5, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(layers.Flatten())
            model.add(layers.Dense(1, activation='linear'))
            return model

        # PADDING CIRCULAIRE
        if self.circ_pad :
            model = models.Sequential()
            model.add(Conv2DCircularPadding(32, kernel_size=3, strides=1, input_shape=(256, 256, 1)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(Conv2DCircularPadding(64, kernel_size=5, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(Conv2DCircularPadding(128, kernel_size=7, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(Conv2DCircularPadding(256, kernel_size=9, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(Conv2DCircularPadding(512, kernel_size=5, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(layers.Flatten())
            model.add(layers.Dense(1, activation='linear'))
            return model

    def call(self, inputs):
        return self.model(inputs)

class GAN(models.Model):
    def __init__(self, generator, discriminator, data_path):
        super(GAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.compile_discriminator()
        self.compile_gan()
        self.data_path = data_path
        self.data = load_parquet_files(self.data_path, test = False)
    def compile_discriminator(self):
        self.discriminator.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.000005), metrics=['accuracy'])
        self.discriminator.trainable = False

    def compile_gan(self):
        z = layers.Input(shape=(self.generator.latent_dim,))
        fake_image = self.generator(z)
        validity = self.discriminator(fake_image)
        self.model = models.Model(z, validity)
        # self.model.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.0001))
        self.model.compile(loss=generator_loss, optimizer=RMSprop(lr=0.0005))


    def generate_latent_points(self, latent_dim, n_samples):
        x_input = np.random.randn(latent_dim * n_samples)
        x_input = x_input.reshape(n_samples, latent_dim)
        return x_input

    def generate_latent_points_uniform(self, latent_dim, n_samples):
        x_input = np.random.uniform(-1, 1, size = (n_samples, latent_dim))
        # x_input = x_input.reshape(n_samples, latent_dim)
        return x_input
    def generate_real_samples(self, n_samples):
        dfs = self.data
        dfs_array = [df.to_numpy() for df in dfs]
        # np.random.shuffle(dfs_array)

        sampled_indices = np.random.choice(len(dfs_array), size=n_samples, replace=False)

        # Sélectionner les arrays échantillonnés
        real_samples = [dfs_array[i] for i in sampled_indices]
        real_samples = np.stack(real_samples, axis=0)
        real_samples = np.expand_dims(real_samples, axis=-1)
        labels = -(np.ones((n_samples, 1)))

        return real_samples, labels

    def generate_and_save_samples(self, epoch, latent_dim, n_samples, output_dir):
        # Générer des exemples avec le générateur
        z = self.generate_latent_points(latent_dim, n_samples)
        generated_samples = self.generator.predict(z)
        # binary_generated_samples = (generated_samples > 0.5).astype(np.float32)

        for i in range(3):
             # np.save(os.path.join(output_dir, f'generated_example_epoch{epoch}_sample{i}.npy'), binary_generated_samples[i])
             np.save(os.path.join(output_dir, f'generated_example_epoch{epoch}_sample{i}.npy'), generated_samples[i])

# def train_gan(generator, discriminator, gan, latent_dim, n_epochs, n_batch, output_dir):
#     d_losses, g_losses = [], []
#     current_epoch = 0  # Ajoutez cette ligne
#
#     for epoch in range(n_epochs):
#         current_epoch += 1  # Ajoutez cette ligne
#         for _ in range(n_batch):
#             z = gan.generate_latent_points(latent_dim, n_batch)
#             X_fake = generator.predict(z)
#             # X_fake = tf.cast(X_fake > 0.5, tf.float32)
#             X_real, y_real = gan.generate_real_samples(n_samples = n_batch)
#
#             # Entraînement du discriminateur
#             d_loss_real = discriminator.train_on_batch(X_real, y_real)
#             d_loss_fake = discriminator.train_on_batch(X_fake, -np.ones((n_batch, 1)))
#             d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
#
#             # Entraînement du générateur
#             z = gan.generate_latent_points(latent_dim, n_batch)
#             y_gan = np.ones((n_batch, 1))
#             g_loss = gan.model.train_on_batch(z, y_gan)
#
#         # Enregistrement des pertes pour la visualisation
#         d_losses.append(d_loss[0])
#         g_losses.append(g_loss)
#
#         # Affichage des résultats et sauvegarde des exemples générés
#         print(f"Epoch {current_epoch}, [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}], [G loss: {g_loss}]")
#         gan.generate_and_save_samples(current_epoch, latent_dim, n_batch, output_dir)
#
#     # Affichage des courbes d'entraînement
#     plot_training_history(d_losses, g_losses)
def train_wgan(generator, discriminator, gan, latent_dim, n_epochs, n_critic, batch_size, output_dir, circ_pad):
    d_losses, g_losses = [], []
    current_epoch = 0
    # num_batches = int(np.ceil(len(gan.data) / batch_size))

    for epoch in range(n_epochs):
        # génération des batchs
        batches = generate_batches(gan.data, batch_size)
        # num_batches = len(batches)
        current_epoch += 1

        # for _ in range(batch_size):
        for batch in batches:
            # Update the critic (discriminator) multiple times
            for _ in range(n_critic):
                z = gan.generate_latent_points(latent_dim, batch_size)
                X_fake = generator.predict(z)
                # X_real, y_real = gan.generate_real_samples(n_samples=batch_size)

                # Expand dims and stacking

                # if current_epoch == 1 :
                #     print(batch[0].shape)

                # real_sample_batch = np.array([np.stack(sample, axis=0) for sample in batch])
                real_sample_batch = np.array([np.expand_dims(sample, axis = -1) for sample in batch])

                # if current_epoch == 1 :
                #     print(real_sample_batch[0].shape)

                X_real, y_real = real_sample_batch, -(np.ones((batch_size, 1)))

                d_loss_real = discriminator.train_on_batch(X_real, y_real)
                d_loss_fake = discriminator.train_on_batch(X_fake, np.ones((batch_size, 1)))  # Use +1 as the target for fake samples
                # d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                d_loss = np.mean(np.array(d_loss_fake) - np.array(d_loss_real))  # Wasserstein loss

                # Clip the weights of the discriminator
                if circ_pad :
                    for layer in discriminator.model.layers:
                        weights = layer.get_weights()
                        weights = [np.clip(w, -0.2, 0.2) for w in weights]
                        layer.set_weights(weights)

            # Update the generator
            z = gan.generate_latent_points(latent_dim, batch_size)
            y_gan = np.ones((batch_size, 1))
            g_loss = gan.model.train_on_batch(z, y_gan)

            # # Record losses for visualization
            # d_losses.append(d_loss[0])
            # g_losses.append(g_loss)

            # Record losses for visualization
            d_losses.append(-d_loss)  # Negative of Wasserstein loss for the critic
            g_losses.append(g_loss)

        # Display results and save generated samples
        # print(f"Epoch {current_epoch}, [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}], [G loss: {g_loss}]")
        print(f"Epoch {current_epoch}, [D loss: {d_losses[-1]}], [G loss: {g_loss}]")
        gan.generate_and_save_samples(current_epoch, latent_dim, batch_size, output_dir)

    # Display training curves
    plot_training_history(d_losses, g_losses)
def plot_training_history(d_losses, g_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss', linestyle='--')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('WGAN Training History')
    plt.show()


# Choix du padding :
circ_pad = True

# Définir le nombre d'époques et la taille du lot
latent_dim = 100
n_epochs = 2
n_batch = 32
# data_path = '/Users/gregoirecolin/Documents/4A/Projet 4A/2023-06-26_projet_etudiants_mines_ML/data/preprocess_data_reduit'  # Définir le chemin d'accès aux données
data_path = '/Users/mathys/Documents/Projet 3A/preprocessed_data'  # Définir le chemin d'accès aux données
output_directory = '/Users/mathys/Documents/Projet 3A/result_WGAN'  # Remplacez par le chemin de votre choix

if not  os.path.exists(output_directory):
    os.makedirs(output_directory)

# Créer les instances des classes
generator = Generator(latent_dim)
# generator.summary()
discriminator = Discriminator(circ_pad)
# discriminator.summary()
gan = GAN(generator, discriminator, data_path)

generator.summary()
discriminator.summary()

# Entraîner le GAN
train_wgan(generator, discriminator, gan, latent_dim, n_epochs, n_critic = 1, batch_size = n_batch, output_dir = output_directory, circ_pad = circ_pad)

# Générer des exemples avec le générateur après l'entraînement
z = gan.generate_latent_points(latent_dim=latent_dim, n_samples=n_batch)
generated_samples = generator(z)
# binary_generated_samples = tf.cast(generated_samples > 0.5, tf.float32)

generator_weights = [layer.get_weights()[0].flatten() for layer in generator.layers if len(layer.get_weights()) > 0]
discriminator_weights = [layer.get_weights()[0].flatten() for layer in discriminator.layers if len(layer.get_weights()) > 0]

plt.figure(figsize=(10,5))
for weights in generator_weights :
    plt.hist(weights, bins = 50, alpha = 0.5)
plt.title('Histogramme des poids du générateur')
plt.show()

plt.figure(figsize=(10,5))
for weights in discriminator_weights :
    plt.hist(weights, bins = 50, alpha = 0.5)
plt.title('Histogramme des poids du discriminateur')
plt.show()

##utils

import os
import pandas as pd
import tensorflow as tf
import numpy as np
from tensorflow.keras import backend
from tensorflow.keras.constraints import Constraint
from tensorflow.keras import layers
def load_parquet_files(root_folder, test):
    dfs = []
    # Si on veut juste un échantillon de données
    if test :
        k = 0

    # Parcourir tous les sous-dossiers dans le chemin spécifié
    # Parcourir les dossiers dans data_path
    for folder in os.listdir(root_folder):
        folder_path = os.path.join(root_folder, folder)
        if os.path.isdir(folder_path):
            # Charger les fichiers parquet dans le dossier

            # Parcourir tous les fichiers dans le dossier
            for filename in os.listdir(folder_path):
                file_path = os.path.join(folder_path, filename)

                # Vérifier si le fichier est un fichier Parquet
                if filename.endswith(".parquet"):
                    # Charger le fichier Parquet dans un DataFrame
                    df = pd.read_parquet(file_path)

                    # Ajouter le DataFrame à la liste
                    dfs.append(df)

        if test :
            k+=1
            if k >1000 :
                break

    return dfs

def generate_batches(data, batch_size):
    data_np = [df.to_numpy() for df in data]
    np.random.shuffle(data_np)
    batches = [data_np[i:i+batch_size] for i in range(0, len(data_np), batch_size)]

    if len(batches[-1]) != batch_size :
        batches.pop()

    return batches
def wasserstein_loss(y_true, y_pred):
    # return tf.reduce_mean(y_true * y_pred)
    return backend.mean(y_true * y_pred)


def generator_loss(y_true, y_pred):
    return -tf.reduce_mean(y_pred)


class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value

    def __call__(self, weights):
        return tf.clip_by_value(weights, -self.clip_value, self.clip_value)

    def get_config(self):
        return{'clip_value': self.clip_value}

class Conv2DCircularPadding(layers.Layer):
    def __init__(self, filters, kernel_size, strides=(1, 1), activation=None, **kwargs):
        super(Conv2DCircularPadding, self).__init__(**kwargs)
        self.conv = layers.Conv2D(filters, kernel_size, strides=strides, padding='valid', activation=activation)

    def call(self, input_tensor):
        # Taille du padding basée sur la taille du kernel
        pad_size = self.conv.kernel_size[0] - 1
        half_pad = pad_size // 2

        # Padding circulaire
        padded_input = tf.concat([input_tensor[:, -half_pad:, :], input_tensor, input_tensor[:, :half_pad, :]], axis=1)
        padded_input = tf.concat([padded_input[:, :, -half_pad:], padded_input, padded_input[:, :, :half_pad]], axis=2)

        # Application de la convolution
        return self.conv(padded_input)

    def get_config(self):
        config = super(Conv2DCircularPadding, self).get_config()
        config.update({"conv": self.conv})
        return config
Enter fullscreen mode Exit fullscreen mode

AWS GenAI LIVE image

How is generative AI increasing efficiency?

Join AWS GenAI LIVE! to find out how gen AI is reshaping productivity, streamlining processes, and driving innovation.

Learn more

Top comments (0)

Billboard image

Create up to 10 Postgres Databases on Neon's free plan.

If you're starting a new project, Neon has got your databases covered. No credit cards. No trials. No getting in your way.

Try Neon for Free →

👋 Kindness is contagious

Please leave a ❤️ or a friendly comment on this post if you found it helpful!

Okay