@addyk/flowernnunet

11
0
flwr new @addyk/flowernnunet

Federated nnU-Net with Flower

This app implements federated learning for nnU-Net, enabling privacy-preserving medical image segmentation across multiple institutions using the Flower framework. It supports modality-aware aggregation across imaging types (CT, MR, PET, US), multi-dataset federation with heterogeneous anatomies, and nnU-Net v2's native training pipeline including 3D full-resolution segmentation with deep supervision.

Key Features

  • šŸ„ Modality-Aware Aggregation — Automatically detects client imaging modalities (CT, MR, PET, US) and performs hierarchical aggregation: intra-modality first, then inter-modality weighted combination
  • šŸ”€ Multi-Dataset Federation — Train across heterogeneous datasets with different anatomies, modalities, and label sets in a single federation
  • 🧠 Native nnU-Net v2 Integration — Uses nnU-Net's full training pipeline: nnUNetDataLoader3D, native augmentation transforms, deep supervision (6 output scales), and automatic architecture configuration
  • šŸ”’ Privacy-Preserving — Only model parameters and dataset fingerprints are shared; raw medical imaging data never leaves the local site
  • šŸ“Š W&B Integration — Optional Weights & Biases logging for experiment tracking across federated rounds
  • šŸ’¾ Model Checkpointing — Automatic saving of best local and global models based on validation Dice scores
  • šŸ–„ļø GPU Support — Full CUDA acceleration with automatic mixed precision training

Architecture

ComponentDescription
server_app_modality.pyModalityAwareFederatedStrategy — multi-phase server with modality grouping and hierarchical aggregation
server_app.pyNnUNetFederatedStrategy — standard FedAvg baseline strategy
client_app.pyNnUNet3DFullresClient — handles fingerprint collection, local training, backbone parameter filtering, and model saving
task.pyFedNnUNetTrainer — extends nnU-Net's nnUNetTrainer for federated scenarios with validation and PyTorch model export
wandb_integration.pyW&B logging utilities for federated experiment tracking
dataset_compatibility.pyDataset validation and compatibility checking for multi-dataset federation
federation_config.pyFederation configuration management

Federated Learning Process

  1. Fingerprint Phase (Round -2) — Clients share dataset statistics (shapes, spacings, intensity properties, modality info)
  2. Initialization Phase (Round -1) — Server merges fingerprints and distributes the global fingerprint + initial model parameters
  3. Training Phases (Round 0+) — Iterative local training with native nnU-Net methods, backbone parameter extraction, and modality-aware global aggregation

Aggregation Strategies

  • Modality-Aware (default): CT clients aggregate → CT model; MR clients aggregate → MR model; weighted combination → global model
  • Standard FedAvg: Traditional weighted average by number of training examples across all clients

Fetch the App

Install Flower:

pip install flwr

Fetch the app:

flwr new @addyk/flowernnunet

This will create a new directory called flowernnunet with the following structure:

flowernnunet
ā”œā”€ā”€ flowernnunet
│   ā”œā”€ā”€ __init__.py
│   ā”œā”€ā”€ client_app.py          # Defines your ClientApp
│   ā”œā”€ā”€ server_app.py          # Basic federated strategy (FedAvg)
│   ā”œā”€ā”€ server_app_modality.py # Modality-aware federated strategy
│   ā”œā”€ā”€ task.py                # FedNnUNetTrainer (extends nnUNet)
│   ā”œā”€ā”€ wandb_integration.py   # W&B logging support
│   ā”œā”€ā”€ dataset_compatibility.py
│   └── federation_config.py
ā”œā”€ā”€ pyproject.toml             # Project metadata and Flower configs
ā”œā”€ā”€ README.md
ā”œā”€ā”€ DEPLOYMENT_GUIDE.md        # Detailed deployment instructions
ā”œā”€ā”€ MULTI_DATASET_GUIDE.md     # Multi-dataset federation guide
ā”œā”€ā”€ setup_flwr_nnunet.sh       # Automated setup & preprocessing script
└── run_federated_deployment.sh # Automated deployment script

