WildTrain Architecture
WildTrain is a modular training framework that supports both object detection (YOLO, MMDetection) and classification (PyTorch Lightning) with integrated experiment tracking and model management.
Overview
Purpose: Flexible model training and evaluation framework
Key Responsibilities: - Model training (detection and classification) - Experiment tracking with MLflow - Hyperparameter optimization - Model evaluation and metrics - Model registration and versioning
Architecture Diagram
graph TB
subgraph "Configuration Layer"
A[Hydra Config]
B[YAML Files]
end
subgraph "Data Layer"
C[WilData Integration]
D[DataModule]
E[DataLoaders]
end
subgraph "Model Layer"
F[YOLO Models]
G[MMDet Models]
H[Classification Models]
end
subgraph "Training Layer"
I[Trainer]
J[Training Loop]
K[Validation]
end
subgraph "Tracking Layer"
L[MLflow]
M[Metrics]
N[Artifacts]
end
subgraph "Output Layer"
O[Trained Models]
P[Model Registry]
Q[Checkpoints]
end
A --> I
B --> A
C --> D
D --> E
E --> J
F --> J
G --> J
H --> J
I --> J
J --> K
K --> M
M --> L
J --> O
O --> P
O --> Q
L --> N
style I fill:#e1f5ff
style L fill:#fff4e1
style P fill:#e8f5e9
Core Components
1. Model Architectures
YOLO Detection
# src/wildtrain/models/detection/yolo_model.py
from ultralytics import YOLO
class YOLODetector:
"""YOLO-based object detector."""
def __init__(self, model_size: str = "n"):
self.model = YOLO(f"yolo11{model_size}.pt")
def train(self, data_yaml: str, **kwargs):
"""Train YOLO model."""
results = self.model.train(
data=data_yaml,
epochs=kwargs.get('epochs', 100),
imgsz=kwargs.get('imgsz', 640),
batch=kwargs.get('batch', 16)
)
return results
def validate(self, data_yaml: str):
"""Validate model."""
return self.model.val(data=data_yaml)
MMDetection
# src/wildtrain/models/detection/mmdet_model.py
from mmdet.apis import init_detector, train_detector
from mmdet.apis import inference_detector
class MMDetDetector:
"""MMDetection-based detector."""
def __init__(self, config: str, checkpoint: Optional[str] = None):
self.config = config
self.model = init_detector(config, checkpoint) if checkpoint else None
def train(self, config_file: str, work_dir: str):
"""Train MMDet model."""
from mmcv import Config
cfg = Config.fromfile(config_file)
train_detector(self.model, cfg, distributed=False)
Classification Models
# src/wildtrain/models/classification/classifier.py
import pytorch_lightning as pl
import torchvision.models as models
class ImageClassifier(pl.LightningModule):
"""PyTorch Lightning classifier."""
def __init__(
self,
architecture: str = "resnet50",
num_classes: int = 10,
learning_rate: float = 0.001
):
super().__init__()
self.save_hyperparameters()
# Load pretrained model
self.model = getattr(models, architecture)(pretrained=True)
# Replace final layer
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log('val_loss', loss)
self.log('val_acc', acc)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
2. Data Modules
Detection DataModule
# src/wildtrain/data/detection_datamodule.py
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
class DetectionDataModule(LightningDataModule):
"""DataModule for detection datasets."""
def __init__(
self,
data_root: str,
batch_size: int = 32,
num_workers: int = 4
):
super().__init__()
self.data_root = data_root
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
# Load datasets using WilData
from wildata import DataPipeline
pipeline = DataPipeline(self.data_root)
if stage == "fit":
self.train_dataset = pipeline.load_dataset("train")
self.val_dataset = pipeline.load_dataset("val")
if stage == "test":
self.test_dataset = pipeline.load_dataset("test")
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers
)
Classification DataModule
# src/wildtrain/data/classification_datamodule.py
class ClassificationDataModule(LightningDataModule):
"""DataModule for classification datasets."""
def __init__(
self,
data_root: str,
image_size: int = 224,
batch_size: int = 32,
num_workers: int = 4
):
super().__init__()
self.data_root = data_root
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers
# Define transforms
self.train_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
3. Training Orchestration
Main Trainer
# src/wildtrain/trainers/trainer.py
from hydra.utils import instantiate
import mlflow
class Trainer:
"""Main training orchestrator."""
def __init__(self, config: DictConfig):
self.config = config
self.model = instantiate(config.model)
self.datamodule = instantiate(config.data)
def train(self):
"""Execute training loop."""
# Setup MLflow
mlflow.set_experiment(self.config.experiment_name)
with mlflow.start_run():
# Log parameters
mlflow.log_params(OmegaConf.to_container(self.config))
# Create PyTorch Lightning trainer
pl_trainer = pl.Trainer(
max_epochs=self.config.training.epochs,
accelerator=self.config.training.accelerator,
devices=self.config.training.devices,
callbacks=self._create_callbacks()
)
# Train
pl_trainer.fit(self.model, self.datamodule)
# Log metrics
metrics = pl_trainer.callback_metrics
mlflow.log_metrics({k: v.item() for k, v in metrics.items()})
# Save model
model_path = "trained_model"
pl_trainer.save_checkpoint(model_path)
mlflow.pytorch.log_model(self.model, "model")
return self.model
def _create_callbacks(self):
"""Create training callbacks."""
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
return [
ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3
),
EarlyStopping(
monitor='val_loss',
patience=10,
mode='min'
)
]
4. Evaluation System
Metrics Computation
# src/wildtrain/evaluation/metrics.py
from torchmetrics import Accuracy, Precision, Recall, F1Score
class ClassificationMetrics:
"""Compute classification metrics."""
def __init__(self, num_classes: int):
self.accuracy = Accuracy(num_classes=num_classes)
self.precision = Precision(num_classes=num_classes, average='macro')
self.recall = Recall(num_classes=num_classes, average='macro')
self.f1 = F1Score(num_classes=num_classes, average='macro')
def compute(self, predictions, targets):
"""Compute all metrics."""
return {
'accuracy': self.accuracy(predictions, targets),
'precision': self.precision(predictions, targets),
'recall': self.recall(predictions, targets),
'f1': self.f1(predictions, targets)
}
Detection Metrics
# src/wildtrain/evaluation/detection_metrics.py
from torchmetrics.detection import MeanAveragePrecision
class DetectionMetrics:
"""Compute detection metrics."""
def __init__(self):
self.map_metric = MeanAveragePrecision()
def compute(self, predictions, targets):
"""Compute mAP and related metrics."""
self.map_metric.update(predictions, targets)
results = self.map_metric.compute()
return {
'mAP': results['map'],
'mAP_50': results['map_50'],
'mAP_75': results['map_75']
}
5. Hyperparameter Optimization
Optuna Integration
# src/wildtrain/tuning/optuna_tuner.py
import optuna
from optuna.integration import PyTorchLightningPruningCallback
class OptunaTuner:
"""Hyperparameter tuning with Optuna."""
def __init__(self, config: DictConfig):
self.config = config
self.study = optuna.create_study(
direction="minimize",
study_name=config.study_name,
storage=config.storage
)
def objective(self, trial: optuna.Trial) -> float:
"""Optimization objective."""
# Suggest hyperparameters
lr = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2)
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])
# Update config
self.config.model.learning_rate = lr
self.config.data.batch_size = batch_size
# Train model
trainer = Trainer(self.config)
model = trainer.train()
# Return validation loss
return trainer.pl_trainer.callback_metrics['val_loss'].item()
def tune(self, n_trials: int = 50):
"""Run hyperparameter tuning."""
self.study.optimize(self.objective, n_trials=n_trials)
print(f"Best trial: {self.study.best_trial.params}")
return self.study.best_params
6. Model Registration
MLflow Model Registry
# src/wildtrain/registry/model_registry.py
import mlflow
from mlflow.tracking import MlflowClient
class ModelRegistry:
"""Manage models in MLflow registry."""
def __init__(self, tracking_uri: str):
mlflow.set_tracking_uri(tracking_uri)
self.client = MlflowClient()
def register_model(
self,
model_path: str,
model_name: str,
description: Optional[str] = None,
tags: Optional[Dict] = None
) -> str:
"""Register model to MLflow."""
# Log model
with mlflow.start_run():
mlflow.pytorch.log_model(model_path, "model")
run_id = mlflow.active_run().info.run_id
# Register
model_uri = f"runs:/{run_id}/model"
mv = mlflow.register_model(model_uri, model_name)
# Add description and tags
if description:
self.client.update_model_version(
name=model_name,
version=mv.version,
description=description
)
if tags:
for key, value in tags.items():
self.client.set_model_version_tag(
name=model_name,
version=mv.version,
key=key,
value=value
)
return mv.version
def load_model(self, model_name: str, version: Optional[str] = None):
"""Load model from registry."""
if version:
model_uri = f"models:/{model_name}/{version}"
else:
model_uri = f"models:/{model_name}/latest"
return mlflow.pytorch.load_model(model_uri)
def promote_model(self, model_name: str, version: str, stage: str):
"""Promote model to production stage."""
self.client.transition_model_version_stage(
name=model_name,
version=version,
stage=stage # "Staging", "Production", "Archived"
)
Configuration System
Hydra Configuration
WildTrain uses Hydra for flexible configuration management.
# src/wildtrain/main.py
import hydra
from omegaconf import DictConfig
@hydra.main(config_path="configs", config_name="main", version_base="1.3")
def main(cfg: DictConfig):
trainer = Trainer(cfg)
trainer.train()
if __name__ == "__main__":
main()
Configuration Structure
# configs/main.yaml
defaults:
- model: yolo
- data: detection
- training: default
- _self_
experiment_name: wildlife_detection
seed: 42
# Override from CLI:
# python main.py model=custom data.batch_size=64
Model Configs
# configs/detection/yolo.yaml
model:
framework: "yolo"
size: "n" # n, s, m, l, x
pretrained: true
training:
epochs: 100
imgsz: 640
batch: 16
optimizer: "AdamW"
lr0: 0.001
CLI Interface
# src/wildtrain/cli/train.py
import typer
app = typer.Typer()
@app.command()
def train(
task: str = typer.Option(..., help="Task type: classifier or detector"),
config: str = typer.Option(..., "-c", "--config"),
override: List[str] = typer.Option(None, "-o", "--override")
):
"""Train a model."""
# Load config with overrides
# Start training
@app.command()
def evaluate(
task: str,
config: str,
checkpoint: str
):
"""Evaluate a trained model."""
@app.command()
def tune(
config: str,
n_trials: int = 50
):
"""Run hyperparameter tuning."""
Use Cases
1. Train YOLO Detector
# Using CLI
wildtrain train detector -c configs/detection/yolo.yaml
# Using Python
from wildtrain import Trainer
trainer = Trainer.from_config("configs/detection/yolo.yaml")
model = trainer.train()
2. Train Classifier
import pytorch_lightning as pl
from wildtrain import ImageClassifier, ClassificationDataModule
# Create model and data
model = ImageClassifier(architecture="resnet50", num_classes=10)
datamodule = ClassificationDataModule(data_root="data/classification")
# Create trainer
trainer = pl.Trainer(max_epochs=50, accelerator="gpu")
# Train
trainer.fit(model, datamodule)
3. Hyperparameter Tuning
from wildtrain.tuning import OptunaTuner
tuner = OptunaTuner(config)
best_params = tuner.tune(n_trials=50)
print(f"Best hyperparameters: {best_params}")
4. Model Registration
from wildtrain.registry import ModelRegistry
registry = ModelRegistry(tracking_uri="http://localhost:5000")
# Register model
version = registry.register_model(
model_path="checkpoints/best.ckpt",
model_name="wildlife_detector",
description="YOLO model trained on aerial images",
tags={"framework": "yolo", "dataset": "wildlife_v1"}
)
# Promote to production
registry.promote_model("wildlife_detector", version, "Production")