GAN и диффузионные модели: как научить нейросеть рисовать

в 11:16, , рубрики: AI-арт, GAN, pytorch, stable diffusion, генеративные сети, датасеты, диффузионные модели, искусственный интеллект, машинное обучение, синтетические данные
GAN и диффузионные модели: как научить нейросеть рисовать - 1

Привет! Сегодня хочу поговорить о двух очень горячих темах в области искусственного интеллекта — генеративно‑состязательные сети (GAN) и диффузионные модели (типа Stable Diffusion). Я сама как‑то подсела на все эти AI‑картинки и поняла, что нужно срочно поделиться тем что накопала. Поехали!:‑)

GAN: Генератор vs. Дискриминатор

Что это вообще и как работает?

Вообрази, что у нас есть два персонажа, которые начинают баттл: генератор (он придумывает новые картинки) и дискриминатор (он проверяет качество этих картинок и говорит фейк это или оригинал). Как в классическом фильме про хорошего и плохого копа, только у нас генератор - художник, а дискриминатор — критик.

  • Генератор (G) получает случайный шум (например, вектор из случайных чисел) и пытается преобразовать его в настоящую картинку (или другой тип данных — картинки, тексты, что угодно).

  • Дискриминатор (D) пытается понять, настоящая ли картинка (из реальной выборки) или сгенерированная (из генератора).

Они тренируются вместе. Генератор хочет обмануть дискриминатор, а дискриминатор — научиться отлавливать подделки. С течением обучения оба становятся всё круче: генератор рисует всё лучше, дискриминатор всё лучше распознаёт фейки. Когда система уравновешивается, генератор может делать настолько реалистичные изображения (или данные), что дискриминатор часто сдаётся и принимает подделку за правду.

Примеры применения GAN

  • Нарисуй котика на коньках: Кто-то делает модель, которая по твоему описанию генерирует изображение. Вот тебе котик, во львиной шубе, на коньках, в закатном свете!

  • Синтетические данные: Иногда нужна куча данных, которых не так просто получить в реальности. Скажем, у тебя стартап, связанный с распознаванием автомобильных номеров (и у тебя мало примеров). GAN может сгенерировать фотки номеров разных типов, осветлённые/потемнённые, с разными шрифтами и так далее.

Диффузионные модели: немножко магии

Stable Diffusion и друзья

Если GAN - это битва генератора и дискриминатора, то в диффузионных моделях (например, Stable Diffusion) используется другая концепция. Тут берётся картинка и постепенно зашумляется (добавляется шум до полного хаоса). А потом модель учится в обратном порядке вычистить шум, восстанавливая исходный вид. По сути, она в процессе обучения узнаёт, как переходить из случайного шума к осмысленному изображению.

Почему это классно?

  • Шаг за шагом: Модель постепенно отбрасывает шум, восстанавливая детали. Этот процесс напоминает волшебное появление картинки из белого шума.

  • Stable Diffusion умеет по тексту создавать (или дорисовывать) изображения, вставлять новую деталь в уже существующую картинку, менять фон и многое другое.

На что это похоже на практике?

  • "Нарисуй робота, собирающего цветы в поле" - пожалуйста! Модель на вход получает твой текст, генерит нечто психоделическое. Или, если хорошо подобрать подсказку (промпт), может получиться очень реалистичная или, наоборот, сюрреалистичная картинка.

  • Создание стиля: Можешь задать дорисуй это лицо в стиле аниме - и модель выдаст вариант в стиле японской анимации.

  • Идеально для мемов: Фанаты по всему миру генерируют мемы, стикеры, всякие шуточные коллажи. Интернет буквально пестрит этими AI-картинками.

Как начать эксперименты?

Выбери библиотеку

Когда речь заходит о том, чтобы начать свой путь в мире GAN или диффузионных моделей, выбор библиотеки играет ключевую роль. PyTorch считается одним из самых популярных вариантов благодаря обилию обучающих материалов и готовых проектов для GAN, что упрощает знакомство с этими сетями и ускоряет первые эксперименты. Если же вам ближе философия и стиль TensorFlow/Keras, здесь тоже есть огромное количество туториалов для новичков, позволяющих освоить основные принципы генеративных алгоритмов. А для работы с диффузионными моделями, особенно со Stable Diffusion, отлично подойдёт библиотека diffusers от Hugging Face: с её помощью буквально за пару строк кода можно запустить генерацию изображений и попробовать самые современные подходы к созданию реалистичных картинок.

Найди датасет

MNIST (просто цифры) банально, но удобно, если хочется понять механику GAN. CIFAR-10 (разные маленькие картинки) чаще всего используют, чтобы тренировать базовые GAN-модели. CelebA (лица знаменитостей) классический датасет, на котором многие показывают возможности GAN’ов.

Для диффузионных моделей обычно нужно очень много картинок, поэтому многие используют общедоступные большие датасеты или берут уже натренированные модели (как Stable Diffusion).

Подводные камни

  • Переобучение (Overfitting): Твой генератор может выучить одну-две картинки и постоянно генерить что-то на них похожее. Нужно следить за тем, чтобы у модели было достаточно разнообразных данных.

  • Нестабильность обучения: GAN’ы могут не сходиться (генератор и дискриминатор не могут найти баланс). Часто это решают настройкой гиперпараметров, выборами архитектур (например, WGAN, DCGAN).

  • Выбор правильного промпта (для диффузионных моделей): Если фраза непонятная или слишком короткая, результат может быть диким. Рекомендуется экспериментировать с ключевыми словами вроде hyperrealistic, 4k, trending on ArtStation и т.д.

