QGIS:fine-tuning ou réentraînement d’un modèle U-Net


Le Deep Learning offre des possibilités incroyables pour analyser les images satellites et segmenter automatiquement des éléments d’intérêt, comme les coraux, les zones forestières ou urbaines. Dans cet article, nous explorons comment réentraîner ou affiner un modèle U-Net dans QGIS, afin d’adapter un modèle préexistant à vos propres données locales.

L’exemple illustré ici concerne la détection des coraux, mais il est important de souligner que les mêmes étapes et principes valent pour n’importe quel modèle U‑Net, quelle que soit la nature des images traitées. 



Introduction

Les modèles U-Net sont largement utilisés pour la segmentation d’images en raison de leur architecture efficace, capable de combiner les informations globales et locales d’une image. Bien que des modèles préentraînés existent pour des tâches générales ou spécifiques, leur performance peut être limitée lorsqu’on travaille avec des données locales ou particulières, par exemple des images Sentinel-2 d’une île tropicale.

Le fine-tuning, ou réentraînement partiel, permet de spécialiser un modèle préexistant sur un nouveau jeu de données tout en conservant les connaissances acquises lors du premier entraînement. Cette approche est particulièrement utile lorsqu’on dispose d’un nombre limité d’images annotées.

Dans QGIS, il est possible de combiner la puissance des modèles U-Net avec vos propres images grâce à des scripts Python utilisant PyTorch, ce qui ouvre la voie à une segmentation adaptée à vos besoins spécifiques. L’article détaillera les étapes clés, du choix du modèle au réentraînement sur vos données, en passant par la préparation des images et l’optimisation des hyperparamètres.


Procédure

Préparer vos données

Pour un modèle U-Net destiné à la segmentation de coraux, il vous faut :

  1. Images d’entrée (Sentinel-2 ou drone)

    • Les mêmes bandes que le modèle original (souvent RGB ou NIR inclus).
    • Taille et résolution cohérentes avec le modèle original si possible, sinon il faudra ajuster le prétraitement.

  2. Masques correspondants (ground truth)

    • Fichiers raster binaires ou multilabel où chaque pixel indique « corail », « eau », « autres ».
    • Ces masques doivent être alignés exactement sur vos images.

Pour générer ces masques : annotation manuelle avec QGIS (polygones ou rasters) ou logiciels d’annotation comme LabelMe, CVAT, ou Labelbox.


Charger le modèle U-Net existant

En PyTorch :

chargement du modèle

import torch
from segmentation_models_pytorch import Unet

# Chemin vers votre modèle existant
model_path = "unet_coraux.pth"

# Charger le modèle complet (architecture + poids)
model = torch.load(model_path, map_location='cpu', weights_only=False)
model.train()  # mettre en mode entraînement

✅ L’avantage : vous ne repartez pas de zéro, le modèle connaît déjà des motifs coralliens génériques.


Préparer le DataLoader

Pour entraîner un U-Net, il faut un DataLoader PyTorch pour fournir vos images et masques par lots :

from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision import transforms

class CoralDataset(Dataset):
    def __init__(self, image_files, mask_files, transform=None):
        self.image_files = image_files
        self.mask_files = mask_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img = np.load(self.image_files[idx]).astype(np.float32)  # ou utiliser rasterio
        mask = np.load(self.mask_files[idx]).astype(np.float32)
        
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
        
        return torch.from_numpy(img), torch.from_numpy(mask)

  • image_files et mask_files contiennent vos données locales.
  • Vous pouvez appliquer des transformations (rotations, flips, normalisation) pour augmenter la diversité.


Configurer l’entraînement

  • Critère de perte : pour la segmentation binaire torch.nn.BCEWithLogitsLoss() ou DiceLoss.
  • Optimiseur : torch.optim.Adam(model.parameters(), lr=1e-4) par exemple.
  • Batch size : dépend de la mémoire GPU.
  • Époques : 10‑50 pour fine-tuning, selon vos données.

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


Fine-tuning

fine-tuning

