Привет! Решал задачу поиска оптимальной модели для классификации собственного датасета изображений (в основном искал на HuggingFace) и столкнулся с моделями timm. Старый конвейер dvc не работал с этими моделями и пришлось искать решение. Вообще почему именно timm?
Как сказано в руководстве:
«timm` - это библиотека глубокого обучения, созданная Россом Уайтманом, и представляет собой коллекцию моделей компьютерного зрения SOTA, слоев, утилит, оптимизаторов, планировщиков, загружающих данных, а также обучающих / валидационных скриптов с возможностью воспроизведения результатов обучения ImageNet.
Вопросы сразу же отпали и я стал изучать какие виды там представлены.
Действительно, в библиотеки timm представлено огромное количество моделей и их версий, обученных на разных данных. Так можно увидеть список интересующих вас моделей:
import timm
pd.DataFrame(timm.list_models('vit_*'))

Указав pretrain=True
, выведем предобученные версии моделей.
Приступим к обучению!
Подготовка данных
Исходные данные: csv файле с описанием изображений и принадлежность их к классу и сами изображения. Первым делом я сформировал две директории (train, validate), в которых были папки с названием класса и изображениями соответственно (стандартная структура для любых обработчиков)
├── train
│ ├── ex_class_1
│ │ ├── image_1.png
│ │ ├── image_2.png
│ │ ├── image_3.png
│ │ ├── image_4.png
│ │ └── image_5.png
│ ├── ex_class_2
│ │ ├── image_1.png
│ │ ├── image_2.png
│ │ ├── image_3.png
│ │ ├── image_4.png
│ │ └── image_5.png
│ └── ex_class_3
└── validate
├── ex_class_1
├── ex_class_2
└── ex_class_3
Вот пример кода:
import os
import csv
import argparse
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import pandas as pd
def resize_and_normalize(image_path, output_path, size=(256, 256), format='JPEG'):
image = Image.open(image_path)
image = image.resize(size)
if image.mode in ['RGBA', 'P']:
image = image.convert('RGB')
image.save(output_path, format=format)
def preprocess_dataset(input_dir, output_dir, csv_path, test_size=0.2):
os.makedirs(output_dir, exist_ok=True)
data = pd.read_csv(csv_path)
for label in data['class'].unique():
os.makedirs(os.path.join(output_dir, 'train', label), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'test', label), exist_ok=True)
# Split data into training and testing
train_df, test_df = train_test_split(data, test_size=test_size, random_state=22)
# Process train images with a progress bar
for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc='Processing train images'):
image_path = os.path.join(input_dir, row['image'] + '.jpg')
if os.path.exists(image_path):
output_path = os.path.join(output_dir, 'train', row['class'], row['image'] + .jpg')
resize_and_normalize(image_path, output_path)
for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc='Processing test images'):
image_path = os.path.join(input_dir, row['image'] + '.jpg')
if os.path.exists(image_path):
output_path = os.path.join(output_dir, 'test', row['class'], row['image'] + '.jpg')
resize_and_normalize(image_path, output_path)
Обучение моделей. Fastai и dvc pipeline
После того, как мы подготовили данные. можно приступать к обучению модели. Я для удобства перенес все интересующие меня модели в текстовый файл.
Импортируем необходимые библиотеки:
import os
import argparse
import timm
from fastai.vision.all import *
from fastai.metrics import *
from dvclive import Live
from dvclive.fastai import DVCLiveCallback
from torchvision import transforms
import torch
После был создан словарь метрик:
METRICS = {
"accuracy": accuracy,
"accuracy_multi": accuracy_multi,
"error_rate": error_rate,
"top_k_accuracy": top_k_accuracy,
"Precision": Precision(average='macro'),
"Recall": Recall(average='macro'),
"F1Score": F1Score(average='macro'),
"RocAuc": RocAuc(), # Note: RocAucScore is used for both binary and multi-class ROC AUC
"RocAucBinary": RocAucBinary(), # Note: RocAucScore is used for both binary and multi-class ROC AUC
"FBeta": FBeta,
"BalancedAccuracy": BalancedAccuracy(),
"F1ScoreMulti": F1ScoreMulti,
}
Сделано было для удобного использования и применения интересующих метрик в params.yaml (кому не знаком dvc рекомендую ознакомиться). Вот так выглядит yaml:
model:
arch: vit_base_patch8_224
train:
epochs: 50
batch_size: 16
metrics:
- accuracy
- Precision
- Recall
- error_rate
Функция по созданию блока данных:
def get_dls(data_path, bs, model_cfg):
print('Start create DataBlock')
if not os.path.exists(data_path):
raise ValueError(f"Data path {data_path} does not exist.")
mean = list(model_cfg['mean'])
std = list(model_cfg['std'])
size = model_cfg['input_size'][1]
dblock = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=GrandparentSplitter(train_name='train', valid_name='test'),
get_y=parent_label,
item_tfms=Resize(size),
batch_tfms=[*aug_transforms(),Normalize.from_stats(mean=mean, std=std, cuda=True),],
)
dls = dblock.dataloaders(data_path, bs=bs)
return dls
model_cfg
содержит конфигурации модели, позже будет понятно откуда он берётся.
DataBlock - Высокоуровневый API от Fastai, позволяющий просто и удобно получить данные в DataLoaders.
blocks
- указывается кортеж используемых данных.
get_items=get_image_files
- получение путей до изображений.
splitter=GrandparentSplitter(train_name='train', valid_name='test')
- для разделения данных на обучающую и тестовую выборку будет произведет поиск папок test, train.
get_y=parent_label
- метки изображениям будут присваиваться исходя из названия родительской директории.
item_tfms=Resize(size)
- Изменения, применяемые к каждым изображениям.
batch_tfms=[*aug_transforms(),Normalize.from_stats(mean=mean, std=std, - cuda=True),]
- изменения, применяемые к батчам.
После формирования блока данных, можно приступать к обучению.
def train_model(data_path, output_path, params_file):
with open(params_file, 'r') as f:
params = yaml.safe_load(f)
epochs = params['train']['epochs']
batch_size = params['train']['batch_size']
metric_names = params['metrics']
metrics = [METRICS[name] for name in metric_names if name in METRICS]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(params['model']['arch'], pretrained=True, num_classes=3)
model = model.to(device)
model_cfg = model.default_cfg
dls = get_dls(data_path,batch_size, model_cfg)
print('Start fine_tune')
learn = Learner(dls, model, metrics=metrics)
learn = Learner(dls, model, metrics=metrics)
lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide), show_plot = False)
lr = lrs[3]
with Live('dvctrain', report='md') as live:
learn.fine_tune(epochs, lrs.valley, cbs = [SaveModelCallback(min_delta=0.01, monitor = 'accuracy'),DVCLiveCallback(live=live)])
model_cfg = model.default_cfg
- возвращает конфигурацию модели

Хотел бы обратить внимание на строку lrs =
learn.lr
_find(suggest_funcs=(minimum, steep, valley, slide), show_plot = False)
, которая позволяет определилить learning rate, наиболее подходящий для ваших данных (более подробно можно ознакомиться тут. Только не забудьте включить VPN).
Вот и весь код! Всего двумя функциями можно без проблем провести эксперименты и подобрать модель классификации изображений! Здесь я не углублялся в процесс создания конвейера dvc, в хабре можно найти материал по этой теме. В следующей статье разберем процесс оценки модели после обучения.
Удачи!
Автор: RadAI