Prerequisites

  1. Python 3.10+ with conda or pip
  2. nnU-Net v2 installed and configured (pip install nnunetv2)
  3. Preprocessed data in nnU-Net's standard format (.npz/.b2nd + .pkl)
  4. GPU with CUDA support (required for nnU-Net training)

Environment Setup

# Set required nnU-Net paths
export nnUNet_raw="/path/to/nnUNet_raw"
export nnUNet_preprocessed="/path/to/nnUNet_preprocessed"
export nnUNet_results="/path/to/nnUNet_results"

# Optional: Set model saving directory
export OUTPUT_ROOT="./federated_models"

Medical Segmentation Decathlon (MSD) Datasets

This app works with any nnU-Net-preprocessed dataset. The Medical Segmentation Decathlon provides 10 benchmark datasets ideal for testing federated scenarios:

TaskDatasetModalityTargetsTrain/TestDownload
Task01BrainTumourMR (FLAIR, T1w, T1gd, T2w)Glioma subregions484 / 266Download
Task02HeartMR (Mono)Left atrium20 / 10Download
Task03LiverCT (Portal venous)Liver + tumors131 / 70Download
Task04HippocampusMR (Mono)Hippocampus head/body260 / 130Download
Task05ProstateMR (T2, ADC)Central gland + peripheral zone32 / 16Download
Task06LungCTLung tumors63 / 32Download
Task07PancreasCT (Portal venous)Pancreas + tumors281 / 139Download
Task08HepaticVesselCTHepatic vessels + tumors303 / 140Download
Task09SpleenCT (Portal venous)Spleen41 / 20Download
Task10ColonCT (Portal venous)Colon cancer126 / 64Download

Quick Dataset Setup

Download and preprocess a dataset (e.g., Spleen):

# Download
wget https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar
tar -xf Task09_Spleen.tar

# Move to nnUNet raw directory and preprocess
# (copy Task09_Spleen into $nnUNet_raw/Dataset009_Spleen following nnU-Net conventions)
nnUNetv2_plan_and_preprocess -d 9 --verify_dataset_integrity

Or use the automated setup script:

bash setup_flwr_nnunet.sh --dataset Task09_Spleen --name flwr-nnunet-demo

Multi-Dataset Federated Scenarios

For testing cross-modality federation, try combining:

  • CT datasets: Spleen (Task09) + Liver (Task03) — same modality, different anatomy
  • Cross-modality: Prostate MR (Task05) + Spleen CT (Task09) — different modalities with modality-aware aggregation

Run the App

Run with the Simulation Engine

Install the dependencies defined in pyproject.toml as well as the flowernnunet package:

cd flowernnunet && pip install -e .

Set environment variables for nnU-Net:

export nnUNet_preprocessed="/path/to/nnUNet_preprocessed"
export TASK_NAME="Dataset009_Spleen"

Run with default settings:

flwr run .

Override settings:

flwr run . --run-config "num-server-rounds=5"

Note: Simulation runs all clients on the same machine. For large 3D medical datasets, ensure sufficient GPU memory or reduce options.num-supernodes in pyproject.toml.

Run with the Deployment Engine

šŸ“– For detailed deployment instructions, see DEPLOYMENT_GUIDE.md.

1. Start the SuperLink (Terminal 1):

flower-superlink --insecure

2. Start SuperNodes (separate terminals):

# Terminal 2: First SuperNode
flower-supernode --insecure --superlink 127.0.0.1:9092 \
    --clientappio-api-address 127.0.0.1:9094 \
    --node-config "partition-id=0"

# Terminal 3: Second SuperNode
flower-supernode --insecure --superlink 127.0.0.1:9092 \
    --clientappio-api-address 127.0.0.1:9095 \
    --node-config "partition-id=1"