for epoch in range(num_epochs):
    for images, masks in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

  • Important : ne pas dépasser la capacité de votre GPU si les images sont grandes.
  • Vous pouvez geler les premières couches de l’encodeur (pré-entraînement) et n’entraîner que le décodeur pour éviter de tout réapprendre.

for param in model.encoder.parameters():
    param.requires_grad = False


Sauvegarder le modèle affiné

torch.save(model, "unet_coraux_maurice.pth")

  • Ce modèle pourra ensuite être utilisé directement dans votre script QGIS pour la segmentation locale.


Bonnes pratiques

  • Vérifiez que les bandes de vos images correspondent exactement aux attentes du modèle.
  • Normalisez vos images comme pour l’entraînement initial.
  • Commencez par un nombre réduit d’images pour tester le fine-tuning avant de lancer un gros entraînement.
  • Pensez à augmenter vos données (rotations, flips, zooms) pour éviter le surapprentissage.


Script : fine-tuning d’un modèle U-Net sur les coraux mauriciens

Il est conçu pour être simple, pédagogique et exécutable dans un environnement Python classique (hors QGIS, car QGIS n’a pas les dépendances d’entraînement).


# -*- coding: utf-8 -*-
"""
Fine-tuning d’un modèle U-Net PyTorch sur les coraux mauriciens
---------------------------------------------------------------
Ce script prend un modèle U-Net existant (fichier .pth)
et l’adapte à vos propres images et masques de coraux.
"""

import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import rasterio
from segmentation_models_pytorch import Unet
from tqdm import tqdm

# --- Configuration utilisateur ---
image_dir = "data/images/"       # dossier contenant les images Sentinel-2 (tuiles .tif)
mask_dir  = "data/masks/"        # dossier contenant les masques binaires correspondants
pretrained_model_path = "unet_coraux.pth"  # modèle existant
output_model_path = "unet_coraux_maurice.pth"
batch_size = 4
num_epochs = 20
learning_rate = 1e-4
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Dataset personnalisé ---
class CoralDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".tif")])
        self.mask_files  = sorted([os.path.join(mask_dir, f)  for f in os.listdir(mask_dir)  if f.endswith(".tif")])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, mask_path = self.image_files[idx], self.mask_files[idx]
        with rasterio.open(img_path) as src:
            img = src.read().astype(np.float32)
        with rasterio.open(mask_path) as src:
            mask = src.read(1).astype(np.float32)  # 1 canal

        # Normalisation 0-1
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)

        return torch.tensor(img), torch.tensor(mask).unsqueeze(0)  # (C,H,W), (1,H,W)

# --- Chargement des données ---
dataset = CoralDataset(image_dir, mask_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# --- Chargement du modèle pré-entraîné ---
print("Chargement du modèle existant...")
model = torch.load(pretrained_model_path, map_location=device)
model = model.to(device)
model.train()

# --- Option : geler l’encodeur (pré-entraînement) ---
for param in model.encoder.parameters():
    param.requires_grad = False

# --- Définir la fonction de perte et l’optimiseur ---
criterion = nn.BCEWithLogitsLoss()  # segmentation binaire
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)

# --- Entraînement ---
print(f"Début de l’entraînement ({num_epochs} époques)...")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, masks in tqdm(dataloader, desc=f"Époque {epoch+1}/{num_epochs}"):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)
    print(f"Époque {epoch+1}/{num_epochs} — perte moyenne : {avg_loss:.4f}")

# --- Sauvegarde du modèle affiné ---
torch.save(model, output_model_path)
print(f"✅ Modèle affiné sauvegardé sous : {output_model_path}")


Ce que fait le script

  • Recharge le modèle U-Net pré-entraîné sur d’autres coraux.
  • Gèle ses couches d’encodeur (celles apprises sur ImageNet ou Sentinel-2).
  • N’entraîne que le décodeur pour adapter la reconnaissance aux textures, couleurs et eaux mauriciennes.
  • Enregistre le nouveau modèle .pth prêt à être utilisé dans le script QGIS.


Organisation attendue des fichiers

