mirror of
https://github.com/Mintplex-Labs/pyannote-audio-legacy.git
synced 2026-07-01 20:24:10 -04:00
2 lines
17 KiB
Plaintext
2 lines
17 KiB
Plaintext
{"cells":[{"cell_type":"markdown","metadata":{"id":"W7BMj2EZlWqU"},"source":["<a href=\"https://colab.research.google.com/github/pyannote/pyannote-audio/blob/develop/tutorials/add_your_own_task.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"]},{"cell_type":"markdown","metadata":{"id":"HG6OvaE4lWqZ"},"source":["# Defining a custom task"]},{"cell_type":"markdown","metadata":{"id":"c6LwrLYVlWqZ"},"source":["## Tutorial setup"]},{"cell_type":"markdown","metadata":{"id":"6lR9bgJBlWqb"},"source":["### `Google Colab` setup"]},{"cell_type":"markdown","metadata":{},"source":["If you are running this tutorial on `Colab`, execute the following commands in order to setup `Colab` environment. These commands will install `pyannote.audio` and download a mini version of the `AMI` corpus."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":127254,"status":"ok","timestamp":1704809957597,"user":{"displayName":"Clément PAGES","userId":"11757386314069785178"},"user_tz":-60},"id":"6LoOS-PjlWqd","outputId":"cd92e2d4-83cc-4bb0-ad5c-824cb2ca11ac"},"outputs":[],"source":["!pip install -qq pyannote.audio==3.1.1\n","!pip install -qq ipython==7.34.0\n","!git clone https://github.com/pyannote/AMI-diarization-setup.git\n","%cd ./AMI-diarization-setup/pyannote/\n","!bash ./download_ami_mini.sh\n","%cd /content"]},{"cell_type":"markdown","metadata":{"id":"LsZTSX-ulWqf"},"source":["⚠ Restart the runtime (Runtime > Restart session)."]},{"cell_type":"markdown","metadata":{"id":"904hVjv8lWqg"},"source":["### Non `Google Colab` setup"]},{"cell_type":"markdown","metadata":{"id":"serWAfFxlWqh"},"source":["If you are not using `Colab`, this tutorial assumes that\n","* `pyannote.audio` has been installed\n","* the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) has already been [setup for use with `pyannote`](https://github.com/pyannote/AMI-diarization-setup/tree/main/pyannote)"]},{"cell_type":"markdown","metadata":{"id":"uWyNce9FlkA3"},"source":["## Task in `pyannote.audio`"]},{"cell_type":"markdown","metadata":{"id":"BK4hbdq6lWqj"},"source":["\n","In `pyannote.audio`, a *task* is a combination of a **_problem_** that needs to be addressed and an **experimental protocol**.\n","\n","For example, one can address **_voice activity detection_** following the **AMI only_words** experimental protocol, by instantiating the following *task*:\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-4B8nLDmlWql"},"outputs":[],"source":["# this assumes that the AMI corpus has been setup for diarization\n","# according to https://github.com/pyannote/AMI-diarization-setup\n","\n","from pyannote.database import registry, FileFinder\n","registry.load_database(\"AMI-diarization-setup/pyannote/database.yml\")\n","ami = registry.get_protocol('AMI.SpeakerDiarization.mini',\n"," preprocessors={'audio': FileFinder()})\n","\n","# address voice activity detection\n","from pyannote.audio.tasks import VoiceActivityDetection\n","task = VoiceActivityDetection(ami)"]},{"cell_type":"markdown","metadata":{"id":"A9nxwDQGlWqn"},"source":["A growing collection of tasks is readily available in `pyannote.audio.tasks`..."]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":232,"status":"ok","timestamp":1704810010556,"user":{"displayName":"Clément PAGES","userId":"11757386314069785178"},"user_tz":-60},"id":"qbbDA2P5lWqp","outputId":"bf1988fb-9140-4d6a-8a2c-970578331f35"},"outputs":[{"name":"stdout","output_type":"stream","text":["SpeakerDiarization\n","VoiceActivityDetection\n","OverlappedSpeechDetection\n","MultiLabelSegmentation\n","SpeakerEmbedding\n","Segmentation\n"]}],"source":["from pyannote.audio.tasks import __all__ as TASKS; print('\\n'.join(TASKS))"]},{"cell_type":"markdown","metadata":{"id":"hihPu4iElWqr"},"source":["... but you will eventually want to use `pyannote.audio` to address a different task. \n","In this example, we will add a new task addressing the **sound event detection** problem.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"RZOs3C4HlWqr"},"source":["## Problem specification\n","\n","A problem is expected to be solved by a model $f$ that takes an audio chunk $X$ as input and returns its predicted solution $\\hat{y} = f(X)$.\n","\n","### Resolution\n","\n","Depending on the addressed problem, you might expect the model to output just one prediction for the whole audio chunk (`Resolution.CHUNK`) or a temporal sequence of predictions (`Resolution.FRAME`).\n","\n","In our particular case, we would like the model to provide one decision for the whole chunk:"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":234,"status":"ok","timestamp":1704810016464,"user":{"displayName":"Clément PAGES","userId":"11757386314069785178"},"user_tz":-60},"id":"G96Mz8vPlWqs"},"outputs":[],"source":["from pyannote.audio.core.task import Resolution\n","resolution = Resolution.CHUNK"]},{"cell_type":"markdown","metadata":{"id":"_Efd28eclWqt"},"source":["### Type of problem\n","\n","Similarly, the type of your problem may fall into one of these generic machine learning categories:\n","* `Problem.BINARY_CLASSIFICATION` for binary classification\n","* `Problem.MONO_LABEL_CLASSIFICATION` for multi-class classification\n","* `Problem.MULTI_LABEL_CLASSIFICATION` for multi-label classification\n","* `Problem.REGRESSION` for regression\n","* `Problem.REPRESENTATION` for representation learning\n","\n","In our particular case, we would like the model to do multi-label classification because one audio chunk may contain multiple sound events:"]},{"cell_type":"code","execution_count":4,"metadata":{"executionInfo":{"elapsed":315,"status":"ok","timestamp":1704810020230,"user":{"displayName":"Clément PAGES","userId":"11757386314069785178"},"user_tz":-60},"id":"Cl0VqB5jlWqu"},"outputs":[],"source":["from pyannote.audio.core.task import Problem\n","problem = Problem.MULTI_LABEL_CLASSIFICATION"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":251,"status":"ok","timestamp":1704810021646,"user":{"displayName":"Clément PAGES","userId":"11757386314069785178"},"user_tz":-60},"id":"Hz_B7FCplWqv"},"outputs":[],"source":["from pyannote.audio.core.task import Specifications\n","specifications = Specifications(\n"," problem=problem,\n"," resolution=resolution,\n"," duration=5.0,\n"," classes=[\"Speech\", \"Dog\", \"Cat\", \"Alarm_bell_ringing\", \"Dishes\",\n"," \"Frying\", \"Blender\", \"Running_water\", \"Vacuum_cleaner\",\n"," \"Electric_shaver_toothbrush\"],\n",")"]},{"cell_type":"markdown","metadata":{"id":"5N72ksU7lWqv"},"source":["A task is expected to be solved by a model $f$ that (usually) takes an audio chunk $X$ as input and returns its predicted solution $\\hat{y} = f(X)$.\n","\n","To help training the model $f$, the task $\\mathcal{T}$ is in charge of\n","- generating $(X, y)$ training samples using the **dataset**\n","- defining the loss function $\\mathcal{L}(y, \\hat{y})$\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"lrTD1RwUlWqw"},"outputs":[],"source":["from math import ceil\n","from typing import Dict, Optional,Tuple, Union\n","import numpy as np\n","from pyannote.core import Segment, SlidingWindow\n","from pyannote.audio.utils.random import create_rng_for_worker\n","from pyannote.audio.core.task import Task, Resolution\n","from pyannote.database import Protocol\n","from torchmetrics.classification import MultilabelAUROC\n","\n","# Your custom task must be a subclass of `pyannote.audio.core.task.Task`\n","class SoundEventDetection(Task):\n"," \"\"\"Sound event detection\"\"\"\n","\n"," def __init__(\n"," self,\n"," protocol: Protocol,\n"," duration: float = 5.0,\n"," min_duration: float = 5.0,\n"," warm_up: Union[float, Tuple[float, float]] = 0.0,\n"," batch_size: int = 32,\n"," num_workers: int = None,\n"," pin_memory: bool = False,\n"," augmentation = None,\n"," cache: Optional[Union[str, None]] = None,\n"," **other_params,\n"," ):\n","\n"," super().__init__(\n"," protocol,\n"," duration=duration,\n"," min_duration=min_duration,\n"," warm_up=warm_up,\n"," batch_size=batch_size,\n"," num_workers=num_workers,\n"," pin_memory=pin_memory,\n"," augmentation=augmentation,\n"," cache=cache,\n"," )\n"," \n"," def prepare_data(self):\n"," # this method is called to prepare data from the specified protocol. \n"," # For most tasks, calling Task.prepare_data() is sufficient. If you \n"," # need to prepare task-specific data, define a post_prepare_data method for your task.\n"," super().prepare_data()\n","\n"," def post_prepare_data(self, prepared_data: Dict):\n"," # this method is called at the end of Task.prepare_data() \n"," # to complete data preparation with task-specific data, here \n"," # the list of classes and some training metadata\n","\n"," # load metadata for training subset\n"," prepared_data[\"train_metadata\"] = list()\n"," for training_file in self.protocol.train():\n"," prepared_data[\"train_metadata\"].append({\n"," # path to audio file (str)\n"," \"audio\": training_file[\"audio\"],\n"," # duration of audio file (float)\n"," \"duration\": training_file[\"torchaudio.info\"].num_frames / training_file[\"torchaudio.info\"].sample_rate,\n"," # reference annotation (pyannote.core.Annotation)\n"," \"annotation\": training_file[\"annotation\"],\n"," })\n","\n"," # gather the list of classes\n"," classes = set()\n"," for training_file in prepared_data[\"train_metadata\"]:\n"," classes.update(training_file[\"annotation\"].labels())\n"," prepared_data[\"classes\"] = sorted(classes)\n","\n"," # `has_validation` is True if protocol defines a development set\n"," if not self.has_validation:\n"," return\n"," \n"," def prepare_validation(self, prepared_data : Dict):\n"," # this method is called at the end of Task.prepare_data(), to complete data preparation\n"," # with task validation elements\n"," \n"," # load metadata for validation subset\n"," prepared_data[\"validation\"] = list()\n"," for validation_file in self.protocol.development():\n"," prepared_data[\"validation\"].append({\n"," \"audio\": validation_file[\"audio\"],\n"," \"num_samples\": validation_file[\"torchaudio.info\"].num_frames,\n"," \"annotation\": validation_file[\"annotation\"],\n"," })\n"," \n"," \n"," def setup(self, stage: Optional[Union[str, None]] = None):\n"," # this method assigns prepared data from task.prepare_data() to the task\n"," # and declares the task specifications\n","\n"," super().setup(stage)\n"," \n"," # specify the addressed problem\n"," self.specifications = Specifications(\n"," # it is a multi-label classification problem\n"," problem=Problem.MULTI_LABEL_CLASSIFICATION,\n"," # we expect the model to output one prediction \n"," # for the whole chunk\n"," resolution=Resolution.CHUNK,\n"," # the model will ingest chunks with that duration (in seconds)\n"," duration=self.duration,\n"," # human-readable names of classes\n"," classes=self.prepared_data[\"classes\"])\n"," \n"," def default_metric(self):\n"," # this method defines the default metrics used to evaluate the model during\n"," # a training\n"," num_classes = len(self.specifications.classes)\n"," return MultilabelAUROC(num_classes, average=\"macro\", compute_on_cpu=True)\n","\n"," def train__iter__(self):\n"," # this method generates training samples, one at a time, \"ad infinitum\". each worker \n"," # of the dataloader will run it, independently from other workers. pyannote.audio and\n"," # pytorch-lightning will take care of making batches out of it.\n","\n"," # create worker-specific random number generator (RNG) to avoid this common bug:\n"," # tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/\n"," rng = create_rng_for_worker(self.model)\n","\n"," # load list and number of classes\n"," classes = self.specifications.classes\n"," num_classes = len(classes)\n","\n"," # yield training samples \"ad infinitum\"\n"," while True:\n","\n"," # select training file at random\n"," random_training_file, *_ = rng.choices(self.prepared_data[\"train_metadata\"], k=1)\n","\n"," # select one chunk at random \n"," random_start_time = rng.uniform(0, random_training_file[\"duration\"] - self.duration)\n"," random_chunk = Segment(random_start_time, random_start_time + self.duration)\n","\n"," # load audio excerpt corresponding to random chunk\n"," X = self.model.audio.crop(random_training_file[\"audio\"], \n"," random_chunk, \n"," fixed=self.duration)\n"," \n"," # load labels corresponding to random chunk as {0|1} numpy array\n"," # y[k] = 1 means that kth class is active\n"," y = np.zeros((num_classes,))\n"," active_classes = random_training_file[\"annotation\"].crop(random_chunk).labels()\n"," for active_class in active_classes:\n"," y[classes.index(active_class)] = 1\n"," \n"," # yield training samples as a dict (use 'X' for input and 'y' for target)\n"," yield {'X': X, 'y': y}\n","\n"," def train__len__(self):\n"," # since train__iter__ runs \"ad infinitum\", we need a way to define what an epoch is.\n"," # this is the purpose of this method. it outputs the number of training samples that\n"," # make an epoch.\n","\n"," # we compute this number as the total duration of the training set divided by \n"," # duration of training chunks. we make sure that an epoch is at least one batch long,\n"," # or pytorch-lightning will complain\n"," train_duration = sum(training_file[\"duration\"] for training_file in self.prepared_data[\"train_metadata\"])\n"," return max(self.batch_size, ceil(train_duration / self.duration))\n","\n"," def val__getitem__(self, sample_idx):\n","\n"," # load list and number of classes\n"," classes = self.specifications.classes\n"," num_classes = len(classes)\n","\n","\n"," # find which part of the validation set corresponds to sample_idx\n"," num_samples = np.cumsum([\n"," validation_file[\"num_samples\"] for validation_file in self.prepared_data[\"validation\"]])\n"," file_idx = np.where(num_samples < sample_idx)[0][0]\n"," validation_file = self.prepared_data[\"validation\"][file_idx]\n"," idx = sample_idx - (num_samples[file_idx] - validation_file[\"num_samples\"]) \n"," chunk = SlidingWindow(start=0., duration=self.duration, step=self.duration)[idx]\n","\n"," # load audio excerpt corresponding to current chunk\n"," X = self.model.audio.crop(validation_file[\"audio\"], chunk, fixed=self.duration)\n","\n"," # load labels corresponding to random chunk as {0|1} numpy array\n"," # y[k] = 1 means that kth class is active\n"," y = np.zeros((num_classes,))\n"," active_classes = validation_file[\"annotation\"].crop(chunk).labels()\n"," for active_class in active_classes:\n"," y[classes.index(active_class)] = 1\n","\n"," return {'X': X, 'y': y}\n","\n"," def val__len__(self):\n"," return sum(validation_file[\"num_samples\"] \n"," for validation_file in self.prepared_data[\"validation\"])\n","\n"," # `pyannote.audio.core.task.Task` base class provides a `LightningModule.training_step` and \n"," # `LightningModule.validation_step` methods that rely on self.specifications to guess which \n"," # loss and metrics should be used. you can obviously choose to customize them. \n"," # More details can be found in pytorch-lightning documentation and in \n"," # pyannote.audio.core.task.Task source code. \n","\n"," # def training_step(self, batch, batch_idx: int):\n"," # return loss\n","\n"," # def validation_step(self, batch, batch_idx: int):\n"," # return metric\n","\n"," # pyannote.audio.tasks.segmentation.mixin also provides a convenient mixin\n"," # for \"segmentation\" tasks (ie. with Resolution.FRAME) that already defines\n"," # a bunch of useful methods. You can use it by inheriting your task from the \n"," # pyannote.audio.tasks.segmentation.mixinSegmentationTask"]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":0}
|