You can also specify datasets and folds per SuperNode:

flower-supernode --insecure --superlink 127.0.0.1:9092 \
    --clientappio-api-address 127.0.0.1:9094 \
    --node-config 'partition-id=0 dataset-name="Dataset005_Prostate" fold=0'

3. Run the federation (Terminal 4):

flwr run . supernode-deployment

Automated Deployment Script

For convenience, use the all-in-one deployment script:

# Single dataset
bash run_federated_deployment.sh \
    --dataset Dataset009_Spleen --clients 2 --rounds 3 --local-epochs 2 --validate

# Multi-dataset with modality-aware aggregation
bash run_federated_deployment.sh \
    --client-datasets '{"0": "Dataset005_Prostate", "1": "Dataset009_Spleen"}' \
    --clients 2 --rounds 5 --enable-modality-aggregation

# Custom modality weights
bash run_federated_deployment.sh \
    --client-datasets '{"0": "Dataset005_Prostate", "1": "Dataset009_Spleen"}' \
    --clients 2 --rounds 5 --enable-modality-aggregation \
    --modality-weights '{"CT": 0.6, "MR": 0.4}'
All deployment script arguments
CategoryArgumentDefaultDescription
Dataset--dataset—Single dataset for all clients
--client-datasets—JSON mapping of client IDs to datasets
--list-datasets—List available preprocessed datasets
--validate-datasets—Validate multi-dataset compatibility
Training--clients2Number of federated clients
--rounds3Number of federated rounds
--local-epochs2Local epochs per client per round
Modality--enable-modality-aggregationfalseEnable modality-aware aggregation
--modality-weights—JSON modality weight overrides
Deployment--moderunsuperlink, supernode, or run
--superlink-host127.0.0.1SuperLink host address
Validation--validatetrueEnable validation during training
--no-validate—Skip validation
Output--output-dirfederated_modelsModel output directory
--save-frequency1Save models every N rounds
System--gpu0GPU device ID

Configuration

Federation Settings (pyproject.toml)

[tool.flwr.app.config]
num-server-rounds = 100       # Number of training rounds
fraction-fit = 1.0            # Fraction of clients for training
fraction-evaluate = 0.0       # Fraction of clients for evaluation

[tool.flwr.federations.local-simulation]
options.num-supernodes = 1     # Number of simulated clients

[tool.flwr.federations.supernode-deployment]
address = "127.0.0.1:9093"
insecure = true
options.num-supernodes = 2
options.enable-modality-aggregation = true

GPU Configuration

# Enable specific GPU
export CUDA_VISIBLE_DEVICES=0

# Set model saving directory
export OUTPUT_ROOT="./federated_models"

In task.py, the default device is CPU for Ray compatibility. For GPU training, modify:

device = torch.device("cuda")  # Instead of "cpu"

Data Format Support

FormatExtensionDescription
B2ND.b2ndCompressed Blosc2 format (preferred, requires blosc2)
NPZ.npzStandard NumPy compressed format (legacy)
Properties.pklMedical imaging metadata per case

Expected preprocessed data structure:

nnUNet_preprocessed/DatasetXXX_Name/
ā”œā”€ā”€ dataset.json
ā”œā”€ā”€ dataset_fingerprint.json
ā”œā”€ā”€ nnUNetPlans.json
ā”œā”€ā”€ splits_final.json
└── nnUNetPlans_3d_fullres/
    ā”œā”€ā”€ case_001.b2nd (or .npz)
    ā”œā”€ā”€ case_001_seg.b2nd
    ā”œā”€ā”€ case_001.pkl
    └── ...

More Information

Acknowledgments

  • nnU-Net v2 for medical image segmentation
  • Flower Framework for federated learning infrastructure
  • Kaapana — federated learning concepts used in federating nnU-Net

License

Apache License 2.0