Skip to main content

WildTrain Classification Training Configuration

Detailed reference for the classification training YAML configuration file.

Overview

The classification training config controls the full classifier training pipeline: dataset loading, model architecture, training hyperparameters, MLflow tracking, checkpointing, and curriculum learning.

Usage:

wildtrain train classifier -c configs/classification/classification_train.yaml

Configuration Sections

dataset — Data Configuration

FieldTypeDefaultDescription
root_data_directorystrPath to the root data directory
dataset_typestrroiDataset type: roi (pre-computed crops) or crop (dynamic extraction)
input_sizeint384Input image size for the model
batch_sizeint64Training batch size
rebalancebooltrueRebalance class distribution via oversampling

Dataset Statistics

Used for image normalization:

dataset:
stats:
mean: [0.554, 0.469, 0.348]
std: [0.203, 0.173, 0.144]

Use wildtrain dataset stats DATA_DIR to compute these values.

Transforms

Define per-split torchvision transforms:

dataset:
transforms:
train:
- name: Resize
params:
size: ${dataset.input_size}
- name: RandomHorizontalFlip
params:
p: 0.5
- name: ColorJitter
params:
brightness: 0.1
contrast: 0.0
saturation: 0.0
- name: RandomRotation
params:
degrees: 45
val:
- name: Resize
params:
size: ${dataset.input_size}

Single-Class Mode

Merge all species into a binary classifier (wildlife vs. background):

dataset:
single_class:
enable: true
background_class_name: "background"
single_class_name: "wildlife"
keep_classes: null # List of classes to keep (null = all)
discard_classes: ["rocks", "vegetation", "other"]

Crop Dataset Parameters

Used when dataset_type: "crop":

FieldTypeDefaultDescription
crop_sizeint${dataset.input_size}Crop size for dynamic extraction
max_tn_cropsint1Max true-negative crops per image
p_draw_annotationsfloat0.2Probability of drawing annotations on crops
compute_difficultiesbooltrueCompute sample difficulty scores
preserve_aspect_ratiobooltruePreserve aspect ratio during cropping

Curriculum Learning

FieldTypeDefaultDescription
curriculum_config.enabledboolfalseEnable curriculum learning
curriculum_config.typestrdifficultyCurriculum type
curriculum_config.difficulty_strategystrlinearDifficulty progression strategy
curriculum_config.start_difficultyfloat0.0Starting difficulty
curriculum_config.end_difficultyfloat1.0Ending difficulty
curriculum_config.warmup_epochsint0Warmup epochs before curriculum starts

model — Model Architecture

FieldTypeDefaultDescription
backbonestrBackbone model identifier (e.g., timm/vit_base_patch14_reg4_dinov2.lvd142m)
pretrainedbooltrueUse pretrained weights
backbone_sourcestrtimmSource library: timm
dropoutfloat0.2Dropout rate
freeze_backbonebooltrueFreeze backbone weights (train head only)
input_sizeint${dataset.input_size}Model input size
meanlist${dataset.stats.mean}Normalization mean
stdlist${dataset.stats.std}Normalization std
hidden_dimint128Hidden layer dimension
num_layersint2Number of classification head layers
weightsstrNonePath to pretrained checkpoint for warm-start

train — Training Hyperparameters

FieldTypeDefaultDescription
batch_sizeint${dataset.batch_size}Training batch size
epochsint20Number of training epochs
lrfloat1e-3Learning rate
label_smoothingfloat0.0Label smoothing factor
weight_decayfloat1e-3Weight decay
lrffloat1e-2Final learning rate factor
precisionstrbf16-mixedTraining precision: bf16-mixed, 16-mixed, 32
acceleratorstrautoPyTorch Lightning accelerator: auto, gpu, cpu
num_workersint4DataLoader workers
val_check_intervalint2Validation check interval (epochs)

mlflow — Experiment Tracking

FieldTypeDefaultDescription
experiment_namestrMLflow experiment name
run_namestrMLflow run name
log_modelbooltrueLog model artifacts to MLflow

checkpoint — Model Checkpointing

FieldTypeDefaultDescription
monitorstrval_f1scoreMetric to monitor for checkpointing
save_top_kint1Number of top checkpoints to keep
modestrmaxOptimization direction: max or min
save_lastbooltrueAlways save last epoch checkpoint
dirpathstrcheckpoints/classificationCheckpoint save directory
patienceint10Early stopping patience
save_weights_onlybooltrueSave only model weights (not optimizer state)
filenamestrbestCheckpoint filename pattern
min_deltafloat0.001Minimum improvement for early stopping

Complete Example

dataset:
root_data_directory: D:/data
dataset_type: "roi"
input_size: 384
batch_size: 64
rebalance: true
stats:
mean: [0.554, 0.469, 0.348]
std: [0.203, 0.173, 0.144]
transforms:
train:
- name: Resize
params: { size: ${dataset.input_size} }
- name: RandomHorizontalFlip
params: { p: 0.5 }
val:
- name: Resize
params: { size: ${dataset.input_size} }
single_class:
enable: true
background_class_name: "background"
single_class_name: "wildlife"
discard_classes: ["rocks", "vegetation"]

model:
backbone: timm/vit_base_patch14_reg4_dinov2.lvd142m
pretrained: true
backbone_source: timm
dropout: 0.2
freeze_backbone: true
hidden_dim: 128
num_layers: 2

mlflow:
experiment_name: wildtrain_classification
run_name: "baseline_dinov2"
log_model: true

train:
epochs: 20
lr: 1e-3
weight_decay: 1e-3
precision: bf16-mixed
accelerator: auto
num_workers: 4

checkpoint:
monitor: val_f1score
save_top_k: 1
mode: max
patience: 10
dirpath: checkpoints/classification
filename: "best"

See also: