GAN — генеративные состязательные сети


архитектура GAN
GAN — Generative Adversarial Networks — генеративные состязательные сети.

Пока мне не удалось найти устоявшегося русскоязычного названия.
Другие варианты:
генеративные соревновательные сети
порождающие соперничающие сети
порождающие соревнующиеся сети

Генеративные сети — это очень интересный класс нейронных сетей, которые учатся генерировать определённые объекты. Сейчас, подобные сети очень популярны и используются для самых разных задач — от генерирования пугающих картинок и суперразрешения до поиска лекарств от рака.

Впервые представлены в 2014 году Ian Goodfellow в [1].

В основе, лежит простая идея — давайте возьмём две нейронных сети и в ходе обучения заставим их соревноваться между собой: первая должна будет учиться обмануть вторую, а вторая — учиться не дать первой этого сделать.

Таким образом, генеративная cостязательная сеть, фактически, состоит из двух сетей:
1 — генератор (G)
2 — дискриминатор (D)

Генератор — нейронная сеть, которая получает на вход, так называемые, скрытые переменные (latent space) (случайный шум), а на выходе получаются данные ( изображение).

Дискриминатор – это обычный бинарный классификатор, который выдаёт:
1 — для реальных данных,
0 — для поддельных данных.

Подробнее с теорией работы GAN помогут статьи:
Автоэнкодеры в Keras, Часть 5: GAN(Generative Adversarial Networks) и tensorflow
Нейросетевая игра в имитацию

Пример реализации GAN в Keras

def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val

def create_gan(channels, height, width):

    input_img = Input(shape=(channels, height, width))

    m_height, m_width = int(height/8), int(width/8)

    # generator
    z = Input(shape=(latent_dim, ))
    x = Dense(256*m_height*m_width)(z)
    #x = BatchNormalization()(x)
    x = Activation('relu')(x)
    #x = Dropout(0.3)(x)

    x = Reshape((256, m_height, m_width))(x)

    x = Conv2DTranspose(256, kernel_size=(5, 5), strides=(2, 2), padding='same', activation='relu')(x)

    x = Conv2DTranspose(128, kernel_size=(5, 5), strides=(2, 2), padding='same', activation='relu')(x)

    x = Conv2DTranspose(64, kernel_size=(5, 5), strides=(2, 2), padding='same', activation='relu')(x)

    x = Conv2D(channels, (5, 5), padding='same')(x)
    g = Activation('tanh')(x)

    generator = Model(z, g, name='Generator')

    # discriminator
    x = Conv2D(128, (5, 5), padding='same')(input_img)
    #x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    #x = Dropout(0.3)(x)
    x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
    x = Conv2D(256, (5, 5), padding='same')(x)
    x = LeakyReLU()(x)
    x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
    x = Conv2D(512, (5, 5), padding='same')(x)
    x = LeakyReLU()(x)
    x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
    x = Flatten()(x)
    x = Dense(2048)(x)
    x = LeakyReLU()(x)
    x = Dense(1)(x)
    d = Activation('sigmoid')(x)

    discriminator = Model(input_img, d, name='Discriminator')

    gan = Sequential()
    gan.add(generator)
    make_trainable(discriminator, False) #discriminator.trainable = False
    gan.add(discriminator)

    return generator, discriminator, gan

gan_gen, gan_ds, gan = create_gan(channels, height, width)

gan_gen.summary()
gan_ds.summary()
gan.summary()

opt = Adam(lr=1e-3)
gopt = Adam(lr=1e-4)
dopt = Adam(lr=1e-4)

gan_gen.compile(loss='binary_crossentropy', optimizer=gopt)
gan.compile(loss='binary_crossentropy', optimizer=opt)

make_trainable(gan_ds, True)
gan_ds.compile(loss='binary_crossentropy', optimizer=dopt)

Процедура обучения GAN

* получаем порцию реальных картинок
* генерируем шум, на базе которого генератор генерирует картинки
* формируем батч для обучения дискриминатора, который состоит из реальных картинок (им присваивается метка 1) и подделок от генератора (метка 0)
* обучаем дискриминатор
* обучаем GAN (в нём обучается генератор, т.к. обучение дискриминатора отключено), подавая на вход шум и ожидая на выходе метку 1.

