DEV Community

Xiao Ling
Xiao Ling

Posted on

How to Train a YOLO Segmentation Model with MIDV500 Dataset for ID Document Detection

This tutorial will guide you through the entire process of training a YOLO segmentation model for ID document detection using the MIDV500 dataset and ultralytics. We'll cover everything including dataset acquisition, training, and creating a GUI application for document detection.

Demo Video: YOLO Segmentation for ID Document Detection

Getting the MIDV500 Dataset

The MIDV500 dataset is a comprehensive collection of identity documents designed for document detection and recognition tasks. It contains images of various identity documents from different countries.

midv500_links = [
    "ftp://smartengines.com/midv-500/dataset/01_alb_id.zip",
    "ftp://smartengines.com/midv-500/dataset/02_aut_drvlic_new.zip",
    "ftp://smartengines.com/midv-500/dataset/03_aut_id_old.zip",
    "ftp://smartengines.com/midv-500/dataset/04_aut_id.zip",
    "ftp://smartengines.com/midv-500/dataset/05_aze_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/06_bra_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/07_chl_id.zip",
    "ftp://smartengines.com/midv-500/dataset/08_chn_homereturn.zip",
    "ftp://smartengines.com/midv-500/dataset/09_chn_id.zip",
    "ftp://smartengines.com/midv-500/dataset/10_cze_id.zip",
    "ftp://smartengines.com/midv-500/dataset/11_cze_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/12_deu_drvlic_new.zip",
    "ftp://smartengines.com/midv-500/dataset/13_deu_drvlic_old.zip",
    "ftp://smartengines.com/midv-500/dataset/14_deu_id_new.zip",
    "ftp://smartengines.com/midv-500/dataset/15_deu_id_old.zip",
    "ftp://smartengines.com/midv-500/dataset/16_deu_passport_new.zip",
    "ftp://smartengines.com/midv-500/dataset/17_deu_passport_old.zip",
    "ftp://smartengines.com/midv-500/dataset/18_dza_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/19_esp_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/20_esp_id_new.zip",
    "ftp://smartengines.com/midv-500/dataset/21_esp_id_old.zip",
    "ftp://smartengines.com/midv-500/dataset/22_est_id.zip",
    "ftp://smartengines.com/midv-500/dataset/23_fin_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/24_fin_id.zip",
    "ftp://smartengines.com/midv-500/dataset/25_grc_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/26_hrv_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/27_hrv_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/28_hun_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/29_irn_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/30_ita_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/31_jpn_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/32_lva_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/33_mac_id.zip",
    "ftp://smartengines.com/midv-500/dataset/34_mda_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/35_nor_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/36_pol_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/37_prt_id.zip",
    "ftp://smartengines.com/midv-500/dataset/38_rou_drvlic.zip",
    "ftp://smartengines.com/midv-500/dataset/39_rus_internalpassport.zip",
    "ftp://smartengines.com/midv-500/dataset/40_srb_id.zip",
    "ftp://smartengines.com/midv-500/dataset/41_srb_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/42_svk_id.zip",
    "ftp://smartengines.com/midv-500/dataset/43_tur_id.zip",
    "ftp://smartengines.com/midv-500/dataset/44_ukr_id.zip",
    "ftp://smartengines.com/midv-500/dataset/45_ukr_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/46_ury_passport.zip",
    "ftp://smartengines.com/midv-500/dataset/47_usa_bordercrossing.zip",
    "ftp://smartengines.com/midv-500/dataset/48_usa_passportcard.zip",
    "ftp://smartengines.com/midv-500/dataset/49_usa_ssn82.zip",
    "ftp://smartengines.com/midv-500/dataset/50_xpo_id.zip",
]


midv2019_links = [
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/01_alb_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/02_aut_drvlic_new.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/03_aut_id_old.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/04_aut_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/05_aze_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/06_bra_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/07_chl_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/08_chn_homereturn.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/09_chn_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/10_cze_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/11_cze_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/12_deu_drvlic_new.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/13_deu_drvlic_old.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/14_deu_id_new.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/15_deu_id_old.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/16_deu_passport_new.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/17_deu_passport_old.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/18_dza_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/19_esp_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/20_esp_id_new.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/21_esp_id_old.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/22_est_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/23_fin_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/24_fin_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/25_grc_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/26_hrv_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/27_hrv_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/28_hun_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/29_irn_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/30_ita_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/31_jpn_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/32_lva_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/33_mac_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/34_mda_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/35_nor_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/36_pol_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/37_prt_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/38_rou_drvlic.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/39_rus_internalpassport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/40_srb_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/41_srb_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/42_svk_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/43_tur_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/44_ukr_id.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/45_ukr_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/46_ury_passport.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/47_usa_bordercrossing.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/48_usa_passportcard.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/49_usa_ssn82.zip",
    "ftp://smartengines.com/midv-500/extra/midv-2019/dataset/50_xpo_id.zip",
]
Enter fullscreen mode Exit fullscreen mode

Installation

pip install midv500
Enter fullscreen mode Exit fullscreen mode

Download

import midv500

dataset_dir = 'midv500_data/'

dataset_name = "all"
midv500.download_dataset(dataset_dir, dataset_name)
Enter fullscreen mode Exit fullscreen mode

The expected directory structure after downloading is as follows:

midv500_data/
  midv500/
    01_alb_id/
      ground_truth/
        01_alb_id.json
        CA/, CS/, HA/, HS/, KA/, KS/, PA/, PS/, TA/, TS/
          *.json (annotation files)
      images/
        01_alb_id.tif
        CA/, CS/, HA/, HS/, KA/, KS/, PA/, PS/, TA/, TS/
          *.tif (image files)
      videos/
        *.MOV, *.mp4 (video files)
    02_alb_pass/
      ... (similar structure)
    ... (50 document types total)
Enter fullscreen mode Exit fullscreen mode

Converting MIDV500 to YOLO Format

The MIDV500 dataset uses JSON annotations, but YOLO requires a specific format. We need to convert the dataset using a custom Python script.

Prerequisites

Install the required dependencies:

pip install ultralytics opencv-python numpy tqdm pyyaml pillow
Enter fullscreen mode Exit fullscreen mode

Conversion Script

Create and run the conversion script:

import os, re, json, shutil, random, math
from pathlib import Path
import yaml
import cv2
import numpy as np
from tqdm import tqdm

MIDV_ROOT = Path("midv500_data/midv500")        
OUT_ROOT  = Path("./dataset")               
TRAIN_SPLIT = 0.9                            
CLASS_NAME = "document"                     

IMG_EXT = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
JSON_KEYS = ["quad", "quadrilateral", "polygon", "points"]

def find_images(root: Path):
    return [p for p in root.rglob("*") if p.suffix.lower() in IMG_EXT]

def parse_annotation(img_path: Path):    
    if "images" in img_path.parts:
        parts = list(img_path.parts)
        try:
            images_idx = parts.index("images")
            parts[images_idx] = "ground_truth"
            json_path = Path(*parts).with_suffix(".json")

            if json_path.exists():
                try:
                    data = json.loads(json_path.read_text())
                    if "quad" in data and isinstance(data["quad"], list) and len(data["quad"]) >= 4:
                        quad = data["quad"][:4]
                        result = []
                        for pt in quad:
                            if isinstance(pt, (list, tuple)) and len(pt) >= 2:
                                result.append([float(pt[0]), float(pt[1])])
                        if len(result) == 4:
                            return result

                except Exception as e:
                    print(f"Error parsing {json_path}: {e}")
                    pass
        except ValueError:
            pass

    sidecar = img_path.with_suffix(".json")
    if sidecar.exists():
        try:
            data = json.loads(sidecar.read_text())
            for k in JSON_KEYS:
                if k in data and isinstance(data[k], (list, tuple)) and len(data[k]) >= 4:
                    pts = data[k][:4]
                    quad = []
                    for pt in pts:
                        if isinstance(pt, dict):
                            quad.append([float(pt["x"]), float(pt["y"])])
                        else:
                            quad.append([float(pt[0]), float(pt[1])])
                    return quad
        except Exception:
            pass

    for parent in [img_path.parent, img_path.parent.parent]:
        gt = parent / "ground_truth.txt"
        if gt.exists():
            stem = img_path.stem
            with gt.open("r", encoding="utf-8", errors="ignore") as f:
                for line in f:
                    parts = re.split(r"\s+", line.strip())
                    if len(parts) >= 9:
                        # first token can be filename or id
                        key = Path(parts[0]).stem
                        if key == stem:
                            try:
                                vals = list(map(float, parts[1:9]))
                                quad = [[vals[0], vals[1]],
                                        [vals[2], vals[3]],
                                        [vals[4], vals[5]],
                                        [vals[6], vals[7]]]
                                return quad
                            except Exception:
                                pass
    return None

def order_quad(quad):
    """Order points TL, TR, BR, BL."""
    pts = np.array(quad, dtype=np.float32)
    s = pts.sum(axis=1)
    d = np.diff(pts, axis=1).ravel()
    ordered = np.array([pts[np.argmin(s)],
                        pts[np.argmin(d)],
                        pts[np.argmax(s)],
                        pts[np.argmax(d)]], dtype=np.float32)
    return ordered.tolist()

def main():
    all_imgs = find_images(MIDV_ROOT)
    if not all_imgs:
        raise SystemExit(f"No images found under {MIDV_ROOT}")

    records = []
    for img in tqdm(all_imgs, desc="Scanning"):
        quad = parse_annotation(img)
        if quad is None:
            continue
        im = cv2.imread(str(img))
        if im is None:
            continue
        h, w = im.shape[:2]
        quad = order_quad(quad)
        norm = []
        for x,y in quad:
            x = max(0, min(w-1, x)) / w
            y = max(0, min(h-1, y)) / h
            norm.append([x, y])
        records.append((img, norm, (w,h)))

    if not records:
        raise SystemExit("Found no usable annotations. Adjust parse_annotation().")

    random.shuffle(records)
    n_train = int(len(records) * TRAIN_SPLIT)
    train, val = records[:n_train], records[n_train:]

    (OUT_ROOT / "images" / "train").mkdir(parents=True, exist_ok=True)
    (OUT_ROOT / "images" / "val").mkdir(parents=True, exist_ok=True)
    (OUT_ROOT / "labels" / "train").mkdir(parents=True, exist_ok=True)
    (OUT_ROOT / "labels" / "val").mkdir(parents=True, exist_ok=True)

    def write_example(split_dir, img_path, norm_quad):
        dst_img = OUT_ROOT / "images" / split_dir / (img_path.stem + ".jpg")
        dst_lbl = OUT_ROOT / "labels" / split_dir / (img_path.stem + ".txt")

        img = cv2.imread(str(img_path))
        if img is not None:
            cv2.imwrite(str(dst_img), img, [cv2.IMWRITE_JPEG_QUALITY, 95])
        else:
            shutil.copy2(img_path, dst_img)

        flat = " ".join(f"{x:.6f} {y:.6f}" for x,y in norm_quad)
        dst_lbl.write_text(f"0 {flat}\n")

    for img, quad, _ in tqdm(train, desc="Write train"):
        write_example("train", img, quad)
    for img, quad, _ in tqdm(val, desc="Write val"):
        write_example("val", img, quad)

    data = {
        "path": str(OUT_ROOT.resolve()),  
        "train": "images/train",
        "val": "images/val",
        "names": [CLASS_NAME],
        "nc": 1,
    }
    yaml.safe_dump(data, open(OUT_ROOT / "doc.yaml", "w"))
    print(f"\nDONE. Wrote dataset at: {OUT_ROOT}\nTotal: {len(records)}, Train: {len(train)}, Val: {len(val)}")

if __name__ == "__main__":
    main()

Enter fullscreen mode Exit fullscreen mode

Running the Conversion

Execute the conversion script:

python prep_midv500_to_yolov11seg.py
Enter fullscreen mode Exit fullscreen mode

Dataset Visualization and Verification

Before training, it's crucial to verify that the dataset conversion was successful. We'll create a visualization tool to inspect the annotations.

# dataset_viewer.py
import sys
import os
import cv2
import numpy as np
from pathlib import Path
from PySide6.QtWidgets import (QApplication, QMainWindow, QVBoxLayout, 
                               QHBoxLayout, QWidget, QLabel, QPushButton, 
                               QComboBox, QSlider, QCheckBox)
from PySide6.QtCore import Qt
from PySide6.QtGui import QPixmap, QImage

class DatasetViewer(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("YOLO Dataset Viewer")
        self.setGeometry(100, 100, 1200, 800)

        self.dataset_path = Path("dataset")
        self.current_split = "train"
        self.current_index = 0
        self.show_annotations = True
        self.image_files = []

        self.init_ui()
        self.load_dataset()
        self.display_current_image()

    def init_ui(self):
        central_widget = QWidget()
        self.setCentralWidget(central_widget)

        layout = QVBoxLayout(central_widget)

        controls_layout = QHBoxLayout()

        self.split_combo = QComboBox()
        self.split_combo.addItems(["train", "val"])
        self.split_combo.currentTextChanged.connect(self.change_split)
        controls_layout.addWidget(QLabel("Split:"))
        controls_layout.addWidget(self.split_combo)

        self.prev_btn = QPushButton("Previous")
        self.prev_btn.clicked.connect(self.prev_image)
        controls_layout.addWidget(self.prev_btn)

        self.next_btn = QPushButton("Next")
        self.next_btn.clicked.connect(self.next_image)
        controls_layout.addWidget(self.next_btn)

        self.annotation_checkbox = QCheckBox("Show Annotations")
        self.annotation_checkbox.setChecked(True)
        self.annotation_checkbox.toggled.connect(self.toggle_annotations)
        controls_layout.addWidget(self.annotation_checkbox)

        self.counter_label = QLabel("0 / 0")
        controls_layout.addWidget(self.counter_label)

        controls_layout.addStretch()
        layout.addLayout(controls_layout)

        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
        self.image_label.setStyleSheet("border: 1px solid gray;")
        layout.addWidget(self.image_label)

    def load_dataset(self):
        """Load image files for current split"""
        images_dir = self.dataset_path / "images" / self.current_split
        if images_dir.exists():
            self.image_files = sorted(list(images_dir.glob("*.jpg")))
        else:
            self.image_files = []

        self.current_index = 0
        self.update_counter()

    def display_current_image(self):
        """Display the current image with optional annotations"""
        if not self.image_files:
            self.image_label.setText("No images found")
            return

        if self.current_index >= len(self.image_files):
            self.current_index = 0

        img_path = self.image_files[self.current_index]
        img = cv2.imread(str(img_path))

        if img is None:
            self.image_label.setText("Failed to load image")
            return

        if self.show_annotations:
            img = self.draw_annotations(img, img_path)

        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w, ch = img_rgb.shape
        bytes_per_line = ch * w
        qt_image = QImage(img_rgb.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)

        pixmap = QPixmap.fromImage(qt_image)
        scaled_pixmap = pixmap.scaled(self.image_label.size(), 
                                     Qt.AspectRatioMode.KeepAspectRatio,
                                     Qt.TransformationMode.SmoothTransformation)

        self.image_label.setPixmap(scaled_pixmap)
        self.update_counter()

    def draw_annotations(self, img, img_path):
        """Draw YOLO annotations on image"""
        labels_dir = self.dataset_path / "labels" / self.current_split
        label_path = labels_dir / f"{img_path.stem}.txt"

        if not label_path.exists():
            return img

        h, w = img.shape[:2]

        with open(label_path, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                parts = line.split()
                if len(parts) < 9:  
                    continue

                class_id = int(parts[0])
                coords = list(map(float, parts[1:]))

                points = []
                for i in range(0, len(coords), 2):
                    x = int(coords[i] * w)
                    y = int(coords[i+1] * h)
                    points.append([x, y])

                points = np.array(points, np.int32)
                cv2.polylines(img, [points], True, (0, 255, 0), 2)

                cv2.putText(img, "document", tuple(points[0]), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        return img

    def update_counter(self):
        """Update the image counter display"""
        total = len(self.image_files)
        current = self.current_index + 1 if total > 0 else 0
        self.counter_label.setText(f"{current} / {total}")

    def change_split(self, split):
        """Change the dataset split"""
        self.current_split = split
        self.load_dataset()
        self.display_current_image()

    def prev_image(self):
        """Go to previous image"""
        if self.image_files:
            self.current_index = (self.current_index - 1) % len(self.image_files)
            self.display_current_image()

    def next_image(self):
        """Go to next image"""
        if self.image_files:
            self.current_index = (self.current_index + 1) % len(self.image_files)
            self.display_current_image()

    def toggle_annotations(self, checked):
        """Toggle annotation visibility"""
        self.show_annotations = checked
        self.display_current_image()

if __name__ == "__main__":
    app = QApplication(sys.argv)
    viewer = DatasetViewer()
    viewer.show()
    sys.exit(app.exec())
Enter fullscreen mode Exit fullscreen mode

Run the dataset viewer:

python dataset_viewer.py
Enter fullscreen mode Exit fullscreen mode

YOLO dataset viewer

Training the YOLO Model

Now that we have prepared and verified our dataset, let's train the YOLO segmentation model and export it as an ONNX model.

Training Script

import subprocess
import sys
import os
import yaml
import tempfile
from pathlib import Path

def create_dynamic_config():
    """Create a temporary YAML config with dynamic absolute path."""
    dataset_dir = Path(__file__).parent / "dataset"
    dataset_dir = dataset_dir.resolve()  

    config = {
        'names': ['document'],
        'nc': 1,
        'path': str(dataset_dir),
        'train': 'images/train',
        'val': 'images/val'
    }

    temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False)
    yaml.dump(config, temp_file, default_flow_style=False)
    temp_file.close()

    print(f"Created temporary config: {temp_file.name}")
    print(f"Dataset path: {dataset_dir}")

    return temp_file.name

def cleanup_temp_file(temp_path):
    """Clean up temporary file."""
    try:
        os.unlink(temp_path)
        print(f"Cleaned up temporary config: {temp_path}")
    except:
        pass

def run_command(cmd, description):
    """Run a command and print its output."""
    print(f"\n=== {description} ===")
    print(f"Running: {cmd}")

    try:
        result = subprocess.run(cmd, shell=True, check=True)
        return True
    except subprocess.CalledProcessError as e:
        print(f"Command failed with exit code {e.returncode}")
        return False

def main():
    if not Path("dataset/images/train").exists() or not Path("dataset/images/val").exists():
        print("Error: dataset/images/train or dataset/images/val not found.")
        print("Please run prep_midv500_to_yolov11seg.py first.")
        sys.exit(1)

    temp_config = create_dynamic_config()

    try:
        download_cmd = "yolo task=segment model=yolo11s-seg.pt"
        print("Ensuring YOLO11s-seg model is available...")
        subprocess.run(download_cmd, shell=True, check=False, capture_output=True)

        train_cmd = (
            f"yolo task=segment mode=train model=yolo11s-seg.pt data={temp_config} "
            "imgsz=640 epochs=20 batch=16 device=0 patience=5"
        )

        print("Starting YOLO11 Segmentation Training for Document Detection")
        print("Dataset: MIDV500 (converted)")
        print("Model: YOLO11s-seg")
        print("Image size: 640")
        print("Epochs: 20")
        print("Batch size: 16")
        print("Device: GPU 0")
        print("Early stopping patience: 5 epochs")

        if run_command(train_cmd, "Training YOLO11 Segmentation Model"):
            print("\n=== Training completed successfully! ===")

            runs_dir = Path("runs/segment")
            if runs_dir.exists():
                train_dirs = [d for d in runs_dir.iterdir() if d.is_dir() and d.name.startswith("train")]
                if train_dirs:
                    latest_train = max(train_dirs, key=lambda x: x.stat().st_mtime)
                    best_model = latest_train / "weights" / "best.pt"

                    if best_model.exists():
                        print(f"\nBest model saved at: {best_model}")

                        export_cmd = f"yolo export model={best_model} format=onnx opset=12"

                        if run_command(export_cmd, "Exporting model to ONNX"):
                            onnx_model = best_model.with_suffix(".onnx")
                            print(f"\nONNX model exported to: {onnx_model}")
                            print("\n=== Training and Export Complete! ===")
                            print(f"PyTorch model: {best_model}")
                            print(f"ONNX model: {onnx_model}")
                        else:
                            print("ONNX export failed, but training completed.")
                    else:
                        print("Warning: best.pt not found in expected location.")
        else:
            print("Training failed. Please check the error messages above.")
            sys.exit(1)

    finally:
        cleanup_temp_file(temp_config)

if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

Training with Python Script

The easiest way to train the model is using the provided Python script that handles all the configuration automatically:

python train_yolo_doc_detection.py
Enter fullscreen mode Exit fullscreen mode

YOLO segmentation training

Alternative: Command Line Training

If you prefer direct YOLO commands, you can also use:

# Basic training command (requires proper path configuration in doc.yaml)
yolo task=segment mode=train model=yolov11s-seg.pt data=dataset/doc.yaml imgsz=640 epochs=30 batch=16 device=0

# Training with GPU acceleration
yolo task=segment mode=train model=yolov11s-seg.pt data=dataset/doc.yaml imgsz=640 epochs=30 batch=16 device=0 amp=True

# Training with CPU (slower but works without GPU)
yolo task=segment mode=train model=yolov11s-seg.pt data=dataset/doc.yaml imgsz=640 epochs=30 batch=8 device=cpu

# Resume training from checkpoint
yolo task=segment mode=train model=runs/segment/train/weights/last.pt resume=True
Enter fullscreen mode Exit fullscreen mode

Note: Direct YOLO commands require absolute paths in the doc.yaml file. The Python script handles this automatically.

Training Parameters Explained

  • model: Pre-trained YOLO11 segmentation model
  • data: Path to dataset YAML configuration file
  • epochs: Number of training iterations
  • imgsz: Input image size (640 is standard)
  • batch: Batch size (adjust based on GPU memory)
  • device: Training device (0 for GPU, cpu for CPU)
  • amp: Automatic Mixed Precision for faster training

Exporting to ONNX Format

After training, export the model to ONNX format for deployment and inference optimization.

# Export best model to ONNX
yolo export model=runs/segment/document_detection/weights/best.pt format=onnx opset=12

# Export with optimization
yolo export model=runs/segment/document_detection/weights/best.pt format=onnx opset=12 simplify=True half=False

# Export multiple formats
yolo export model=runs/segment/document_detection/weights/best.pt format=onnx,torchscript,tflite
Enter fullscreen mode Exit fullscreen mode

ONNX Model Testing

import os, sys, warnings
import numpy as np
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"     # don't let PyTorch probe CUDA
os.environ["ORT_LOGGING_LEVEL"] = "4"         # onnxruntime: errors only
warnings.filterwarnings("ignore")

try:
    import onnxruntime as ort
    ort.set_default_logger_severity(4)
except ImportError:
    print("❌ onnxruntime not installed. Run: pip install onnxruntime")
    sys.exit(1)

from ultralytics import YOLO

MODEL_ONNX = "runs/segment/train/weights/best.onnx"  # your exported ONNX
IMAGE_PATH = "CA01_01.jpg"                            # test image
IMG_SIZE   = 640
CONF       = 0.25

def order_pts(pts: np.ndarray) -> np.ndarray:
    """Order 4 points TL, TR, BR, BL."""
    s = pts.sum(1)
    d = np.diff(pts, axis=1).ravel()
    return np.array([pts[np.argmin(s)],
                     pts[np.argmin(d)],
                     pts[np.argmax(s)],
                     pts[np.argmax(d)]], dtype=np.float32)

def polygon_to_quad(poly_xy: np.ndarray) -> np.ndarray:
    """
    poly_xy: (K,2) float32 in original image coordinates.
    Returns ordered 4x2 quad (float32). Uses approxPolyDP, with minAreaRect fallback.
    """
    cnt = poly_xy.reshape(-1, 1, 2).astype(np.float32)
    peri = cv2.arcLength(cnt, True)

    for eps in (0.03, 0.025, 0.02, 0.015, 0.01):
        approx = cv2.approxPolyDP(cnt, eps * peri, True)
        if len(approx) == 4:
            return order_pts(approx.reshape(-1, 2).astype(np.float32))

    rect = cv2.minAreaRect(cnt)
    box = cv2.boxPoints(rect).astype(np.float32)  
    return order_pts(box)

def main():
    if not os.path.exists(MODEL_ONNX):
        print(f"❌ ONNX model not found: {MODEL_ONNX}")
        sys.exit(1)
    if not os.path.exists(IMAGE_PATH):
        print(f"❌ Image not found: {IMAGE_PATH}")
        sys.exit(1)

    print("Providers (onnxruntime):", ort.get_available_providers())
    print(f"Loading ONNX model: {MODEL_ONNX}")
    model = YOLO(MODEL_ONNX)      
    print("✅ ONNX model loaded")

    img = cv2.imread(IMAGE_PATH)
    if img is None:
        print(f"❌ Failed to read image: {IMAGE_PATH}")
        sys.exit(1)
    H, W = img.shape[:2]
    print(f"Image size: {W}x{H}")

    res = model(img, imgsz=IMG_SIZE, conf=CONF, device='cpu')[0]
    print("✅ Inference completed")

    if res.masks is None or len(res.masks.data) == 0:
        print("⚠️  No document detected.")
        return

    polys = res.masks.xy  
    areas = [cv2.contourArea(p.astype(np.float32)) for p in polys]
    idx = int(np.argmax(areas))
    poly = polys[idx].astype(np.float32)

    quad = polygon_to_quad(poly)
    w_out = int(max(np.linalg.norm(quad[1]-quad[0]), np.linalg.norm(quad[2]-quad[3])))
    h_out = int(max(np.linalg.norm(quad[2]-quad[1]), np.linalg.norm(quad[3]-quad[0])))
    dst = np.array([[0, 0],
                    [w_out - 1, 0],
                    [w_out - 1, h_out - 1],
                    [0, h_out - 1]], dtype=np.float32)
    M = cv2.getPerspectiveTransform(quad, dst)
    warped = cv2.warpPerspective(img, M, (w_out, h_out))

    overlay = img.copy()
    cv2.polylines(overlay, [quad.astype(int)], True, (0, 255, 0), 3)

    cv2.imwrite("detect_overlay.jpg", overlay)
    cv2.imwrite("document_warped.jpg", warped)

    print("Quad (TL, TR, BR, BL):")
    for i, (x, y) in enumerate(quad):
        print(f"  P{i}: ({x:.1f}, {y:.1f})")
    print("Saved: detect_overlay.jpg, document_warped.jpg")

if __name__ == "__main__":
    main()

Enter fullscreen mode Exit fullscreen mode

Creating a GUI Application

Now let's create a GUI application for document detection using PySide6.

Document Detection Engine

class DocumentDetector:
    """Core document detection functionality"""

    def __init__(self):
        self.model = None
        self.load_model()

    def load_model(self):
        """Load the trained model (PyTorch first, then ONNX fallback)"""
        model_paths = [
            ("runs/segment/train/weights/best.pt", "PyTorch"),
            ("runs/segment/train/weights/best.onnx", "ONNX"),
            ("yolo11s-seg.pt", "Pre-trained")
        ]

        for model_path, model_type in model_paths:
            if os.path.exists(model_path):
                try:
                    self.model = YOLO(model_path)
                    print(f"{model_type} model loaded: {model_path}")
                    return
                except Exception as e:
                    print(f"❌ Failed to load {model_type} model: {e}")
                    continue

        raise RuntimeError("No working model found!")

    @staticmethod
    def order_points(pts):
        """Order points in clockwise order starting from top-left"""
        s = pts.sum(1)
        d = np.diff(pts, axis=1).ravel()
        return np.array([
            pts[np.argmin(s)],      # top-left
            pts[np.argmin(d)],      # top-right
            pts[np.argmax(s)],      # bottom-right
            pts[np.argmax(d)]       # bottom-left
        ], dtype="float32")

    @staticmethod
    def polygon_to_quad(poly_xy: np.ndarray) -> np.ndarray:
        cnt = poly_xy.reshape(-1, 1, 2).astype(np.float32)
        peri = cv2.arcLength(cnt, True)

        for eps in (0.03, 0.025, 0.02, 0.015, 0.01):
            approx = cv2.approxPolyDP(cnt, eps * peri, True)
            if len(approx) == 4:
                return DocumentDetector.order_points(approx.reshape(-1, 2).astype(np.float32))

        rect = cv2.minAreaRect(cnt)
        box = cv2.boxPoints(rect).astype(np.float32)  
        return DocumentDetector.order_points(box)

    def detect_document(self, image: np.ndarray, conf: float = 0.25) -> Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
        if self.model is None:
            raise RuntimeError("Model not loaded")

        orig_h, orig_w = image.shape[:2]
        print(f"Debug - Image size: {orig_w}x{orig_h}")

        try:
            results = self.model(image, imgsz=640, conf=conf, device='cpu')[0]
        except Exception as e:
            raise RuntimeError(f"Inference failed: {e}")

        if results.masks is None or len(results.masks.data) == 0:
            return None

        polys = results.masks.xy  
        if not polys:
            return None

        areas = [cv2.contourArea(p.astype(np.float32)) for p in polys]
        idx = int(np.argmax(areas))
        poly = polys[idx].astype(np.float32)

        print(f"Debug - Selected polygon has {len(poly)} points, area: {areas[idx]:.0f}")

        quad = self.polygon_to_quad(poly)
        if quad is None:
            return None

        overlay = image.copy()
        cv2.polylines(overlay, [quad.astype(int)], True, (0, 255, 0), 3)

        for i, point in enumerate(quad):
            x, y = int(round(point[0])), int(round(point[1]))
            cv2.circle(overlay, (x, y), 8, (255, 0, 0), -1)
            cv2.putText(overlay, str(i), (x + 10, y - 10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

        h, w = image.shape[:2]
        w_out = int(max(np.linalg.norm(quad[1] - quad[0]), np.linalg.norm(quad[2] - quad[3])))
        h_out = int(max(np.linalg.norm(quad[2] - quad[1]), np.linalg.norm(quad[3] - quad[0])))

        dst = np.array([[0, 0], [w_out - 1, 0], [w_out - 1, h_out - 1], [0, h_out - 1]], dtype="float32")
        M = cv2.getPerspectiveTransform(quad, dst)
        rectified = cv2.warpPerspective(image, M, (w_out, h_out))

        return overlay, rectified, quad
Enter fullscreen mode Exit fullscreen mode

Main GUI Application

class DocumentDetectorGUI(QMainWindow):

    def __init__(self):
        super().__init__()
        self.detector = None
        self.current_image = None
        self.current_overlay = None
        self.current_rectified = None
        self.current_quad = None

        self.detection_thread = None
        self.detection_worker = None

        self.init_ui()
        self.init_detector()

    def init_ui(self):
        self.setWindowTitle("Document Detection GUI")
        self.setGeometry(100, 100, 1400, 800)

        central_widget = QWidget()
        self.setCentralWidget(central_widget)

        main_layout = QHBoxLayout(central_widget)

        left_panel = self.create_left_panel()
        main_layout.addWidget(left_panel)

        right_panel = self.create_right_panel()
        main_layout.addWidget(right_panel, stretch=3)

        self.statusBar().showMessage("Ready - Drop an image or click Browse to start")

    def create_left_panel(self):
        """Create the left control panel"""
        panel = QFrame()
        panel.setMaximumWidth(300)
        panel.setFrameStyle(QFrame.Shape.Box)
        layout = QVBoxLayout(panel)

        title = QLabel("Document Detection")
        title.setAlignment(Qt.AlignmentFlag.AlignCenter)
        title.setFont(QFont("Arial", 16, QFont.Weight.Bold))
        layout.addWidget(title)

        self.browse_btn = QPushButton("📁 Browse Image")
        self.browse_btn.clicked.connect(self.browse_image)
        layout.addWidget(self.browse_btn)

        self.detect_btn = QPushButton("🔍 Detect Document")
        self.detect_btn.clicked.connect(self.detect_document)
        self.detect_btn.setEnabled(False)
        layout.addWidget(self.detect_btn)

        self.save_btn = QPushButton("💾 Save Rectified")
        self.save_btn.clicked.connect(self.save_rectified)
        self.save_btn.setEnabled(False)
        layout.addWidget(self.save_btn)

        layout.addWidget(QLabel())  # Spacer

        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        layout.addWidget(self.progress_bar)

        info_group = QGroupBox("Detection Info")
        info_layout = QVBoxLayout(info_group)

        self.info_text = QTextEdit()
        self.info_text.setMaximumHeight(200)
        self.info_text.setFont(QFont("Consolas", 9))
        info_layout.addWidget(self.info_text)

        layout.addWidget(info_group)

        layout.addItem(QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding))

        return panel

    def create_right_panel(self):
        panel = QWidget()
        layout = QVBoxLayout(panel)

        splitter = QSplitter(Qt.Orientation.Horizontal)

        original_group = QGroupBox("Original with Overlay")
        original_layout = QVBoxLayout(original_group)

        self.original_scroll = QScrollArea()
        self.original_label = DropLabel("Drop image here or click Browse")
        self.original_label.file_dropped.connect(self.load_image)
        self.original_scroll.setWidget(self.original_label)
        self.original_scroll.setWidgetResizable(True)
        original_layout.addWidget(self.original_scroll)

        splitter.addWidget(original_group)

        rectified_group = QGroupBox("Rectified Document")
        rectified_layout = QVBoxLayout(rectified_group)

        self.rectified_scroll = QScrollArea()
        self.rectified_label = QLabel("Rectified document will appear here")
        self.rectified_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
        self.rectified_label.setStyleSheet("""
            QLabel {
                border: 1px solid #ddd;
                background-color: #f5f5f5;
                min-height: 200px;
                color: #888;
            }
        """)
        self.rectified_scroll.setWidget(self.rectified_label)
        self.rectified_scroll.setWidgetResizable(True)
        rectified_layout.addWidget(self.rectified_scroll)

        splitter.addWidget(rectified_group)
        splitter.setSizes([1, 1])  # Equal sizes

        layout.addWidget(splitter)

        return panel

    def init_detector(self):
        try:
            self.detector = DocumentDetector()
            self.statusBar().showMessage("Model loaded successfully - Ready for detection")
        except Exception as e:
            QMessageBox.critical(self, "Error", f"Failed to load model: {e}")
            self.statusBar().showMessage("Error loading model")

    def browse_image(self):
        """Browse for an image file"""
        file_path, _ = QFileDialog.getOpenFileName(
            self,
            "Select Image",
            "",
            "Image Files (*.jpg *.jpeg *.png *.bmp *.tiff *.tif)"
        )

        if file_path:
            self.load_image(file_path)

    def load_image(self, file_path: str):
        """Load an image from file path"""
        try:
            self.current_image = cv2.imread(file_path)
            if self.current_image is None:
                raise ValueError("Failed to load image")

            height, width, channel = self.current_image.shape
            bytes_per_line = 3 * width
            q_image = QImage(
                cv2.cvtColor(self.current_image, cv2.COLOR_BGR2RGB).data,
                width, height, bytes_per_line, QImage.Format.Format_RGB888
            )

            pixmap = QPixmap.fromImage(q_image)
            self.original_label.set_pixmap_scaled(pixmap)

            self.detect_btn.setEnabled(True)
            self.save_btn.setEnabled(False)
            self.rectified_label.setText("Rectified document will appear here")
            self.rectified_label.setPixmap(QPixmap())

            file_name = Path(file_path).name
            self.info_text.setText(f"Loaded: {file_name}\nDimensions: {width}x{height}\nReady for detection")
            self.statusBar().showMessage(f"Loaded: {file_name}")

            self.current_overlay = None
            self.current_rectified = None
            self.current_quad = None

            self.detect_document()

        except Exception as e:
            QMessageBox.critical(self, "Error", f"Failed to load image: {e}")

    def detect_document(self):
        """Run document detection"""
        if self.current_image is None or self.detector is None:
            print("Detection skipped: No image or detector")
            return

        try:
            self.progress_bar.setVisible(True)
            self.progress_bar.setRange(0, 0)  
            self.detect_btn.setEnabled(False)
            self.statusBar().showMessage("Detecting document...")

            self.cleanup_detection_thread()

            self.detection_worker = DetectionWorker(self.detector, self.current_image, 0.25)
            self.detection_thread = QThread()
            self.detection_worker.moveToThread(self.detection_thread)

            self.detection_thread.started.connect(self.detection_worker.run)
            self.detection_worker.finished.connect(self.on_detection_finished)
            self.detection_worker.error.connect(self.on_detection_error)
            self.detection_worker.finished.connect(self.cleanup_detection_thread)

            print("Starting detection thread...")
            self.detection_thread.start()

        except Exception as e:
            print(f"Error in detect_document: {e}")
            import traceback
            traceback.print_exc()
            self.progress_bar.setVisible(False)
            self.detect_btn.setEnabled(True)
            QMessageBox.critical(self, "Detection Error", f"Failed to start detection: {e}")

    ...
Enter fullscreen mode Exit fullscreen mode

Running the GUI Application

# Install additional dependencies for GUI
pip install PySide6

# Run the GUI application
python document_detector_gui.py
Enter fullscreen mode Exit fullscreen mode

ID detection with YOLO segmentation

Source Code

https://github.com/yushulx/python-barcode-qrcode-sdk/edit/main/examples/document-detection

Top comments (0)