data/
 ├── images/
 │    ├── 001.tif
 │    ├── 002.tif
 │    └── ...
 └── masks/
      ├── 001.tif
      ├── 002.tif
      └── ...

  • Chaque image doit avoir un mask du même nom et même dimension.
  • Les fichiers peuvent être issus de Sentinel-2, Planet ou drone (même résolution préférable).


Pour aller plus loin

  • Vous pouvez changer la fonction de perte pour mieux distinguer coraux/sable :
    from segmentation_models_pytorch.losses import DiceLoss criterion = DiceLoss(mode="binary")
  • Si vous souhaitez réentraîner tout le modèle (pas seulement le décodeur), supprimez la boucle :
    for param in model.encoder.parameters(): param.requires_grad = False
  • Si vous avez plusieurs classes (coraux, algues, sable, eau), remplacez la couche de sortie et la perte par une version multi-classes (nn.CrossEntropyLoss()).


Script : Validation et visualisation du modèle affiné

# -*- coding: utf-8 -*-
"""
Validation du modèle U-Net affiné sur les coraux mauriciens
-----------------------------------------------------------
Affiche la prédiction de masques coralliens à partir d'images Sentinel-2.
"""

import torch
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# --- Paramètres utilisateur ---
model_path = "unet_coraux_maurice.pth"  # modèle affiné
test_image_dir = "data/test_images/"     # dossier d’images Sentinel-2 à tester
output_dir = "results/"                  # dossier où enregistrer les masques prédits
os.makedirs(output_dir, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Chargement du modèle ---
print("Chargement du modèle...")
model = torch.load(model_path, map_location=device)
model.eval()

# --- Fonction utilitaire pour prédire un masque sur une image Sentinel-2 ---
def predict_mask(image_path):
    with rasterio.open(image_path) as src:
        img = src.read().astype(np.float32)
        profile = src.profile

    img = (img - img.min()) / (img.max() - img.min() + 1e-6)
    tensor = torch.from_numpy(img).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = torch.sigmoid(model(tensor))
        mask = pred.squeeze().cpu().numpy()

    return mask, profile

# --- Boucle de test sur les images ---
for file in tqdm(sorted(os.listdir(test_image_dir))):
    if not file.endswith(".tif"):
        continue

    image_path = os.path.join(test_image_dir, file)
    mask, profile = predict_mask(image_path)

    # Seuil pour binaire (ajuster selon ton modèle)
    binary_mask = (mask > 0.5).astype(np.uint8)

    # Enregistrement du masque prédictif
    out_path = os.path.join(output_dir, f"mask_{file}")
    profile.update(count=1, dtype='uint8')
    with rasterio.open(out_path, 'w', **profile) as dst:
        dst.write(binary_mask, 1)

    # --- Visualisation (optionnelle) ---
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.title("Image Sentinel-2 (RGB)")
    rgb = np.stack([mask for mask in (img[3:0:-1])], axis=-1) if img.shape[0]>=3 else img.transpose(1,2,0)
    plt.imshow(np.clip(rgb, 0, 1))
    plt.axis("off")

    plt.subplot(1,2,2)
    plt.title("Masque prédit (coraux)")
    plt.imshow(mask, cmap="turbo")
    plt.colorbar(label="Probabilité")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

print(f"✅ Résultats enregistrés dans : {output_dir}")


Ce que fait ce script

  • Charge le modèle affiné (unet_coraux_maurice.pth).
  • Lit chaque image Sentinel-2 du dossier data/test_images/.
  • Calcule la carte de probabilité (entre 0 et 1) pour les zones coralliennes.
  • Applique un seuil (0.5 par défaut) pour obtenir un masque binaire.
  • Sauvegarde le masque dans results/.
  • Affiche l’image et le masque côte à côte pour contrôle visuel.


Conseils d’analyse

  • Les zones rouge/orange sur la carte de chaleur sont celles que le modèle identifie comme coraux probables.
  • Ajustez le seuil 0.5 selon vos observations :

    • 0.3 → masque plus large, plus sensible (risque de faux positifs)
    • 0.7 → masque plus restreint, plus précis (mais risque d’oublier certains coraux)

  • Si vous observez de la confusion sable/coraux :

    • Ajoutez des masques d’eau pour filtrer la terre.
    • Ou enrichissez l’entraînement avec des échantillons équilibrés (eau claire, herbiers, sable).


Maintenant la suite logique : un script d’évaluation quantitative pour mesurer la performance du modèle affiné U-Net sur les images coralliennes mauriciennes.
Il calcule plusieurs indicateurs de qualité (IoU, F1-score, précision, rappel, etc.) à partir d’images de test et de leurs masques de référence.


Script : Évaluation du modèle U-Net sur les coraux

# -*- coding: utf-8 -*-
"""
Évaluation du modèle U-Net affiné sur les coraux mauriciens
-----------------------------------------------------------
Compare les masques prédits et les masques de référence.
"""

import os
import numpy as np
import torch
import rasterio
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score
from tqdm import tqdm

# --- Paramètres utilisateur ---
model_path = "unet_coraux_maurice.pth"
test_image_dir = "data/test_images/"     # images Sentinel-2 à tester
ref_mask_dir = "data/test_masks/"        # masques de référence (binaire 0/1)
output_dir = "results_eval/"
os.makedirs(output_dir, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Chargement du modèle ---
print("Chargement du modèle...")
model = torch.load(model_path, map_location=device)
model.eval()

# --- Fonction utilitaire pour prédire un masque ---
def predict_mask(image_path):
    with rasterio.open(image_path) as src:
        img = src.read().astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min() + 1e-6)
    tensor = torch.from_numpy(img).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = torch.sigmoid(model(tensor))
        mask = pred.squeeze().cpu().numpy()
    return (mask > 0.5).astype(np.uint8)

# --- Initialisation des listes de scores ---
iou_scores, f1_scores, prec_scores, rec_scores = [], [], [], []

# --- Boucle principale ---
for file in tqdm(sorted(os.listdir(test_image_dir))):
    if not file.endswith(".tif"):
        continue

    image_path = os.path.join(test_image_dir, file)
    ref_path = os.path.join(ref_mask_dir, file.replace(".tif", "_mask.tif"))
    if not os.path.exists(ref_path):
        print(f"⚠️ Pas de masque de référence pour {file}")
        continue

    # Prédiction
    pred_mask = predict_mask(image_path)

    # Chargement du masque de référence
    with rasterio.open(ref_path) as src:
        ref_mask = src.read(1).astype(np.uint8)

    # Mise à la même taille si besoin
    if pred_mask.shape != ref_mask.shape:
        min_h = min(pred_mask.shape[0], ref_mask.shape[0])
        min_w = min(pred_mask.shape[1], ref_mask.shape[1])
        pred_mask = pred_mask[:min_h, :min_w]
        ref_mask = ref_mask[:min_h, :min_w]

    # Vectorisation
    y_true = ref_mask.flatten()
    y_pred = pred_mask.flatten()

    # Calcul des métriques
    iou = jaccard_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)

    iou_scores.append(iou)
    f1_scores.append(f1)
    prec_scores.append(prec)
    rec_scores.append(rec)

# --- Résumé des résultats ---
print("\n=== Évaluation du modèle U-Net ===")
print(f"IoU moyen      : {np.mean(iou_scores):.3f}")
print(f"F1-score moyen : {np.mean(f1_scores):.3f}")
print(f"Précision moy. : {np.mean(prec_scores):.3f}")
print(f"Rappel moyen   : {np.mean(rec_scores):.3f}")

# Sauvegarde des résultats
with open(os.path.join(output_dir, "scores.txt"), "w") as f:
    f.write("=== Évaluation du modèle U-Net ===\n")
    f.write(f"IoU moyen      : {np.mean(iou_scores):.3f}\n")
    f.write(f"F1-score moyen : {np.mean(f1_scores):.3f}\n")
    f.write(f"Précision moy. : {np.mean(prec_scores):.3f}\n")
    f.write(f"Rappel moyen   : {np.mean(rec_scores):.3f}\n")

print(f"\n✅ Résultats enregistrés dans : {output_dir}")


Ce que l’on obtient

Indicateur Interprétation
IoU (Intersection over Union) Taux de recouvrement entre prédiction et vérité terrain.
F1-score Équilibre entre précision et rappel (plus robuste que la simple précision).
Précision Pourcentage de pixels prédits “corail” qui sont réellement des coraux.
Rappel Pourcentage de vrais coraux détectés par le modèle.