for epoch in range(epochs):
    print('Epoch {} from {} ...'.format(epoch, epochs))

    n = x_train.shape[0]
    image_batch = x_train[np.random.randint(0, n, size=batch_size),:,:,:]

    noise_gen = np.random.uniform(-1, 1, size=[batch_size, latent_dim])

    generated_images = gan_gen.predict(noise_gen, batch_size=batch_size)

    if epoch % 10 == 0:
        print('Save gens ...')
        save_images(generated_images)
        gan_gen.save_weights('temp/gan_gen_weights_'+str(height)+'.h5', True)
        gan_ds.save_weights('temp/gan_ds_weights_'+str(height)+'.h5', True)
        # save loss
        df = pd.DataFrame( {'d_loss': d_loss, 'g_loss': g_loss} )
        df.to_csv('temp/gan_loss.csv', index=False)

    x_train2 = np.concatenate( (image_batch, generated_images) )
    y_tr2 = np.zeros( [2*batch_size, 1] )
    y_tr2[:batch_size] = 1

    d_history = gan_ds.train_on_batch(x_train2, y_tr2)
    print('d:', d_history)
    d_loss.append( d_history )

    noise_gen = np.random.uniform(-1, 1, size=[batch_size, latent_dim])
    g_history = gan.train_on_batch(noise_gen, np.ones([batch_size, 1]))
    print('g:', g_history)
    g_loss.append( g_history )

Обратите внимание, что для обучения генератора не используются реальные изображения, а только метка дискриминатора. Т.е. генератор обучается на градиентах ошибки от дискриминатора.

Если посмотреть на кривые потерь, то видно, что дискриминатор быстро обучается отличать реальную картинку от первоначального мусора, выдаваемого генератором, но потом кривые начинают колебаться — генератор учится генерировать всё более подходящее изображение.
GAN - график обучения

Рекомендации по архитектуре из статьи конца 2015 года от facebook research про DCGAN (Deep Convolutional GAN):
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks,
а так же набор рекомендаций, позволяющий заставить GAN работать:
How to Train a GAN? Tips and tricks to make GANs work

Заключение
GAN — это очень интересный класс генеративных моделей, относящийся к обучению без учителя.
Сейчас, они очень популярны и демонстрируют интересные результаты своей работы в самых разных областях применения.

далее: Создание покемонов с помощью генеративных состязательных сетей

Статьи
1. Goodfellow I. et al. Generative Adversarial Networks – 2014.

Ссылки
Нейросетевая игра в имитацию
Автоэнкодеры в Keras, Часть 5: GAN(Generative Adversarial Networks) и tensorflow
Генеративные модели от OpenAI
GAN by Example using Keras on Tensorflow Backend
Deep Convolutional GANs (DCGAN): Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
How to Train a GAN? Tips and tricks to make GANs work
MNIST Generative Adversarial Model in Keras
Generative Adversarial Networks Part 2 — Implementation with Keras 2.0
Keras implementation of Deep Convolutional Generative Adversarial Networks (DCGAN)
Учим робота готовить пиццу. Часть 2: Состязание нейронных сетей

По теме
Сегментация изображений при помощи нейронной сети: U-Net
Детектирование объектов — нейросетевой подход

Нейронная сеть
Нейронная сеть — введение
Принцип обучения многослойной нейронной сети с помощью алгоритма обратного распространения
Пример работы самоорганизующейся инкрементной нейронной сети SOINN


Добавить комментарий

Arduino

Что такое Arduino?
Зачем мне Arduino?
Начало работы с Arduino
Для начинающих ардуинщиков
Радиодетали (точка входа для начинающих ардуинщиков)
Первые шаги с Arduino

Разделы

  1. Преимуществ нет, за исключением читабельности: тип bool обычно имеет размер 1 байт, как и uint8_t. Думаю, компилятор в обоих случаях…

  2. Добрый день! Я недавно начал изучать программирование под STM32 и ваши уроки просто бесценны! Хотел узнать зачем использовать переменную типа…

3D-печать AI Arduino Bluetooth CraftDuino DIY Google IDE iRobot Kinect LEGO OpenCV Open Source Python Raspberry Pi RoboCraft ROS swarm ИК автоматизация андроид балансировать бионика версия видео военный датчик дрон интерфейс камера кибервесна манипулятор машинное обучение наше нейронная сеть подводный пылесос работа распознавание робот робототехника светодиод сервомашинка собака управление ходить шаг за шагом шаговый двигатель шилд юмор

OpenCV
Робототехника
Будущее за бионическими роботами?
Нейронная сеть - введение