Files
Hervé BREDIN 0b9be93ee1 chore: replace load_from_checkpoint by Model.from_pretrained (#588)
* chore: import Model from package root
* chore: import Inference from package root
2021-01-20 16:39:08 +01:00

80 lines
1.9 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyannote.database import get_protocol, FileFinder\n",
"protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n",
" preprocessors={\"audio\": FileFinder()})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyannote.audio.tasks import VoiceActivityDetection\n",
"from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n",
"import pytorch_lightning as pl\n",
"\n",
"vad = VoiceActivityDetection(protocol, duration=2., batch_size=32, num_workers=4)\n",
"model = SimpleSegmentationModel(task=vad)\n",
"trainer = pl.Trainer(max_epochs=1, default_root_dir='sharing/')\n",
"_ = trainer.fit(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load a model without knowing its class"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyannote.audio import Model\n",
"model = Model.from_pretrained('sharing/lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt')\n",
"assert isinstance(model, SimpleSegmentationModel)\n",
"\n",
"# checkpoint should work with a URL as well (it relies on pl_load)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}