Interprétation pratique

Valeur Signification
IoU > 0.7 Excellent modèle
0.5 < IoU ≤ 0.7 Bon résultat, améliorable
IoU < 0.5 Manque de généralisation ou données mal équilibrées


Voici donc la suite directe : un script complémentaire qui compare visuellement les performances de deux modèles U-Net — par exemple le modèle d’origine (entraîné ailleurs) et le modèle affiné sur les coraux mauriciens — à partir des métriques enregistrées.


Script : Comparaison automatique de deux modèles U-Net

# -*- coding: utf-8 -*-
"""
Comparaison de deux modèles U-Net (original vs affiné)
------------------------------------------------------
Lit les résultats d’évaluation (scores.txt) et génère un graphique comparatif.
"""

import matplotlib.pyplot as plt
import numpy as np
import os

# --- Fichiers de résultats à comparer ---
model_names = ["U-Net original", "U-Net affiné (Maurice)"]
score_files = [
    "results_original/scores.txt",
    "results_eval/scores.txt"
]

# --- Lecture des scores ---
def read_scores(file_path):
    metrics = {}
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split(":")
            if len(parts) == 2:
                key, val = parts
                try:
                    metrics[key.strip()] = float(val)
                except:
                    pass
    return metrics

results = [read_scores(f) for f in score_files]

# --- Récupération des métriques ---
metrics_names = ["IoU moyen", "F1-score moyen", "Précision moy.", "Rappel moyen"]
values = np.array([[r[m] for m in metrics_names] for r in results])

# --- Création du graphique ---
x = np.arange(len(metrics_names))
width = 0.35

fig, ax = plt.subplots(figsize=(8, 5))
rects1 = ax.bar(x - width/2, values[0], width, label=model_names[0])
rects2 = ax.bar(x + width/2, values[1], width, label=model_names[1])

ax.set_ylabel("Score")
ax.set_title("Comparaison des performances des modèles U-Net")
ax.set_xticks(x)
ax.set_xticklabels(metrics_names, rotation=15)
ax.legend()

# Ajout des valeurs au-dessus des barres
def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f"{height:.2f}",
                    xy=(rect.get_x() + rect.get_width()/2, height),
                    xytext=(0, 3),  # décalage vertical
                    textcoords="offset points",
                    ha='center', va='bottom')
autolabel(rects1)
autolabel(rects2)

plt.tight_layout()
plt.show()


Résultat

Ce script génère un graphique en barres comparatif entre :

  • le U-Net original (par ex. modèle global, entraîné ailleurs) ;
  • le U-Net affiné sur données locales mauriciennes.

L’affichage permet de voir immédiatement :

  • si l’affinage améliore l’IoU ou le rappel (souvent le cas) ;
  • s’il change la précision, révélant parfois un modèle plus sensible mais moins spécifique.


Variante bonus (radar plot)

Si vous souhaitez un visuel plus synthétique (idéal pour un article), vous pouvez remplacer la partie graphique par :

from math import pi

metrics = metrics_names
num_vars = len(metrics)
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
values = np.concatenate((values, values[:, [0]]), axis=1)
angles += angles[:1]

fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))

for i in range(len(model_names)):
    ax.plot(angles, values[i], label=model_names[i])
    ax.fill(angles, values[i], alpha=0.25)

ax.set_xticks(angles[:-1])
ax.set_xticklabels(metrics)
ax.set_yticklabels([])
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
plt.title("Comparaison radar des performances U-Net")
plt.show()


Si cet article vous a intéressé et que vous pensez qu'il pourrait bénéficier à d'autres personnes, n'hésitez pas à le partager sur vos réseaux sociaux en utilisant les boutons ci-dessous. Votre partage est apprécié !

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *