@mnabih/speech_llm_fl
flwr new @mnabih/speech_llm_flFederated SpeechLLM with Flower
This app implements federated learning for a Speech Large Language Model (SpeechLLM), enabling privacy-preserving training of a multimodal speech understanding system across multiple clients using the Flower framework. It combines a WavLM audio encoder, a lightweight connector module, and a TinyLlama LLM to perform joint audio understanding tasks ā including transcription, speaker gender, emotion, age, accent, and speech activity detection ā without raw audio data ever leaving the local client.
Key Features
- šļø Multimodal Architecture ā Combines WavLM audio encoder + connector + TinyLlama LLM for end-to-end speech understanding across multiple tasks simultaneously
- š Privacy-Preserving ā Only trainable model parameters (LoRA weights + connector) are shared; raw audio data never leaves the client
- š Custom FedAvg Strategy ā Server-side learning rate decay per round, configurable client sampling, and automatic model checkpointing after every aggregation
- ā” LoRA Fine-Tuning ā Parameter-efficient federation: only LoRA adapters and the connector are trained and communicated, drastically reducing communication cost
- š Rich Metric Logging ā Tracks WER (transcription), gender, emotion, age, accent accuracy, and speech activity per validation round via W&B
- š¾ Checkpoint Resumption ā Supports resuming federation from a pretrained .ckpt file via pyproject.toml config
- š„ļø GPU Support ā Full CUDA acceleration with gradient clipping and gradient accumulation
Architecture
| File | Description |
|---|---|
| client_app.py | ClientApp with @app.train() and @app.evaluate() handlers ā loads weights, runs local PyTorch Lightning training, returns updated parameters and metrics |
| server_app.py | ServerApp with SpeechLLMFedAvg strategy ā manages LR decay per round, hierarchical aggregation, and checkpoint saving |
| trainer.py | SpeechLLMLightning ā PyTorch Lightning module defining the full SpeechLLM model, forward pass, training/validation/test steps, and metric logging |
| dataset.py | InstructionalAudioDataset, MyCollator, and build_dataloaders_from_csvs ā loads partitioned audio CSV datasets per client |
| pyproject.toml | All federation, model, training, data, and checkpoint config in one place |
Federated Learning Process
- Initialization ā Server loads global SpeechLLMLightning model (optionally from a pretrained checkpoint) and extracts only trainable parameters (LoRA + connector)
- Round Config ā Server computes a decayed learning rate for the round and broadcasts it alongside model weights to sampled clients
- Local Training ā Each client loads the received weights, trains locally for local-epochs with train-batch-per-epoch steps using PyTorch Lightning
- Aggregation ā Server performs weighted FedAvg over client updates proportional to dataset sizes
- Checkpointing ā Aggregated model is saved to disk after every round; final model saved as final_model.pt
Model Architecture
Audio Input (waveform)
ā
ā¼
WavLM Encoder (frozen or finetuned)
ā
ā¼
Connector (Linear / LinearPool / CNN)
ā
ā¼
[Pre-prompt embeddings] + [Speech embeddings] + [Post-prompt embeddings]
ā
ā¼
TinyLlama LLM (LoRA fine-tuned)
ā
ā¼
Structured JSON output:
{ "Transcript": "...", "Gender": "male", "Emotion": "neutral", ... }
Fetch the App
Install Flower:
pip install flwr
Fetch the app:
flwr new @mnabih/speech-llm-fl
This will create the following structure:
speech_llm_fl/
āāā speech_llm_fl/
ā āāā __init__.py
ā āāā client_app.py # ClientApp ā local train & evaluate handlers
ā āāā server_app.py # ServerApp ā SpeechLLMFedAvg strategy + main
ā āāā trainer.py # SpeechLLMLightning model definition
ā āāā dataset.py # Dataset, collator, and dataloader utilities
āāā pyproject.toml # All project metadata and Flower config
āāā README.md
Prerequisites
- Python 3.10+
- CUDA-capable GPU (strongly recommended for WavLM + LLM training)
- Audio data partitioned as CSV files per client, each row containing audio_path and label columns (transcript, gender, emotion, age, accent, isspeech)
- Pretrained model weights accessible (WavLM and TinyLlama will be downloaded from HuggingFace on first run)
Install Dependencies
cd speech_llm_fl && pip install -e .
Data Preparation
Each client needs a CSV file with the following columns:
| Column | Description |
|---|---|
| audio_path | Absolute path to the .wav audio file (16kHz mono) |
| transcript | Ground-truth transcription text |
| gender | Speaker gender (male / female) |
| emotion | Emotion label (e.g. neutral, happy, sad) |
| age | Age group label |
| accent | Accent label |
| isspeech | Boolean ā whether audio contains speech |
Organize client partitions into a directory, one CSV per client:
fl_multilingual/
āāā client_0.csv
āāā client_1.csv
āāā client_2.csv
āāā ...
Set the paths in pyproject.toml:
csv-train-dir = "./fl_multilingual" csv-dev-dir = "./fl_MLS_dev_speaker"
Run the App
Simulation (Single Machine)
Run with default settings:
flwr run .
Override specific settings at runtime:
flwr run . --run-config "num-server-rounds=10 local-epochs=5 max-lr=0.00005"
Resume from a pretrained checkpoint:
flwr run . --run-config "pretrained-checkpoint=/path/to/Checkpoint-round-420.ckpt checkpoint-offset=420"
Note: Simulation runs all clients on the same machine via Ray. Ensure sufficient GPU memory or reduce train-batch-per-epoch and train-batch-size for smaller GPUs.
Deployment Engine (Multi-Machine)
1. Start the SuperLink (Terminal 1):
flower-superlink --insecure
2. Start SuperNodes ā one per client machine (separate terminals):
# Client 0 flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9094 \ --node-config "partition-id=0 num-partitions=4" # Client 1 flower-supernode --insecure --superlink 127.0.0.1:9092 \ --clientappio-api-address 127.0.0.1:9095 \ --node-config "partition-id=1 num-partitions=4"
3. Run the federation (Terminal 4):
flwr run . supernode-deployment
Configuration
All settings are controlled via pyproject.toml. No code changes needed to run experiments.
Federation Settings
[tool.flwr.app.config] num-server-rounds = 200 # Total FL rounds fraction-fit = 0.3 # Fraction of clients sampled per round fraction-evaluate = 0.0 # Fraction of clients used for evaluation min-fit-clients = 2 # Minimum clients required to start a round min-evaluate-clients = 2
Model Settings
audio-encoder-name = "microsoft/wavlm-large" # HuggingFace model ID llm-name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" connector-name = "linear" # "linear", "linear-pool", or "cnn" audio-enc-dim = 1024 llm-dim = 2048 use-lora = true lora-r = 8 lora-alpha = 16 finetune-encoder = false
Training Settings
local-epochs = 10 # Epochs per client per round train-batch-size = 4 train-batch-per-epoch = 200 # Steps per epoch (limits dataset length) grad-accumulate-steps = 4 max-lr = 0.0001
Learning Rate Decay
lr-decay-factor = 0.9 # Multiply LR by this value every N rounds lr-decay-every = 10 # Decay interval in rounds
The effective LR at round r is: max-lr Ć decay-factor ^ (r // decay-every)
Checkpoint Settings
checkpoint-dir = "FL_SLAM_checkpoints" # Directory for round checkpoints checkpoint-offset = 0 # Add offset to round number in filenames pretrained-checkpoint = "" # Path to .ckpt to resume from
Simulation Settings
[tool.flwr.federations.local-simulation] options.num-supernodes = 316 # Total number of simulated clients [tool.flwr.federations.local-simulation.options] backend.client-resources.num-cpus = 1 backend.client-resources.num-gpus = 1
Metrics Tracked
| Metric | Description |
|---|---|
| train_loss | Cross-entropy loss on local training data |
| val/loss | Validation loss |
| val/wer | Word Error Rate on transcript predictions |
| val/gender | Gender classification accuracy |
| val/emotion | Emotion classification accuracy |
| val/age | Age group classification accuracy |
| val/accent | Accent classification accuracy |
| val/speech_activity | Speech activity detection accuracy |
Results
Performance comparison of WavLM vs. Whisper encoders, measured in Word Error Rate (WER ā) on LibriSpeech (LS) and Multilingual LibriSpeech (MLS) test sets. Central training serves as the upper bound.
| Setting | WavLM (Round=100) LS | WavLM (Round=100) MLS | Whisper (Round=40) LS | Whisper (Round=40) MLS |
|---|---|---|---|---|
| Central Training ā | 6.1 | 18.4 | 6.4 | 16.4 |
| FL Sample Cluster | 9.7 | 19.6 | 7.7 | 16.4 |
ā Central training is the upper bound (non-federated). Lower WER is better.
Key takeaway: WavLM with federated learning (FL Sample Cluster) achieves competitive WER on both benchmarks, with only a modest gap vs. central training ā demonstrating that the federation does not significantly degrade model quality while preserving data privacy.
More Information
- Flower Framework Docs ā Flower framework documentation
- WavLM (Microsoft) ā Audio encoder used in this app
- TinyLlama ā LLM backbone
- LoRA / PEFT ā Parameter-efficient fine-tuning library
- PyTorch Lightning ā Training framework
License
Apache License 2.0