Всем привет! Меня зовут Вадим, я Data Scientist в компании Raft, и сегодня мы погрузимся в Mojo. Я уже делал обзор данного языка программирования и рассмотрел его преимущества, примеры использования, а также провел сравнение с Python.
Теперь давайте посмотрим, как обучить простую сверточную нейронную сеть, и разберём один из методов машинного обучения — линейную регрессию. В качестве примеров задач возьмем стандартные соревнования машинного обучения: предсказание стоимости жилья и классификацию рукописных цифр MNIST. Для проведения экспериментов на Python используем фреймворк машинного обучения PyTorch. А на Mojo — фреймворк машинного обучения Basalt.
Немного о датасетах
MNIST (Modified National Institute of Standards and Technology) — это набор данных (датасет) для задачи распознавания рукописных цифр от 0 до 9. Набор состоит из 70 тысяч картинок с разрешением 28х28 черного цвета с белой цифрой. Задача состоит в том, чтобы распознать цифру, которая изображена на картинке.
Housing Prices Dataset — это набор данных для предсказания стоимости жилья на основе некоторых признаков. Например, площадь участка, тип жилья, наличия гаража, количества комнат и так далее.
Погружаемся в код
Эксперимент на MNIST
Для решения задачи классификации рукописных цифр напишем простую CNN (convolutional neural network), которая будет состоять из 2-х частей:
-
построение карты признаков (feature map), реализованной через 2 слоя сверток;
-
классификатор, состоящий из 3-х полносвязных слоев.
Более подробно архитектура представлена в таблице 1.
Layer |
Future map |
Size |
Kernel size |
Stride |
Padding |
Activation |
|
Input |
Image |
1 |
28x28 |
- |
- |
- |
- |
1 |
Convolution |
16 |
28x28 |
5x5 |
1 |
2 |
ReLU |
2 |
Maxpool |
16 |
14x14 |
2x2 |
0 |
0 |
|
3 |
Convolution |
32 |
14x14 |
5x5 |
1 |
2 |
ReLU |
4 |
Maxpool |
32 |
7x7 |
2x2 |
0 |
0 |
|
5 |
FC |
- |
120 |
- |
- |
- |
ReLU |
6 |
FC |
- |
184 |
- |
– |
- |
ReLU |
Output |
FC |
- |
10 |
- |
- |
- |
- |
Гиперпараметры обучения:
-
num_epochs = 20
-
batch_size = 8
-
learning_rate = 2e-3
-
оптимизатор Adam
-
функция потерь CrossEntropyLoss
Реализация архитектуры сети на Python и Mojo немного отличается. В первом случае, используя PyTorch, мы могли бы определить архитектуру как последовательность блоков.
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.block2 = nn.Sequential(
nn.Conv2d(in_channels=16,out_channels= 32, kernel_size=5, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.fc1 = nn.Linear(in_features=32 * 7 * 7, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.out = nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = x.view(x.size(0), -1)
x = nn.ReLU()(self.fc1(x))
x = nn.ReLU()(self.fc2(x))
return self.out(x)
В случае с Mojo необходимо определить структуру Graph, которая реализует так называемый граф вычислений. Он применяется для вычислений в предсказании (feed forward) и обратного распространения ошибки (backpropagation).
fn create_CNN(batch_size: Int) -> Graph:
# инициализируем граф и наш вход
var g = Graph()
var x = g.input(TensorShape(batch_size, 1, 28, 28))
# инициализируем и применяем сверточные слои
var conv1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2)
var act_conv1 = nn.ReLU(g, conv1)
var max_pool1 = nn.MaxPool2d(g, act_conv1, kernel_size=2)
var conv2 = nn.Conv2d(g, max_pool1, out_channels=32, kernel_size=5, padding=2)
var act_conv2 = nn.ReLU(g, conv2)
var max_pool2 = nn.MaxPool2d(g, act_conv2, kernel_size=2)
# переводим выходной тензор в вектор
var x_reshape = g.op(
OP.RESHAPE,
max_pool2,
attributes=AttributeVector(
Attribute(
"shape",
TensorShape(max_pool2.shape[0], max_pool2.shape[1] * max_pool2.shape[2] * max_pool2.shape[3]),
)
),
)
# классифицируем, извлеченные признаки, полносвязной сетью
var fc1 = nn.Linear(g, x_reshape, n_outputs=120)
var act_fc1 = nn.ReLU(g, fc1)
var fc2 = nn.Linear(g, act_fc1, n_outputs=84)
var act_fc2 = nn.ReLU(g, fc2)
var out = nn.Linear(g, act_fc2, n_outputs=10)
g.out(out)
# считаем потери, используя CrossEntropyLoss
var y_true = g.input(TensorShape(batch_size, 10))
var loss = nn.CrossEntropyLoss(g, out, y_true)
g.loss(loss)
return g
Инициализация модели вместе с оптимизатором и цикл её обучения на Python достаточно стандартный для PyTorch: прогоняем весь датасет некоторое число раз (эпох) по батчам (пакетам), определяем признаки (images) и метки к ним (labels). Затем предсказываем метку класса, рассчитываем ошибку и обновляем градиенты.
cnn = CNN()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
cnn.train()
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(loaders["train"]):
b_x = Variable(images)
b_y = Variable(labels)
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
На Mojo есть небольшие отличия:
-
необходимо определять функцию для выполнения кода;
-
необходимо определить модель и оптимизатор через структуру graph;
-
перед подачей изображений в сеть необходимо произвести one hot encoding меток.
В остальном процесс обучения сети схож со стилем PyTorch, за исключением особенностей синтаксиса языка.
fn main():
alias graph = create_CNN(batch_size)
var model = nn.Model[graph]()
var optim = nn.optim.Adam[graph](Reference(model.parameters), lr=learning_rate)
for epoch in range(num_epochs):
var num_batches: Int = 0
var epoch_loss: Float32 = 0.0
for batch in training_loader:
var labels_one_hot = Tensor[dtype](batch.labels.dim(0), 10)
for bb in range(batch.labels.dim(0)):
labels_one_hot[int((bb * 10 + batch.labels[bb]))] = 1.0
var loss = model.forward(batch.data, labels_one_hot)
optim.zero_grad()
model.backward()
optim.step()
epoch_loss += loss[0]
num_batches += 1
House price prediction
Для решения этой задачи мы применим стандартную линейную регрессию, реализованную через один полносвязный слой.
Гиперпараметры обучения следующие:
-
num_epochs = 500
-
batch_size = 32
-
learning_rate = 0.01
-
оптимизатор Adam
-
функция потерь MSELoss
На Python код с использованием PyTorch будет выглядеть следующим образом.
class LinearRegression(nn.Module):
def __init__(self, input_dim):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(in_features=input_dim, out_features=1)
def forward(self, x):
return self.linear(x)
На Mojo снова необходимо определить структуру Graph и слой с функцией потерь, через которые будут происходить вычисления.
fn linear_regression(batch_size: Int, n_inputs: Int, n_outputs: Int) -> Graph:
var g = Graph()
var x = g.input(TensorShape(batch_size, n_inputs))
var y_true = g.input(TensorShape(batch_size, n_outputs))
var y_pred = nn.Linear(g, x, n_outputs)
g.out(y_pred)
var loss = nn.MSELoss(g, y_pred, y_true)
g.loss(loss)
return g
Цикл обучения совпадает с тем, что был показан на MNIST, за исключением того, что отпадает необходимость в ohe hot encoding, так как метки уже закодированы.
Результаты
Обучая сверточную сеть на MNIST и линейную регрессию на предсказании стоимости жилья, было проведено множество экспериментов с настройкой различных гиперпараметров. Результаты с оптимальными значениями по времени обучения представлены в таблице 2.
|
MNIST |
House Price |
Python |
1.58 сек |
23.18 сек |
Mojo |
4.89 сек |
0.15 сек |
По задаче классификации MNIST язык программирования Python продемонстрировал лучшую производительность в классификации рукописных цифр. В то время как Mojo показал менее удовлетворительные результаты, что можно объяснить недостаточной оптимизацией сверток в текущем фреймворке Mojo Basalt.
Для задачи предсказания стоимости домов Python уступил Mojo в линейной регрессии. Mojo продемонстрировал хорошие результаты, превзойдя Python, что подтверждает высокую производительность языка, особенно в задачах, связанных с линейными вычислениями.
Заключение
Потенциал Mojo раскрывается в задачах, где важна скорость. Хотя на данный момент он пока не так хорош в работе с нейронными сетями, как Python, поскольку имеет ограниченной функционал. Но с развитием фреймворков и различных оптимизаций, вероятно, будут улучшения.
А что думаете вы? Пишите в комментариях!
Ссылки:
Автор: MidavNibush