@addyk/flowernnunet
flwr new @addyk/flowernnunetFederated nnU-Net with Flower
This app implements federated learning for nnU-Net, enabling privacy-preserving medical image segmentation across multiple institutions using the Flower framework. It supports modality-aware aggregation across imaging types (CT, MR, PET, US), multi-dataset federation with heterogeneous anatomies, and nnU-Net v2's native training pipeline including 3D full-resolution segmentation with deep supervision.
Key Features
- š„ Modality-Aware Aggregation ā Automatically detects client imaging modalities (CT, MR, PET, US) and performs hierarchical aggregation: intra-modality first, then inter-modality weighted combination
- š Multi-Dataset Federation ā Train across heterogeneous datasets with different anatomies, modalities, and label sets in a single federation
- š§ Native nnU-Net v2 Integration ā Uses nnU-Net's full training pipeline: nnUNetDataLoader3D, native augmentation transforms, deep supervision (6 output scales), and automatic architecture configuration
- š Privacy-Preserving ā Only model parameters and dataset fingerprints are shared; raw medical imaging data never leaves the local site
- š W&B Integration ā Optional Weights & Biases logging for experiment tracking across federated rounds
- š¾ Model Checkpointing ā Automatic saving of best local and global models based on validation Dice scores
- š„ļø GPU Support ā Full CUDA acceleration with automatic mixed precision training
Architecture
| Component | Description |
|---|---|
| server_app_modality.py | ModalityAwareFederatedStrategy ā multi-phase server with modality grouping and hierarchical aggregation |
| server_app.py | NnUNetFederatedStrategy ā standard FedAvg baseline strategy |
| client_app.py | NnUNet3DFullresClient ā handles fingerprint collection, local training, backbone parameter filtering, and model saving |
| task.py | FedNnUNetTrainer ā extends nnU-Net's nnUNetTrainer for federated scenarios with validation and PyTorch model export |
| wandb_integration.py | W&B logging utilities for federated experiment tracking |
| dataset_compatibility.py | Dataset validation and compatibility checking for multi-dataset federation |
| federation_config.py | Federation configuration management |
Federated Learning Process
- Fingerprint Phase (Round -2) ā Clients share dataset statistics (shapes, spacings, intensity properties, modality info)
- Initialization Phase (Round -1) ā Server merges fingerprints and distributes the global fingerprint + initial model parameters
- Training Phases (Round 0+) ā Iterative local training with native nnU-Net methods, backbone parameter extraction, and modality-aware global aggregation
Aggregation Strategies
- Modality-Aware (default): CT clients aggregate ā CT model; MR clients aggregate ā MR model; weighted combination ā global model
- Standard FedAvg: Traditional weighted average by number of training examples across all clients
Fetch the App
Install Flower:
pip install flwr
Fetch the app:
flwr new @addyk/flowernnunet
This will create a new directory called flowernnunet with the following structure:
flowernnunet
āāā flowernnunet
ā āāā __init__.py
ā āāā client_app.py # Defines your ClientApp
ā āāā server_app.py # Basic federated strategy (FedAvg)
ā āāā server_app_modality.py # Modality-aware federated strategy
ā āāā task.py # FedNnUNetTrainer (extends nnUNet)
ā āāā wandb_integration.py # W&B logging support
ā āāā dataset_compatibility.py
ā āāā federation_config.py
āāā pyproject.toml # Project metadata and Flower configs
āāā README.md
āāā DEPLOYMENT_GUIDE.md # Detailed deployment instructions
āāā MULTI_DATASET_GUIDE.md # Multi-dataset federation guide
āāā setup_flwr_nnunet.sh # Automated setup & preprocessing script
āāā run_federated_deployment.sh # Automated deployment script
Prerequisites
- Python 3.10+ with conda or pip
- nnU-Net v2 installed and configured (pip install nnunetv2)
- Preprocessed data in nnU-Net's standard format (.npz/.b2nd + .pkl)
- GPU with CUDA support (required for nnU-Net training)
Environment Setup
# Set required nnU-Net paths export nnUNet_raw="/path/to/nnUNet_raw" export nnUNet_preprocessed="/path/to/nnUNet_preprocessed" export nnUNet_results="/path/to/nnUNet_results" # Optional: Set model saving directory export OUTPUT_ROOT="./federated_models"
Medical Segmentation Decathlon (MSD) Datasets
This app works with any nnU-Net-preprocessed dataset. The Medical Segmentation Decathlon provides 10 benchmark datasets ideal for testing federated scenarios:
| Task | Dataset | Modality | Targets | Train/Test | Download |
|---|---|---|---|---|---|
| Task01 | BrainTumour | MR (FLAIR, T1w, T1gd, T2w) | Glioma subregions | 484 / 266 | Download |
| Task02 | Heart | MR (Mono) | Left atrium | 20 / 10 | Download |
| Task03 | Liver | CT (Portal venous) | Liver + tumors | 131 / 70 | Download |
| Task04 | Hippocampus | MR (Mono) | Hippocampus head/body | 260 / 130 | Download |
| Task05 | Prostate | MR (T2, ADC) | Central gland + peripheral zone | 32 / 16 | Download |
| Task06 | Lung | CT | Lung tumors | 63 / 32 | Download |
| Task07 | Pancreas | CT (Portal venous) | Pancreas + tumors | 281 / 139 | Download |
| Task08 | HepaticVessel | CT | Hepatic vessels + tumors | 303 / 140 | Download |
| Task09 | Spleen | CT (Portal venous) | Spleen | 41 / 20 | Download |
| Task10 | Colon | CT (Portal venous) | Colon cancer | 126 / 64 | Download |
Quick Dataset Setup
Download and preprocess a dataset (e.g., Spleen):
# Download wget https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar tar -xf Task09_Spleen.tar # Move to nnUNet raw directory and preprocess # (copy Task09_Spleen into $nnUNet_raw/Dataset009_Spleen following nnU-Net conventions) nnUNetv2_plan_and_preprocess -d 9 --verify_dataset_integrity
Or use the automated setup script:
bash setup_flwr_nnunet.sh --dataset Task09_Spleen --name flwr-nnunet-demo
Multi-Dataset Federated Scenarios
For testing cross-modality federation, try combining:
- CT datasets: Spleen (Task09) + Liver (Task03) ā same modality, different anatomy
- Cross-modality: Prostate MR (Task05) + Spleen CT (Task09) ā different modalities with modality-aware aggregation
Run the App
Run with the Simulation Engine
Install the dependencies defined in pyproject.toml as well as the flowernnunet package:
cd flowernnunet && pip install -e .
Set environment variables for nnU-Net:
export nnUNet_preprocessed="/path/to/nnUNet_preprocessed" export TASK_NAME="Dataset009_Spleen"
Run with default settings:
flwr run .
Override settings:
flwr run . --run-config "num-server-rounds=5"
Note: Simulation runs all clients on the same machine. For large 3D medical datasets, ensure sufficient GPU memory or reduce options.num-supernodes in pyproject.toml.
Run with the Deployment Engine
š For detailed deployment instructions, see DEPLOYMENT_GUIDE.md.
1. Start the SuperLink (Terminal 1):
flower-superlink --insecure
2. Start SuperNodes (separate terminals):
# Terminal 2: First SuperNode flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9094 \ --node-config "partition-id=0" # Terminal 3: Second SuperNode flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9095 \ --node-config "partition-id=1"
You can also specify datasets and folds per SuperNode:
flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9094 \ --node-config 'partition-id=0 dataset-name="Dataset005_Prostate" fold=0'
3. Run the federation (Terminal 4):
flwr run . supernode-deployment
Automated Deployment Script
For convenience, use the all-in-one deployment script:
# Single dataset bash run_federated_deployment.sh \ --dataset Dataset009_Spleen --clients 2 --rounds 3 --local-epochs 2 --validate # Multi-dataset with modality-aware aggregation bash run_federated_deployment.sh \ --client-datasets '{"0": "Dataset005_Prostate", "1": "Dataset009_Spleen"}' \ --clients 2 --rounds 5 --enable-modality-aggregation # Custom modality weights bash run_federated_deployment.sh \ --client-datasets '{"0": "Dataset005_Prostate", "1": "Dataset009_Spleen"}' \ --clients 2 --rounds 5 --enable-modality-aggregation \ --modality-weights '{"CT": 0.6, "MR": 0.4}'
All deployment script arguments
| Category | Argument | Default | Description |
|---|---|---|---|
| Dataset | --dataset | ā | Single dataset for all clients |
| --client-datasets | ā | JSON mapping of client IDs to datasets | |
| --list-datasets | ā | List available preprocessed datasets | |
| --validate-datasets | ā | Validate multi-dataset compatibility | |
| Training | --clients | 2 | Number of federated clients |
| --rounds | 3 | Number of federated rounds | |
| --local-epochs | 2 | Local epochs per client per round | |
| Modality | --enable-modality-aggregation | false | Enable modality-aware aggregation |
| --modality-weights | ā | JSON modality weight overrides | |
| Deployment | --mode | run | superlink, supernode, or run |
| --superlink-host | 127.0.0.1 | SuperLink host address | |
| Validation | --validate | true | Enable validation during training |
| --no-validate | ā | Skip validation | |
| Output | --output-dir | federated_models | Model output directory |
| --save-frequency | 1 | Save models every N rounds | |
| System | --gpu | 0 | GPU device ID |
Configuration
Federation Settings (pyproject.toml)
[tool.flwr.app.config] num-server-rounds = 100 # Number of training rounds fraction-fit = 1.0 # Fraction of clients for training fraction-evaluate = 0.0 # Fraction of clients for evaluation [tool.flwr.federations.local-simulation] options.num-supernodes = 1 # Number of simulated clients [tool.flwr.federations.supernode-deployment] address = "127.0.0.1:9093" insecure = true options.num-supernodes = 2 options.enable-modality-aggregation = true
GPU Configuration
# Enable specific GPU export CUDA_VISIBLE_DEVICES=0 # Set model saving directory export OUTPUT_ROOT="./federated_models"
In task.py, the default device is CPU for Ray compatibility. For GPU training, modify:
device = torch.device("cuda") # Instead of "cpu"
Data Format Support
| Format | Extension | Description |
|---|---|---|
| B2ND | .b2nd | Compressed Blosc2 format (preferred, requires blosc2) |
| NPZ | .npz | Standard NumPy compressed format (legacy) |
| Properties | .pkl | Medical imaging metadata per case |
Expected preprocessed data structure:
nnUNet_preprocessed/DatasetXXX_Name/
āāā dataset.json
āāā dataset_fingerprint.json
āāā nnUNetPlans.json
āāā splits_final.json
āāā nnUNetPlans_3d_fullres/
āāā case_001.b2nd (or .npz)
āāā case_001_seg.b2nd
āāā case_001.pkl
āāā ...
More Information
- DEPLOYMENT_GUIDE.md ā Step-by-step SuperLink/SuperNode deployment
- MULTI_DATASET_GUIDE.md ā Multi-dataset federation with heterogeneous data
- nnU-Net v2 ā Medical image segmentation framework
- Flower Framework ā Federated learning infrastructure
- Medical Segmentation Decathlon ā Benchmark datasets
Acknowledgments
- nnU-Net v2 for medical image segmentation
- Flower Framework for federated learning infrastructure
- Kaapana ā federated learning concepts used in federating nnU-Net
License
Apache License 2.0