Обучение моделей timm. Связка с fastai

в 9:16, , рубрики: искусственные нейронные сети, искусственный интеллект, конвейер, машинное обучение

Привет! Решал задачу поиска оптимальной модели для классификации собственного датасета изображений (в основном искал на HuggingFace) и столкнулся с моделями timm. Старый конвейер dvc не работал с этими моделями и пришлось искать решение. Вообще почему именно timm?

Как сказано в руководстве:

«timm` - это библиотека глубокого обучения, созданная Россом Уайтманом, и представляет собой коллекцию моделей компьютерного зрения SOTA, слоев, утилит, оптимизаторов, планировщиков, загружающих данных, а также обучающих / валидационных скриптов с возможностью воспроизведения результатов обучения ImageNet.

Вопросы сразу же отпали и я стал изучать какие виды там представлены.

Действительно, в библиотеки timm представлено огромное количество моделей и их версий, обученных на разных данных. Так можно увидеть список интересующих вас моделей:

import timm
pd.DataFrame(timm.list_models('vit_*'))
Рис 1. Вывод

Рис 1. Вывод

Указав pretrain=True, выведем предобученные версии моделей.

Приступим к обучению!

Подготовка данных

Исходные данные: csv файле с описанием изображений и принадлежность их к классу и сами изображения. Первым делом я сформировал две директории (train, validate), в которых были папки с названием класса и изображениями соответственно (стандартная структура для любых обработчиков)

├── train
│   ├── ex_class_1
│   │   ├── image_1.png
│   │   ├── image_2.png
│   │   ├── image_3.png
│   │   ├── image_4.png
│   │   └── image_5.png
│   ├── ex_class_2
│   │   ├── image_1.png
│   │   ├── image_2.png
│   │   ├── image_3.png
│   │   ├── image_4.png
│   │   └── image_5.png
│   └── ex_class_3
└── validate
├── ex_class_1
├── ex_class_2
└── ex_class_3

Вот пример кода:

import os
import csv
import argparse
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import pandas as pd

def resize_and_normalize(image_path, output_path, size=(256, 256), format='JPEG'):
    image = Image.open(image_path)
    image = image.resize(size)
    if image.mode in ['RGBA', 'P']:
        image = image.convert('RGB')
    image.save(output_path, format=format)

def preprocess_dataset(input_dir, output_dir, csv_path, test_size=0.2):
  os.makedirs(output_dir, exist_ok=True)
  data = pd.read_csv(csv_path)
  
  for label in data['class'].unique():
      os.makedirs(os.path.join(output_dir, 'train', label), exist_ok=True)
      os.makedirs(os.path.join(output_dir, 'test', label), exist_ok=True)
  
  # Split data into training and testing
  train_df, test_df = train_test_split(data, test_size=test_size, random_state=22)
  
  
  # Process train images with a progress bar
  for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc='Processing train images'):
      image_path = os.path.join(input_dir, row['image'] + '.jpg')
      if os.path.exists(image_path):
          output_path = os.path.join(output_dir, 'train', row['class'], row['image'] + .jpg')
          resize_and_normalize(image_path, output_path)
  for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc='Processing test images'):
      image_path = os.path.join(input_dir, row['image'] + '.jpg')
      if os.path.exists(image_path):
          output_path = os.path.join(output_dir, 'test', row['class'], row['image'] + '.jpg')
          resize_and_normalize(image_path, output_path)

Обучение моделей. Fastai и dvc pipeline

После того, как мы подготовили данные. можно приступать к обучению модели. Я для удобства перенес все интересующие меня модели в текстовый файл.

Импортируем необходимые библиотеки:

import os
import argparse
import timm
from fastai.vision.all import *
from fastai.metrics import *
from dvclive import Live
from dvclive.fastai import DVCLiveCallback
from torchvision import transforms
import torch

После был создан словарь метрик:

METRICS = {
    "accuracy": accuracy,
    "accuracy_multi": accuracy_multi,
    "error_rate": error_rate,
    "top_k_accuracy": top_k_accuracy,
    "Precision": Precision(average='macro'),
    "Recall": Recall(average='macro'),
    "F1Score": F1Score(average='macro'),
    "RocAuc": RocAuc(),  # Note: RocAucScore is used for both binary and multi-class ROC AUC
    "RocAucBinary": RocAucBinary(),  # Note: RocAucScore is used for both binary and multi-class ROC AUC
    "FBeta": FBeta,
    "BalancedAccuracy": BalancedAccuracy(),
    "F1ScoreMulti": F1ScoreMulti,
}

Сделано было для удобного использования и применения интересующих метрик в params.yaml (кому не знаком dvc рекомендую ознакомиться). Вот так выглядит yaml:

model:
  arch: vit_base_patch8_224
train:
  epochs: 50
  batch_size: 16
metrics:
- accuracy
- Precision
- Recall
- error_rate

Функция по созданию блока данных:

def get_dls(data_path, bs, model_cfg):
    print('Start create DataBlock')
    if not os.path.exists(data_path):
        raise ValueError(f"Data path {data_path} does not exist.")

    mean = list(model_cfg['mean'])
    std = list(model_cfg['std'])
    size = model_cfg['input_size'][1]

    dblock = DataBlock(
        blocks=(ImageBlock, CategoryBlock),
        get_items=get_image_files,
        splitter=GrandparentSplitter(train_name='train', valid_name='test'),
        get_y=parent_label,
        item_tfms=Resize(size),
        batch_tfms=[*aug_transforms(),Normalize.from_stats(mean=mean, std=std, cuda=True),],
    )
    dls = dblock.dataloaders(data_path, bs=bs)
    return dls

model_cfg содержит конфигурации модели, позже будет понятно откуда он берётся.

DataBlock - Высокоуровневый API от Fastai, позволяющий просто и удобно получить данные в DataLoaders.

blocks - указывается кортеж используемых данных.

get_items=get_image_files - получение путей до изображений.

splitter=GrandparentSplitter(train_name='train', valid_name='test') - для разделения данных на обучающую и тестовую выборку будет произведет поиск папок test, train.

get_y=parent_label - метки изображениям будут присваиваться исходя из названия родительской директории.

item_tfms=Resize(size) - Изменения, применяемые к каждым изображениям.

batch_tfms=[*aug_transforms(),Normalize.from_stats(mean=mean, std=std, - cuda=True),] - изменения, применяемые к батчам.

После формирования блока данных, можно приступать к обучению.

def train_model(data_path, output_path, params_file):
    with open(params_file, 'r') as f:
        params = yaml.safe_load(f)

    epochs = params['train']['epochs']
    batch_size = params['train']['batch_size']
    metric_names = params['metrics']
    metrics = [METRICS[name] for name in metric_names if name in METRICS]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = timm.create_model(params['model']['arch'], pretrained=True, num_classes=3)
    model = model.to(device)
    model_cfg = model.default_cfg

    dls = get_dls(data_path,batch_size, model_cfg)
    print('Start fine_tune')
    learn = Learner(dls, model, metrics=metrics)
    learn = Learner(dls, model, metrics=metrics)
    lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide), show_plot = False)
    lr = lrs[3]
    with Live('dvctrain', report='md') as live:
        learn.fine_tune(epochs, lrs.valley, cbs = [SaveModelCallback(min_delta=0.01, monitor = 'accuracy'),DVCLiveCallback(live=live)])

model_cfg = model.default_cfg - возвращает конфигурацию модели

Рис 2. Конфигурация модели

Рис 2. Конфигурация модели

Хотел бы обратить внимание на строку lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide), show_plot = False) , которая позволяет определилить learning rate, наиболее подходящий для ваших данных (более подробно можно ознакомиться тут. Только не забудьте включить VPN).

Вот и весь код! Всего двумя функциями можно без проблем провести эксперименты и подобрать модель классификации изображений! Здесь я не углублялся в процесс создания конвейера dvc, в хабре можно найти материал по этой теме. В следующей статье разберем процесс оценки модели после обучения.

Удачи!

Автор: RadAI

Источник

* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js