Практика - наше всё!

Начать можно с простого запуска демо-версии, ведь в Google Colab уже есть готовые ноутбуки с настроенным окружением: достаточно лишь внести несколько изменений в код, загрузить собственные данные и наблюдать, как всё это работает. После этого самое время поэкспериментировать с гиперпараметрами, меняя размер batch, количество эпох и архитектуру сети, чтобы увидеть, как они влияют на качество генерации. И, конечно же, не стоит оставаться в одиночестве: Reddit, Kaggle и различные Discord-сервера — отличные места, где можно задать вопросы, поделиться успехами и получить советы от более опытных коллег по увлечению.

Парочка примеров на python

Минимальный GAN на MNIST (PyTorch)

В этом примере мы берём датасет MNIST (чёрно-белые цифры 28x28) и обучаем GAN генерировать «нарисованные» цифры. Код демонстрационный, поэтому детали (архитектура, гиперпараметры и т.д.) можно дорабатывать для лучшего результата.

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# Устройство (GPU при наличии)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Параметры
batch_size = 128
z_dim = 100     # размер вектора шума
num_epochs = 5  # для демо достаточно нескольких эпох

# Трансформации: переведём в тензор и нормализуем
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # стандартизируем к диапазону [-1,1]
])

# Датасет: MNIST
mnist_data = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    transform=transform,
    download=True
)
dataloader = torch.utils.data.DataLoader(
    mnist_data,
    batch_size=batch_size,
    shuffle=True
)

class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=28*28):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # выход в диапазоне [-1, 1]
        )
    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid(),  # выход 0..1 (вероятность подлинности)
        )
    def forward(self, x):
        return self.disc(x)

# Создаём экземпляры сети
gen = Generator(z_dim).to(device)
disc = Discriminator().to(device)

# Функция потерь и оптимизаторы
criterion = nn.BCELoss()  # бинарная кроссэнтропия
lr = 2e-4
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)

# Обучение
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.view(-1, 28*28).to(device)
        batch_size_curr = real.shape[0]

        # ============ Тренируем дискриминатор ============
        # Шум -> фейковые картинки
        noise = torch.randn(batch_size_curr, z_dim).to(device)
        fake = gen(noise)

        # Вычисляем вероятность для реальных и фейковых
        disc_real = disc(real).view(-1)
        disc_fake = disc(fake.detach()).view(-1)

        # Метки (1 - настоящие, 0 - фейк)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        # ============ Тренируем генератор ============
        # Новые фейковые картинки, оцениваем дискриминатором
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))  # хотим, чтобы дискрим. сказал 1

        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")

# Проверка генерации: берём шум и смотрим, что получится
import matplotlib.pyplot as plt

gen.eval()
with torch.no_grad():
    sample_noise = torch.randn(16, z_dim).to(device)
    generated = gen(sample_noise).view(-1, 1, 28, 28)
    generated = (generated + 1) / 2  # денормализация из [-1,1] в [0,1]

# Отобразим 16 сгенерированных цифр
fig, axes = plt.subplots(4, 4, figsize=(6,6))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

Что посмотреть/поменять:

  • Можно добавлять RandomCrop, RandomHorizontalFlip и т.д. для более разнообразной генерации.

  • Увеличить архитектуру (больше слоёв, сверточные слои — см. DCGAN).

  • Использовать другие датасеты (CIFAR-10, CelebA и т.д.).

Запуск Stable Diffusion с помощью diffusers

Чтобы потрогать диффузионную модель, достаточно нескольких строк кода. Ниже пример генерации изображения по текстовому описанию. Для этого нужно установить библиотеку diffusers, а также transformers и accelerate (лучше в отдельном виртуальном окружении).

pip install diffusers transformers accelerate
import torch
from diffusers import StableDiffusionPipeline

# Подключаем модель Stable Diffusion
model_id = "runwayml/stable-diffusion-v1-5"

# Загружаем пайплайн
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")  # если есть GPU

prompt = "A beautiful cyberpunk cityscape at night, high resolution, trending on artstation"

# Генерируем изображение
image = pipe(prompt).images[0]

image.save("cyberpunk_city.png")
print("Изображение сохранено как 'cyberpunk_city.png'")

Что посмотреть/поменять:

  • prompt. Экспериментируйте с ключевыми словами типа “hyperrealistic, 4k, photorealistic” и прочими деталями.

  • Количество шагов. Можно увеличить num_inference_steps, чтобы добиться более детальной картинки.

  • Разные семплеры (e.g. scheduler = DPMSolverMultistepScheduler) дают немного разные результаты.

Заключение

Нейронные сети, которые создают изображения (а иногда тексты, музыку, видео), - это потрясающий инструмент для всех, кто хочет сам подружиться с ИИ. Хочешь сделать свою нейросеть-художника? Добро пожаловать в мир GAN и диффузионных моделей! Главное - не бойся ошибаться и экспериментировать. Лично я считаю, что за поколением (генеративным!) будущего будущее - и это только начало!

Если у тебя чешутся руки, чтобы попробовать прямо сейчас, советую погуглить GAN tutorial PyTorch или поискать репозиторий huggingface/diffusers. Уверена, уже через пару вечеров у тебя в папке будет куча забавных сгенерированных картинок. Удачи!

Автор: miss_polly

Источник

* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js