mirror of
https://github.com/Mintplex-Labs/pyannote-audio-legacy.git
synced 2026-07-01 20:24:10 -04:00
2 lines
13 KiB
Plaintext
2 lines
13 KiB
Plaintext
{"cells":[{"cell_type":"markdown","metadata":{"id":"kY1p-wCLHw92"},"source":["# Add your own model"]},{"cell_type":"markdown","metadata":{"id":"iD_DNGmmHs9v"},"source":["<a href=\"https://colab.research.google.com/github/pyannote/pyannote-audio/blob/develop/tutorials/add_your_own_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"]},{"cell_type":"markdown","metadata":{"id":"hhBTSvk6H_JC"},"source":["## Tutorial setup"]},{"cell_type":"markdown","metadata":{"id":"r-ocA5Z8PqNl"},"source":["### `Google Colab` setup"]},{"cell_type":"markdown","metadata":{"id":"I7lc6ctfIBv-"},"source":["If you are running this tutorial on `Colab`, execute the following commands in order to setup `Colab` environment. These commands will install `pyannote.audio` and download a mini version of the `AMI` corpus."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"l07Xq_UAIUFE"},"outputs":[],"source":["!pip install -qq pyannote.audio==3.1.1\n","!pip install -qq ipython==7.34.0\n","!git clone https://github.com/pyannote/AMI-diarization-setup.git\n","%cd ./AMI-diarization-setup/pyannote/\n","!bash ./download_ami_mini.sh\n","%cd /content"]},{"cell_type":"markdown","metadata":{"id":"3rjw5hATOv_c"},"source":["⚠ Restart the runtime (Runtime > Restart session)."]},{"cell_type":"markdown","metadata":{},"source":["### Non `Google Colab` setup"]},{"cell_type":"markdown","metadata":{"id":"VdMVQD-9QAto"},"source":["If you are not using `Colab`, this tutorial assumes that\n","* `pyannote.audio` has been installed\n","* the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) has already been [setup for use with `pyannote`](https://github.com/pyannote/AMI-diarization-setup/tree/main/pyannote)"]},{"cell_type":"markdown","metadata":{"id":"kuemd4PWHeqh"},"source":["## Defining a custom model\n","\n","A collection of models is readily available in `pyannote.audio.models` but you will eventually want to try your own architecture.\n","\n","This tutorial explains how to define (and then use) your own model. "]},{"cell_type":"code","execution_count":18,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":12960,"status":"ok","timestamp":1704802939163,"user":{"displayName":"Clément PAGES","userId":"11757386314069785178"},"user_tz":-60},"id":"kNwQfnTOHeqm","outputId":"5f71bde3-5f13-4431-918f-6e1ca3ddd518"},"outputs":[],"source":["from typing import Optional\n","import torch\n","import torch.nn as nn\n","from pyannote.audio import Model\n","from pyannote.core import SlidingWindow\n","from pyannote.audio.core.task import Task, Resolution\n","from torchaudio.transforms import MFCC\n","\n","# Your custom model must be a subclass of `pyannote.audio.Model`,\n","# which is a subclass of `pytorch_lightning.LightningModule`,\n","# which is a subclass of `torch.nn.Module`.\n","class MyCustomModel(Model):\n"," \"\"\"My custom model\"\"\"\n","\n","\n"," def __init__(\n"," self,\n"," sample_rate: int = 16000,\n"," num_channels: int = 1,\n"," task: Optional[Task] = None,\n"," param1: int = 32,\n"," param2: int = 16,\n"," ):\n","\n"," # First three parameters (sample_rate, num_channels, and task)\n"," # must be there and passed to super().__init__()\n"," super().__init__(sample_rate=sample_rate,\n"," num_channels=num_channels,\n"," task=task)\n","\n"," # Mark param1 and param2 as hyper-parameters.\n"," self.save_hyperparameters(\"param1\", \"param2\")\n","\n"," # They will be saved automatically into checkpoints.\n"," # They are now also available in self.hparams:\n"," # - param1 == self.hparams.param1\n"," # - param2 == self.hparams.param2\n","\n"," # Layers that do not depend on the addressed task should be defined in '__init__'.\n"," self.mfcc = MFCC()\n"," self.linear1 = nn.Linear(self.mfcc.n_mfcc, self.hparams.param1)\n"," self.linear2 = nn.Linear(self.hparams.param1, self.hparams.param2)\n","\n"," def num_frames(self, num_samples: int) -> int:\n"," # Compute number of output frames for a given number of input samples\n"," hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length\n"," n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft\n"," center = self.mfcc.MelSpectrogram.spectrogram.center\n"," return (\n"," 1 + num_samples // hop_length\n"," if center\n"," else 1 + (num_samples - n_fft) // hop_length\n"," )\n","\n"," def receptive_field_size(self, num_frames: int = 1) -> int:\n"," # Compute receptive field size\n"," hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length\n"," n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft\n"," center = self.mfcc.MelSpectrogram.spectrogram.center\n","\n"," if center:\n"," return (num_frames - 1) * hop_length\n"," else:\n"," return (num_frames - 1) * hop_length + n_fft\n","\n"," def receptive_field(self) -> SlidingWindow:\n"," # Compute receptive field\n","\n"," # duration of the receptive field of each output frame\n"," duration = (\n"," self.mfcc.MelSpectrogram.spectrogram.win_length / self.hparams.sample_rate\n"," )\n","\n"," # step between the receptive field region of two consecutive output frames\n"," step = (\n"," self.mfcc.MelSpectrogram.spectrogram.hop_length / self.hparams.sample_rate\n"," )\n","\n"," return SlidingWindow(start=0.0, duration=duration, step=step)\n","\n"," def build(self):\n"," # Add layers that depend on the specifications of the task addressed\n"," # by this model.\n","\n"," # For instance, this simple model could be used for \"speech vs. non-speech\"\n"," # or \"speech vs. music vs. other\" classification and the only difference\n"," # would lie in the number of classes (2 or 3) in the final classifier.\n","\n"," # Since task specifications are not available at the time '__init__' is called,\n"," # task-dependent layers can only be added a 'build' time (where task specifications\n"," # are available in 'specifications' attribute)\n","\n"," num_classes = len(self.specifications.classes)\n"," self.classifier = nn.Linear(self.hparams.param2, num_classes)\n","\n"," # 'specifications' has several attributes describing what the task is:\n"," # - classes: the list of classes\n"," # - problem: the type of machine learning problem (e.g. binary\n"," # classification or representation learning)\n"," # - duration: the duration of input audio chunks, in seconds\n"," # - resolution: the resolution of the output (e.g. frame-wise scores\n"," # for voice activity detection or chunk-wise vector for speaker\n"," # embedding)\n"," # - permutation_invariant : whether classes are permutation-invariant\n"," # (e.g. in the case of speaker diarization)\n","\n"," # Depending on the type of 'problem', 'default_activation' can be used\n"," # to automatically guess what the final activation should be (e.g. softmax\n"," # for multi-class classification or sigmoid for multi-label classification).\n"," self.activation = self.default_activation()\n","\n"," # You obviously do not _have_ to use 'default_activation' and can choose to\n"," # use any activation you see fit (or even not use any activation layer). But\n"," # note that pyannote.audio tasks also define default loss functions that are\n"," # consistent with `default_activation` (e.g. binary cross entropy with softmax\n"," # for binary classification tasks)\n","\n"," def forward(self, waveforms: torch.Tensor) -> torch.Tensor:\n","\n"," # Models are expected to work on batches of audio chunks provided as tensors\n"," # with shape (batch_size, num_channels, num_samples) and using the sample rate\n"," # passed to __init__. Resampling will be done automatically for you so you do\n"," # not have to bother about that when preparing the data.\n","\n"," # Extract sequence of MFCCs and passed them through two linear layers\n"," mfcc = self.mfcc(waveforms).squeeze(dim=1).transpose(1, 2)\n"," output = self.linear1(mfcc)\n"," output = self.linear2(output)\n","\n"," # Apply temporal pooling for tasks which need an output at chunk-level.\n"," if self.specifications.resolution == Resolution.CHUNK:\n"," output = torch.mean(output, dim=-1)\n"," # Keep 'mfcc' frame resolution for frame-level tasks.\n"," elif self.specifications.resolution == Resolution.FRAME:\n"," pass\n","\n"," # Apply final classifier and activation function\n"," output = self.classifier(output)\n"," return self.activation(output)"]},{"cell_type":"markdown","metadata":{"id":"BuieqViJHeqp"},"source":["## Using your model with `pyannote.audio` API\n","\n","Your model can now be used like any other builtin model."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"qwTjuGuvHeqr"},"outputs":[],"source":["# initialize your experimental protocol\n","from pyannote.database import registry, FileFinder\n","\n","registry.load_database(\"./AMI-diarization-setup/pyannote/database.yml\")\n","protocol = registry.get_protocol('AMI.SpeakerDiarization.mini', preprocessors={\"audio\": FileFinder()})\n","\n","# initialize the task you want to address\n","from pyannote.audio.tasks import VoiceActivityDetection\n","task = VoiceActivityDetection(protocol)\n","\n","# initialize the model\n","model = MyCustomModel(task=task)\n","\n","# train the model\n","from pytorch_lightning import Trainer\n","trainer = Trainer(max_epochs=1)\n","trainer.fit(model)"]},{"cell_type":"markdown","metadata":{"id":"4qidmGQyHeqt"},"source":["## Using your model with `pyannote-audio-train` CLI\n","\n","1. Define your model in a proper Python package:\n","\n","```\n","/your/favorite/directory/\n"," your_package_name/\n"," __init__.py # needs to be here but can be empty\n"," custom_model.py # contains the above definition of your model\n","```\n","\n","2. Add the package to your `PYTHONPATH`:\n","\n","```bash\n","$ export PYTHONPATH=/your/favorite/directory\n","```\n","\n","3. Check that you can import it from Python:\n","\n","```python\n",">>> from your_package_name.custom_model import MyCustomModel\n","```\n","\n","4. Tell `Hydra` (on which `pyannote-audio-train` is based) about this new model:\n","\n","```\n","/your/favorite/directory/\n"," custom_config/\n"," model/\n"," MyCustomModel.yaml\n","```\n","\n","where the content of `MyCustomModel.yaml` is as follows:\n","\n","```yaml\n","# @package _group_\n","_target_: your_package_name.custom_model.MyCustomModel\n","param1: 32\n","param2: 16\n","```\n","\n","5. Enjoy\n","\n","```bash\n","$ pyannote-audio-train --config-dir=/your/favorite/directory/custom_config \\\n"," protocol=Debug.SpeakerDiarization.Debug \\\n"," task=VoiceActivityDetection \\\n"," model=MyCustomModel \\\n"," model.param2=12\n","```"]},{"cell_type":"markdown","metadata":{"id":"8W2UT1IpHequ"},"source":["## Contributing your model to `pyannote-audio`\n","\n","1. Add your model in `pyannote.audio.models`.\n","\n","```\n","pyannote/\n"," audio/\n"," models/\n"," custom_model.py \n","```\n","\n","2. Check that you can import it from Python:\n","\n","```python\n",">>> from pyannote.audio.models.custom_model import MyCustomModel\n","```\n","\n","3. Add the corresponding `Hydra` configuration file:\n","\n","```\n","pyannote/\n"," audio/\n"," cli/\n"," train_config/\n"," model/\n"," MyCustomModel.yaml\n","```\n","\n","where the content of `MyCustomModel.yaml` is as follows:\n","\n","```yaml\n","# @package _group_\n","_target_: pyannote.audio.models.custom_model.MyCustomModel\n","param1: 32\n","param2: 16\n","```\n","\n","4. Enjoy\n","\n","```bash\n","$ pyannote-audio-train protocol=Debug.SpeakerDiarization.Debug \\\n"," task=VoiceActivityDetection \\\n"," model=MyCustomModel \\\n"," model.param2=12\n","```"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":0}
|