Сейчас нейросети стали настолько большими, что обучение большой сети на 1 видеокарте технически невозможно или займёт десятки и сотни лет. Кроме того, на большой обучающей выборке всплывают проблемы забывания сетью того, чему её учили вначале.
Одним из способов решения этих проблем является разбивка датасета на куски, и обучение одной и той же нейросети параллельно на разных устройствах. Потом, очевидно, нужно каким-то образом слить обученные нейросети в одну. Обсудим в этой статье детальнее, зачем это вообще может быть нужно, и как это сделать более-менее правильно.
Оглавление
1. Проблемы, решаемые через слияние словарей
1.1 Улучшение генерализации работы нейросети
1.2 Забывание нейросетью первых данных из датасета при последовательном обучении
1.3 Ускорение обучения нейросети за счёт распараллеливания обучения
2. Вводные для успешного слияния словарей
2.1 Исходная сеть всегда одна и та же
2.2 Успех слияния зависит от близости каждого словаря к глобальному минимуму ошибки
2.3. Слияние словарей с разными минимумами ошибки, полученных на разных кусках обучающих данных
3. Техника слияния словарей
3.1. Простейший пример слияния
3.2. Почему простейший пример нельзя использовать для слияния тысяч словарей
3.3. Способы решения проблем с потерей точности
3.4. Кручу-верчу, float64 хочу
3.5 И многопроцессность для скорости!
3.6 Полная версия скрипта для слияния словарей
4. Итоги
1. Проблемы, решаемые через слияние словарей
▍ 1.1. Улучшение генерализации работы нейросети
В работе было показано, что усреднение весов моделей, полученных в ходе одного запуска, но сохранённых в разные моменты времени (чекпойты), повышало качество работы нейросети. В PyTorch эту стратегию можно запустить с использованием модуля SWA из torchcontrib.optim.
Есть статья на сайте PyTorch об использовании SWA.
Единственно напрягает, что код уже около 5 лет не обновлялся. Так что его работоспособность под сомнением. Вот ссылка на GitHub c примерами.
По результатам тестов видно, что после SWA точность нейросети повысилась примерно на 1%. Нужно ли вам такое улучшение — решать только вам.
Для этой стратегии не требуются хитрости и трюки из оставшейся части статьи. Хотя, возможно, предложенные ниже способы слияния словарей будут лучше. Я не проверял, потому что меня интересовало слияние словарей, полученных после обучения на разных кусках данных.
▍ 1.2. Забывание нейросетью первых данных из датасета при последовательном обучении
Эта проблема широко известна разработчикам нейросетей. Она называется Catastrophic interference, что можно перевести как «катастрофическое забывание/размазывание/интерференция». Есть много подходов для её решения. Ещё в 1989 году McCloskey and Cohen (1989) изучали катастрофическое забывание при последовательном обучении нейросетей, и один из выводов звучал так:
Interference was catastrophic in the backpropagation networks when learning was sequential but not concurrent
Перевод: Забывание было катастрофическим в сетях с обратным распространением ошибки при последовательном обучении, но не при параллельном.
Один из способов параллельного обучения состоит в том, чтобы разделить обучающий датасет на куски и обучать исходную сеть на каждом таком куске. По результатам такого обучения мы получим кучу комплектов весов (словарей в терминах PyTorch), которые как-то нужно слить в один словарь.
▍ 1.3. Ускорение обучения нейросети за счёт распараллеливания обучения
С этим способом всё понятно, так как мы не можем масштабировать видеокарту/ускоритель за счёт бесконечного увеличения скорости её работы, нам приходится использовать целые фермы ускорителей, чтобы получить обученную сеть в приемлемое время.
И тут можно использовать тот же подход. Мы клонируем нейросеть и на разных кусках датасета обучаем её, а потом сливаем словари.
2. Вводные для успешного слияния словарей
▍ 2.1. Исходная сеть всегда одна и та же
Обычный подход к обучению нейросети (и так реализовано в PyTorch), что она инициализируется случайными весами. Если мы будем каждый раз обучать разную нейросеть, то ничего хорошего при слиянии не выйдет.
Поэтому мы должны использовать фиксированный seed, а ещё намного лучше, если мы будем использовать одну и ту же предобученную сеть.
Само собой подразумевается, что как топология, так и схема обучения должна быть абсолютно одинаковой. Единственное, что может немного отличаться — это количество эпох, так как заранее неизвестно, сколько их потребуется для достижения заданной точности.
Это можно представить так, что мы создали атомарную копию двух людей и отправили их в разные места примерно на одинаковое время, а потом сливаем их воспоминания через усреднение концентраций нейромедиаторов в синаптических щелях. Только тогда есть шанс, что воспоминания успешно интегрируются.
▍ 2.2. Успех слияния зависит от близости каждого словаря к глобальному минимуму ошибки
При обсуждении успешности обучения нейросетей через сливание словарей на форуме PyTorch инженер Patric Black из Nvidia заметил, что слияние может быть успешным при близости к глобальному минимуму ошибки.
Самый простой способ слить словари — это их усреднить. Об этом дальше в статье.
Самая сложность состоит в том, чтобы понять, какие словари достаточно близки к глобальному минимуму, а какие нет. В случае когда мы сливаем словари из одного запуска, то, почти наверняка можно сказать, что веса будут близки к одному и тому же минимуму (глобальному или локальному).
В результате слияния нам нужно как можно ближе подобраться к глобальному минимуму ошибки, поэтому для слияния мы должны брать словари, которые уже близко к нему находятся: D1, D2 и D3, а иначе результирующие веса не дадут ошибку, близкую к глобальному минимуму.
Возможные подходы проверки близости конкретного словаря к глобальному минимуму ошибки:
- Проверять среднюю ошибку словаря на большом наборе тестовых данных.
- Сравнивать ошибку конкретного словаря с ошибкой объединённого словаря, и 5-10% самых худших словарей не допускать к результирующему слиянию.
- Смотреть, насколько сильно портит точность усреднённого словаря конкретный словарь, допущенный к слиянию.
Совершенно неочевидно какой подход выбрать (я экспериментировал со вторым), и есть ли среди этих подходов самый правильный. Детальнее, как проконтролировать, какие словари допустить к слиянию, а какие нет (или какие из них нужно как-то перевычислить, чтобы они стали ближе к глобальному минимуму ошибки), возможно, мы обсудим в следующей статье.
▍ 2.3. Слияние словарей с разными минимумами ошибки, полученных на разных кусках обучающих данных
При слиянии словарей полученных в результате обучения одной и той же предобученной сети на разных кусках обучающего датасета мы практически всегда вначале имеем словари с разными локальными минимумами. Станет ли от этого после слияния итоговая нейросеть полностью неточной и бесполезной?
Нет. Мы вместо катастрофического забывания получим контролируемое забывание (размазывание), что тоже может быть полезно.
При катастрофическом забывании нейросеть полностью забывает то, что учила вначале. А при контролируемом она забудет в равной мере то, что учила на каждом куске данных. И в итоге она будет работать лучше, чем необученная сеть на всём датасете, что тоже может быть полезно, когда достаточно небольшого превышения вероятности над полной неопределённостью для успешной работы нейросети.
В моих экспериментах по объединению словарей полученных на разных кусках данных после каждого слияния процент ошибки на общем наборе увеличивался примерно в 2 раза за одно пирамидальное (об этом ниже) слияние. Но при этом сеть уже могла выдавать приемлемую точность на всём наборе. При большом количестве пирамидальных слияний (> ~20, это около миллиона словарей) ошибка возрастала настолько, что правильный ответ уже тонул в шумах и нейросеть становилась бесполезной.
3. Техника слияния словарей
▍ 3.1 Простейший пример слияния
# Setup
modelA = nn.Linear(1, 1)
modelB = nn.Linear(1, 1)
sdA = modelA.state_dict()
sdB = modelB.state_dict()
# Average all parameters
for key in sdA:
sdB[key] = (sdB[key] + sdA[key]) / 2.
# Recreate model and load averaged state_dict (or use modelA/B)
model = nn.Linear(1, 1)
model.load_state_dict(sdB)
Как мы видим, используется усреднение. Пример кода взят с форума PyTorch. Согласно научной работе это вполне хороший подход для усреднения словарей в рамках одного прогона для разных чекпойнтов под названием «стохастическое усреднение весов» (Stochastic Weight Averaging).
▍ 3.2 Почему простейший пример нельзя использовать для слияния тысяч словарей
Представим, что нам нужно слить 1 000 весов. Формула для слияния будет выглядеть примерно так:
for key in sdA000:
sdB[key] = (sdA000[key] + sdA001[key]+ sdA002[key] + ... + sdA999[key]) / 1000
Понятно, что тут можно написать в виде цикла, чтобы обойтись без гигантской строки.
Суть в том, что float32 (тип данных по умолчанию в PyTorch) имеет 23 бита для хранения цифр после запятой, это примерно 6-7 десятичных цифр.
Проблема в том, что при усреднении большого числа словарей мы будем сильно, а иногда колоссально терять в точности. Просто потому, что число знаков в мантиссе ограничено, и значащие цифры будут из неё выпадать при добавлении маленьких чисел к большому (сумме).
При увеличении числа слагаемых сумма может стать настолько большой, что добавление любого числа маленьких чисел никак на ней никак не отразится.
Сложим, например, 2 числа:
1.111111 * 10^32
1.111111 * 10^-32
Чтобы их сложить, экспонента второго числа должна быть приведена к экспоненте первого числа.
И второе число превратится в 0.0000000000000000000000000000000000000000000000000000000000000001111111 * 10^32
Поскольку число значащих знаков в мантиссе FLOAT32 всего 23 бита, то второе число после приведения экспоненты будет приравнено к нулю.
Посмотрим код с четырьмя способами усреднения:
import torch
import warnings
warnings.filterwarnings("ignore")
def tree_sum(t):
while t.size(dim=0) > 1:
a, b = torch.chunk(t, 2)
t = a.add(b)
return t
n = 1024
s = 0.0
d = 1/311
for i in range(0, n):
s += d
calc = s/n
print('Float64', calc, (1 - calc/d)*100.0)
t = torch.full((n,), d, dtype=torch.float32)
d = torch.tensor(d, dtype=torch.float32)
calc = t.sum().div(n)
print('Float32 PyTorch Sum() error: ', calc.item(), calc.div(d).add(-1.0).mul(100.0).item())
d = torch.tensor(d, dtype=torch.float32)
s = torch.tensor(0.0, dtype=torch.float32)
for i in range(0, n):
s = torch.add(s, d)
calc = s.div(n)
print('Float32 PyTorch + error: ', calc.item(), calc.div(d).add(-1.0).mul(100.0).item())
t = torch.full((n,), d, dtype=torch.float32)
d = torch.tensor(d, dtype=torch.float32)
calc = tree_sum(t).div(n)
print('Float32 PyTorch TreeSum error: ', calc.item(), calc.div(d).add(-1.0).mul(100.0).item())
exit()
В этом коде мы усредняем значение 1/311 (довольно маленький вес) в 1024 ячейках 4 различными способами. И выводим ошибку в процентах для каждого способа.
Результат:
Float64 0.0032154340836013477 1.9095836023552692e-12
Float32 PyTorch Sum() error: 0.0032154337968677282 5.9604644775390625e-06
Float32 PyTorch + error: 0.003215478966012597 0.0013947486877441406
Float32 PyTorch TreeSum error: 0.003215434029698372 0.0
Первая колонка — само число после усреднения, вторая ошибка в процентах.
- Наивное усреднение с помощью float64, очень маленькая, примерно 2 триллионных процента процента.
- Усреднение с использованием метода PyTorch sum(). В нём, судя по всему, встроены специальные алгоритмы, уменьшающие ошибку. 6/1000000 процента. Но загонять 1024 словаря в память — никакой памяти не хватит. Расточительный метод. Но можно обойти и написать код, который это учтёт.
- Обычное наивное сложение тензоров, а потом деление. Даёт практически 0.0013% ошибки. Это наибольшая ошибка из протестированных способов. Казалось бы ошибка незначительная, но в реальных ситуациях, когда усредняются значения разных порядков, а не одно и то же значение как в тесте, она будет значительно больше. При использовании глубоких сетей даже небольшая ошибка в каждом слое приведёт к экспоненциальному лавинообразному росту ошибок в последнем слое.
- Попарное суммирование (ещё ссылка). Мы складываем попарно все числа, потом повторяем это много раз. Из-за того, что каждый раз мы складываем числа примерно одного масштаба, ошибка минимальна. В этом тестовом примере вообще ноль! Это сложение в виде бинарного дерева. Напоминает по форме дерево Меркла.
Как вы и догадались, я использовал последний способ усреднения словарей.
▍ 3.3 Способы решения проблем с потерей точности
Классический способ — это алгоритм Кэхэна, вот статья на Хабре на тему суммирования.
Попарный древовидный алгоритм сложения (из предыдущего пункта статьи) показался мне более простым, чем алгоритм Кэхэна.
Его, кстати, можно ещё улучшить, если отсортировать исходный массив значений. Это даст более гарантированное совпадение порядков складываемых величин.
▍ 3.4 Кручу-верчу, float64 хочу
И для ещё большей точности мы перед усреднением будем конвертировать словари в float64. А полученный итоговый словарь сконвертируем опять в float32. Это даст дополнительный прирост точности. И наша нейросеть будет точнее.
NN1 = MyNeuralNetwork()
NN2 = MyNeuralNetwork()
NN1.loadDict(path1)
NN2.loadDict(path2)
# конвертируем в float64
NN1.double()
NN2.double()
# Усредняем
# .....
# конвертируем в float32. Это делаем только самым последним после того, как всё уже усредним.
RES.float()
▍ 3.5 И многопроцессность для скорости!
Если словарей много, то целесообразно всё делать параллельно. Для этого прекрасно подойдёт метод pool.map(). Про него немного есть в моей предыдущей статье.
torch.set_num_threads(1)
with mp.Pool(PROCESSES_NUM) as pool:
pool.map(
merge_dict_mp,
[(sDictPath, *a[j*2:(j+1)*2], i+1, j) for j in range(len(a) // 2)]
)
torch.set_num_threads(PROCESSES_NUM)
▍ 3.6 Полная версия скрипта для слияния словарей
Всё в спойлере. Первый аргумент — директория, где лежат словари для усреднения.
#!/usr/bin/env python
# сливаем словари с помощью попарного суммирования. Итоговый словарь во float64.
import time, sys, os
import torch
import torch.nn as nn
from _nn512 import NeuralNetwork512 # моя модель
from torch import multiprocessing as mp
import hashlib
PROCESSES_NUM = 30
USE_FP64_DICT = True
def merge_dict_mp(args):
path, fdict1, fdict2, nOutPostfix, numFile = args
sOutName = f'{nOutPostfix:>02d}_{numFile:>06d}_' + hashlib.md5((fdict1 + fdict2).encode()).hexdigest()[0:16] + '.zip'
sFullOutName = f'{sIniPath}x{nOutPostfix:>02d}/{sOutName}'
if os.path.exists(sFullOutName):
return
NN1 = NeuralNetwork512()
NN2 = NeuralNetwork512()
# Словари со второго уровня уже сохранены как float64
if nOutPostfix > 1 and USE_FP64_DICT:
NN1.double()
NN2.double()
NN1.loadDict(path + fdict1)
NN2.loadDict(path + fdict2)
# Словари первого уровня нужно преобразовать в float64 после загрузки
if nOutPostfix == 1 and USE_FP64_DICT:
NN1.double()
NN1.double()
sd1 = NN1.state_dict()
sd2 = NN2.state_dict()
# Average all parameters
for key in sd2:
sd1[key] = (sd1[key] + sd2[key]) / 2
# Recreate model and load averaged state_dict (or use modelA/B)
NN1 = NeuralNetwork512()
if USE_FP64_DICT:
NN1 = NN1.double()
NN1.load_state_dict(sd1)
torch.save(NN1.state_dict(), sFullOutName)
if __name__ == '__main__':
sIniPath = sys.argv[1] + '/'
# Обработка до 2^10 словарей
for i in range(0, 11):
if i == 0:
sDictPath = sIniPath
else:
sDictPath = f'{sIniPath}x{i:>02d}/'
os.chdir(sDictPath)
a = sorted(filter(os.path.isfile, os.listdir('.')), key= os.path.getmtime if i==0 else os.path.basename)
if (len(a) < 2):
break
sNextDictPath = f'{sIniPath}x{i+1:>02d}/'
if not os.path.exists(sNextDictPath):
os.mkdir(sNextDictPath)
torch.set_num_threads(1)
with mp.Pool(PROCESSES_NUM) as pool:
pool.map(
merge_dict_mp,
[(sDictPath, *a[j*2:(j+1)*2], i+1, j) for j in range(len(a) // 2)]
)
torch.set_num_threads(PROCESSES_NUM)
exit()
4. Итоги
На своих нейросетях (256*16 нейронов) я увидел существенный прирост точности от попарного пирамидального усреднения (с предварительной конверсией в FP64) в сравнении с наивным усреднением в FP32. Фактически при наивном усреднении вообще ничего не работало на большом числе словарей, сеть давала рандомные результаты. Из-за того, что при многослойности ошибки от слоя к слою росли лавинообразно. После внедрения попарного суммирования на всём наборе данных я увидел рост точности работы нейросети.
После каждого объединения процент ошибки на общем наборе увеличивался примерно в 2 раза, по сравнению с ошибкой словаря, полученного на куске данных. Но при этом сеть уже могла выдавать приемлемую точность на всём наборе.
Возможно, чтобы точность не так сильно размывалась, мне следовало использовать бОльшую сеть с бОльшим предобучением. Но об этом, возможно, в следующих статьях.
Автор: Сергей Ю. Каменев