Models¶
Base classes¶
-
class
asteroid.models.base_models.
BaseModel
(sample_rate: float, in_channels: Optional[int] = 1)[source]¶ Bases:
torch.nn.Module
Base class for serializable models.
Defines saving/loading procedures, and separation interface to separate. Need to overwrite the forward and get_model_args methods.
Models inheriting from BaseModel can be used by
asteroid.separate
and by the asteroid-infer CLI. For models whose forward doesn’t go from waveform to waveform tensors, overwrite forward_wav to return waveform tensors.- Parameters
sample_rate (float) – Operating sample rate of the model.
in_channels – Number of input channels in the signal. If None, no checks will be performed.
-
forward_wav
(wav, *args, **kwargs)[source]¶ Separation method for waveforms.
In case the network’s forward doesn’t have waveforms as input/output, overwrite this method to separate from waveform to waveform. Should return a single torch.Tensor, the separated waveforms.
- Parameters
wav (torch.Tensor) – waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last.
-
classmethod
from_pretrained
(pretrained_model_conf_or_path, *args, **kwargs)[source]¶ Instantiate separation model from a model config (file or dict).
- Parameters
pretrained_model_conf_or_path (Union[dict, str]) – model conf as returned by serialize, or path to it. Need to contain model_args and state_dict keys.
*args – Positional arguments to be passed to the model.
**kwargs – Keyword arguments to be passed to the model. They overwrite the ones in the model package.
- Returns
nn.Module corresponding to the pretrained model conf/URL.
- Raises
ValueError if the input config file doesn't contain the keys – model_name, model_args or state_dict.
-
class
asteroid.models.base_models.
BaseEncoderMaskerDecoder
(encoder, masker, decoder, encoder_activation=None)[source]¶ Bases:
asteroid.models.base_models.BaseModel
Base class for encoder-masker-decoder separation models.
- Parameters
-
forward
(wav)[source]¶ Enc/Mask/Dec model forward
- Parameters
wav (torch.Tensor) – waveform tensor. 1D, 2D or 3D tensor, time last.
- Returns
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
-
forward_encoder
(wav: torch.Tensor) → torch.Tensor[source]¶ Computes time-frequency representation of wav.
- Parameters
wav (torch.Tensor) – waveform tensor in 3D shape, time last.
- Returns
torch.Tensor, of shape (batch, feat, seq).
-
forward_masker
(tf_rep: torch.Tensor) → torch.Tensor[source]¶ Estimates masks from time-frequency representation.
- Parameters
tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq).
- Returns
torch.Tensor – Estimated masks
-
apply_masks
(tf_rep: torch.Tensor, est_masks: torch.Tensor) → torch.Tensor[source]¶ Applies masks to time-frequency representation.
- Parameters
tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq) shape.
est_masks (torch.Tensor) – Estimated masks.
- Returns
torch.Tensor – Masked time-frequency representations.
-
forward_decoder
(masked_tf_rep: torch.Tensor) → torch.Tensor[source]¶ Reconstructs time-domain waveforms from masked representations.
- Parameters
masked_tf_rep (torch.Tensor) – Masked time-frequency representation.
- Returns
torch.Tensor – Time-domain waveforms.
-
asteroid.models.base_models.
BaseTasNet
[source]¶ alias of
asteroid.models.base_models.BaseEncoderMaskerDecoder
Ready-to-use models¶
-
class
asteroid.models.conv_tasnet.
ConvTasNet
(n_src, out_chan=None, n_blocks=8, n_repeats=3, bn_chan=128, hid_chan=512, skip_chan=128, conv_kernel_size=3, norm_type='gLN', mask_act='sigmoid', in_chan=None, causal=False, fb_name='free', kernel_size=16, n_filters=512, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
ConvTasNet separation model, as described in [1].
- Parameters
n_src (int) – Number of sources in the input mixtures.
out_chan (int, optional) – Number of bins in the estimated masks. If
None
, out_chan = in_chan.n_blocks (int, optional) – Number of convolutional blocks in each repeat. Defaults to 8.
n_repeats (int, optional) – Number of repeats. Defaults to 3.
bn_chan (int, optional) – Number of channels after the bottleneck.
hid_chan (int, optional) – Number of channels in the convolutional blocks.
skip_chan (int, optional) – Number of channels in the skip connections. If 0 or None, TDConvNet won’t have any skip connections and the masks will be computed from the residual output. Corresponds to the ConvTasnet architecture in v1 or the paper.
conv_kernel_size (int, optional) – Kernel size in convolutional blocks.
norm_type (str, optional) – To choose from
'BN'
,'gLN'
,'cLN'
.mask_act (str, optional) – Which non-linear function to generate mask.
in_chan (int, optional) – Number of input channels, should be equal to n_filters.
causal (bool, optional) – Whether or not the convolutions are causal.
fb_name (str, className) – Filterbank family from which to make encoder and decoder. To choose among [
'free'
,'analytic_free'
,'param_sinc'
,'stft'
].n_filters (int) – Number of filters / Input dimension of the masker net.
kernel_size (int) – Length of the filters.
stride (int, optional) – Stride of the convolution. If None (default), set to
kernel_size // 2
.sample_rate (float) – Sampling rate of the model.
**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
- References
[1] : “Conv-TasNet: Surpassing ideal time-frequency magnitude masking for speech separation” TASLP 2019 Yi Luo, Nima Mesgarani https://arxiv.org/abs/1809.07454
-
class
asteroid.models.conv_tasnet.
VADNet
(n_src, out_chan=None, n_blocks=8, n_repeats=3, bn_chan=128, hid_chan=512, skip_chan=128, conv_kernel_size=3, norm_type='gLN', mask_act='sigmoid', in_chan=None, causal=False, fb_name='free', kernel_size=16, n_filters=512, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs)[source]¶ Bases:
asteroid.models.conv_tasnet.ConvTasNet
-
forward_decoder
(masked_tf_rep: torch.Tensor) → torch.Tensor[source]¶ Reconstructs time-domain waveforms from masked representations.
- Parameters
masked_tf_rep (torch.Tensor) – Masked time-frequency representation.
- Returns
torch.Tensor – Time-domain waveforms.
-
-
class
asteroid.models.dccrnet.
DCCRNet
(*args, stft_n_filters=512, stft_kernel_size=400, stft_stride=100, **masknet_kwargs)[source]¶ Bases:
asteroid.models.dcunet.BaseDCUNet
DCCRNet as proposed in [1].
- Parameters
- References
[1] : “DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement”, Yanxin Hu et al. https://arxiv.org/abs/2008.00264
-
forward_encoder
(wav)[source]¶ Computes time-frequency representation of wav.
- Parameters
wav (torch.Tensor) – waveform tensor in 3D shape, time last.
- Returns
torch.Tensor, of shape (batch, feat, seq).
-
apply_masks
(tf_rep, est_masks)[source]¶ Applies masks to time-frequency representation.
- Parameters
tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq) shape.
est_masks (torch.Tensor) – Estimated masks.
- Returns
torch.Tensor – Masked time-frequency representations.
-
class
asteroid.models.dcunet.
BaseDCUNet
(architecture, stft_n_filters=1024, stft_kernel_size=1024, stft_stride=256, sample_rate=16000.0, **masknet_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
Base class for
DCUNet
andDCCRNet
classes.- Parameters
architecture (str) – The architecture to use. Overriden by subclasses.
stft_n_filters (int) –
stft_kernel_size (int) – STFT frame length to use.
stft_stride (int, optional) – STFT hop length to use.
sample_rate (float) – Sampling rate of the model.
masknet_kwargs (optional) – Passed to the masknet constructor.
-
forward_encoder
(wav)[source]¶ Computes time-frequency representation of wav.
- Parameters
wav (torch.Tensor) – waveform tensor in 3D shape, time last.
- Returns
torch.Tensor, of shape (batch, feat, seq).
-
apply_masks
(tf_rep, est_masks)[source]¶ Applies masks to time-frequency representation.
- Parameters
tf_rep (torch.Tensor) – Time-frequency representation in (batch, feat, seq) shape.
est_masks (torch.Tensor) – Estimated masks.
- Returns
torch.Tensor – Masked time-frequency representations.
-
class
asteroid.models.dcunet.
DCUNet
(architecture, stft_n_filters=1024, stft_kernel_size=1024, stft_stride=256, sample_rate=16000.0, **masknet_kwargs)[source]¶ Bases:
asteroid.models.dcunet.BaseDCUNet
DCUNet as proposed in [1].
- Parameters
architecture (str) – The architecture to use, any of “DCUNet-10”, “DCUNet-16”, “DCUNet-20”, “Large-DCUNet-20”.
stft_n_filters (int) –
stft_kernel_size (int) – STFT frame length to use.
stft_stride (int, optional) – STFT hop length to use.
sample_rate (float) – Sampling rate of the model.
masknet_kwargs (optional) – Passed to
DCUMaskNet
- References
[1] : “Phase-aware Speech Enhancement with Deep Complex U-Net”, Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107
-
class
asteroid.models.demask.
DeMask
(input_type='mag', output_type='mag', hidden_dims=(1024, ), dropout=0.0, activation='relu', mask_act='relu', norm_type='gLN', fb_name='stft', n_filters=512, stride=256, kernel_size=512, sample_rate=16000, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
Simple MLP model for surgical mask speech enhancement A transformed-domain masking approach is used.
- Parameters
input_type (str, optional) – whether the magnitude spectrogram “mag” or both real imaginary parts “reim” are passed as features to the masker network. Concatenation of “mag” and “reim” also can be used by using “cat”.
output_type (str, optional) – whether the masker ouputs a mask for magnitude spectrogram “mag” or both real imaginary parts “reim”.
hidden_dims (list, optional) – list of MLP hidden layer sizes.
dropout (float, optional) – dropout probability.
activation (str, optional) – type of activation used in hidden MLP layers.
mask_act (str, optional) – Which non-linear function to generate mask.
norm_type (str, optional) – To choose from
'BN'
,'gLN'
,'cLN'
.fb_name (str) – type of analysis and synthesis filterbanks used, choose between [“stft”, “free”, “analytic_free”].
n_filters (int) – number of filters in the analysis and synthesis filterbanks.
stride (int) – filterbank filters stride.
kernel_size (int) – length of filters in the filterbank.
encoder_activation (str) –
sample_rate (float) – Sampling rate of the model.
**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
-
forward_masker
(tf_rep)[source]¶ Estimates masks based on time-frequency representations.
- Parameters
tf_rep (torch.Tensor) – Time-frequency representation in (batch, freq, seq).
- Returns
torch.Tensor – Estimated masks in (batch, freq, seq).
-
apply_masks
(tf_rep, est_masks)[source]¶ Applies masks to time-frequency representations.
- Parameters
tf_rep (torch.Tensor) – Time-frequency representations in (batch, freq, seq).
est_masks (torch.Tensor) – Estimated masks in (batch, freq, seq).
- Returns
torch.Tensor – Masked time-frequency representations.
-
class
asteroid.models.dprnn_tasnet.
DPRNNTasNet
(n_src, out_chan=None, bn_chan=128, hid_size=128, chunk_size=100, hop_size=None, n_repeats=6, norm_type='gLN', mask_act='sigmoid', bidirectional=True, rnn_type='LSTM', num_layers=1, dropout=0, in_chan=None, fb_name='free', kernel_size=16, n_filters=64, stride=8, encoder_activation=None, sample_rate=8000, use_mulcat=False, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
DPRNN separation model, as described in [1].
- Parameters
n_src (int) – Number of masks to estimate.
out_chan (int or None) – Number of bins in the estimated masks. Defaults to in_chan.
bn_chan (int) – Number of channels after the bottleneck. Defaults to 128.
hid_size (int) – Number of neurons in the RNNs cell state. Defaults to 128.
chunk_size (int) – window size of overlap and add processing. Defaults to 100.
hop_size (int or None) – hop size (stride) of overlap and add processing. Default to chunk_size // 2 (50% overlap).
n_repeats (int) – Number of repeats. Defaults to 6.
norm_type (str, optional) –
Type of normalization to use. To choose from
'gLN'
: global Layernorm'cLN'
: channelwise Layernorm
mask_act (str, optional) – Which non-linear function to generate mask.
bidirectional (bool, optional) – True for bidirectional Inter-Chunk RNN (Intra-Chunk is always bidirectional).
rnn_type (str, optional) – Type of RNN used. Choose between
'RNN'
,'LSTM'
and'GRU'
.num_layers (int, optional) – Number of layers in each RNN.
dropout (float, optional) – Dropout ratio, must be in [0,1].
in_chan (int, optional) – Number of input channels, should be equal to n_filters.
fb_name (str, className) – Filterbank family from which to make encoder and decoder. To choose among [
'free'
,'analytic_free'
,'param_sinc'
,'stft'
].n_filters (int) – Number of filters / Input dimension of the masker net.
kernel_size (int) – Length of the filters.
stride (int, optional) – Stride of the convolution. If None (default), set to
kernel_size // 2
.sample_rate (float) – Sampling rate of the model.
**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
- References
[1] “Dual-path RNN: efficient long sequence modeling for time-domain single-channel speech separation”, Yi Luo, Zhuo Chen and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
-
class
asteroid.models.dptnet.
DPTNet
(n_src, n_heads=4, ff_hid=256, chunk_size=100, hop_size=None, n_repeats=6, norm_type='gLN', ff_activation='relu', encoder_activation='relu', mask_act='relu', bidirectional=True, dropout=0, in_chan=None, fb_name='free', kernel_size=16, n_filters=64, stride=8, sample_rate=8000, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
DPTNet separation model, as described in [1].
- Parameters
n_src (int) – Number of masks to estimate.
out_chan (int or None) – Number of bins in the estimated masks. Defaults to in_chan.
bn_chan (int) – Number of channels after the bottleneck. Defaults to 128.
hid_size (int) – Number of neurons in the RNNs cell state. Defaults to 128.
chunk_size (int) – window size of overlap and add processing. Defaults to 100.
hop_size (int or None) – hop size (stride) of overlap and add processing. Default to chunk_size // 2 (50% overlap).
n_repeats (int) – Number of repeats. Defaults to 6.
norm_type (str, optional) –
Type of normalization to use. To choose from
'gLN'
: global Layernorm'cLN'
: channelwise Layernorm
mask_act (str, optional) – Which non-linear function to generate mask.
bidirectional (bool, optional) – True for bidirectional Inter-Chunk RNN (Intra-Chunk is always bidirectional).
rnn_type (str, optional) – Type of RNN used. Choose between
'RNN'
,'LSTM'
and'GRU'
.num_layers (int, optional) – Number of layers in each RNN.
dropout (float, optional) – Dropout ratio, must be in [0,1].
in_chan (int, optional) – Number of input channels, should be equal to n_filters.
fb_name (str, className) – Filterbank family from which to make encoder and decoder. To choose among [
'free'
,'analytic_free'
,'param_sinc'
,'stft'
].n_filters (int) – Number of filters / Input dimension of the masker net.
kernel_size (int) – Length of the filters.
stride (int, optional) – Stride of the convolution. If None (default), set to
kernel_size // 2
.sample_rate (float) – Sampling rate of the model.
**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
- References
[1]: Jingjing Chen et al. “Dual-Path Transformer Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation” Interspeech 2020.
-
class
asteroid.models.lstm_tasnet.
LSTMTasNet
(n_src, out_chan=None, rnn_type='lstm', n_layers=4, hid_size=512, dropout=0.3, mask_act='sigmoid', bidirectional=True, in_chan=None, fb_name='free', n_filters=64, kernel_size=16, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
TasNet separation model, as described in [1].
- Parameters
n_src (int) – Number of masks to estimate.
out_chan (int or None) – Number of bins in the estimated masks. Defaults to in_chan.
hid_size (int) – Number of neurons in the RNNs cell state. Defaults to 128.
mask_act (str, optional) – Which non-linear function to generate mask.
bidirectional (bool, optional) – True for bidirectional Inter-Chunk RNN (Intra-Chunk is always bidirectional).
rnn_type (str, optional) – Type of RNN used. Choose between
'RNN'
,'LSTM'
and'GRU'
.n_layers (int, optional) – Number of layers in each RNN.
dropout (float, optional) – Dropout ratio, must be in [0,1].
in_chan (int, optional) – Number of input channels, should be equal to n_filters.
fb_name (str, className) – Filterbank family from which to make encoder and decoder. To choose among [
'free'
,'analytic_free'
,'param_sinc'
,'stft'
].n_filters (int) – Number of filters / Input dimension of the masker net.
kernel_size (int) – Length of the filters.
stride (int, optional) – Stride of the convolution. If None (default), set to
kernel_size // 2
.sample_rate (float) – Sampling rate of the model.
**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
- References
[1]: Yi Luo et al. “Real-time Single-channel Dereverberation and Separation with Time-domain Audio Separation Network”, Interspeech 2018
-
class
asteroid.models.sudormrf.
SuDORMRFNet
(n_src, bn_chan=128, num_blocks=16, upsampling_depth=4, mask_act='softmax', in_chan=None, fb_name='free', kernel_size=21, n_filters=512, stride=None, sample_rate=8000, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
SuDORMRF separation model, as described in [1].
- Parameters
n_src (int) – Number of sources in the input mixtures.
bn_chan (int, optional) – Number of bins in the bottleneck layer and the UNet blocks.
num_blocks (int) – Number of of UBlocks.
upsampling_depth (int) – Depth of upsampling.
mask_act (str) – Name of output activation.
in_chan (int, optional) – Number of input channels, should be equal to n_filters.
fb_name (str, className) – Filterbank family from which to make encoder and decoder. To choose among [
'free'
,'analytic_free'
,'param_sinc'
,'stft'
].n_filters (int) – Number of filters / Input dimension of the masker net.
kernel_size (int) – Length of the filters.
stride (int, optional) – Stride of the convolution. If None (default), set to
kernel_size // 2
.sample_rate (float) – Sampling rate of the model.
**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
- References
[1] : “Sudo rm -rf: Efficient Networks for Universal Audio Source Separation”, Tzinis et al. MLSP 2020.
-
class
asteroid.models.sudormrf.
SuDORMRFImprovedNet
(n_src, bn_chan=128, num_blocks=16, upsampling_depth=4, mask_act='relu', in_chan=None, fb_name='free', kernel_size=21, n_filters=512, stride=None, sample_rate=8000, **fb_kwargs)[source]¶ Bases:
asteroid.models.base_models.BaseEncoderMaskerDecoder
Improved SuDORMRF separation model, as described in [1].
- Parameters
n_src (int) – Number of sources in the input mixtures.
bn_chan (int, optional) – Number of bins in the bottleneck layer and the UNet blocks.
num_blocks (int) – Number of of UBlocks.
upsampling_depth (int) – Depth of upsampling.
mask_act (str) – Name of output activation.
in_chan (int, optional) – Number of input channels, should be equal to n_filters.
fb_name (str, className) – Filterbank family from which to make encoder and decoder. To choose among [
'free'
,'analytic_free'
,'param_sinc'
,'stft'
].n_filters (int) – Number of filters / Input dimension of the masker net.
kernel_size (int) – Length of the filters.
stride (int, optional) – Stride of the convolution. If None (default), set to
kernel_size // 2
.**fb_kwargs (dict) – Additional kwards to pass to the filterbank creation.
- References
[1] : “Sudo rm -rf: Efficient Networks for Universal Audio Source Separation”, Tzinis et al. MLSP 2020.
Publishing models¶
-
class
asteroid.models.zenodo.
Zenodo
(api_key=None, use_sandbox=True)[source]¶ Bases:
object
Faciliate Zenodo’s REST API.
- Parameters
All methods return the requests response.
Note
A Zenodo record is something that is public and cannot be deleted. A Zenodo deposit has not yet been published, is private and can be deleted.
-
create_new_deposition
(metadata=None)[source]¶ Creates a new deposition.
- Parameters
metadata (dict, optional) – Metadata dict to upload on the new deposition.
-
change_metadata_in_deposition
(dep_id, metadata)[source]¶ Set or replace metadata in given deposition
- Parameters
- Examples
>>> metadata = { ... 'title': 'My first upload', ... 'upload_type': 'poster', ... 'description': 'This is my first upload', ... 'creators': [{'name': 'Doe, John', ... 'affiliation': 'Zenodo'}] ... }
-
upload_new_file_to_deposition
(dep_id, file, name=None)[source]¶ Upload one file to existing deposition.
- Parameters
dep_id (int) – deposition id. You cna get it with r = create_new_deposition(); dep_id = r.json()[‘id’]
file (str or io.BufferedReader) – path to a file, or already opened file (path prefered).
name (str, optional) – name given to the uploaded file. Defaults to the path.
-
asteroid.models.publisher.
save_publishable
(publish_dir, model_dict, metrics=None, train_conf=None, recipe=None)[source]¶ Save models to prepare for publication / model sharing.
- Parameters
publish_dir (str) – Path to the publishing directory. Usually under exp/exp_name/publish_dir
model_dict (dict) – dict at least with keys model_args, state_dict,`dataset` or licenses
metrics (dict) – dict with evaluation metrics.
train_conf (dict) – Training configuration dict (from conf.yml).
recipe (str) – Name of the recipe.
- Returns
dict, same as model_dict with added fields.
- Raises
AssertionError when either model_args, state_dict,`dataset` or – licenses are not present is model_dict.keys()
-
asteroid.models.publisher.
upload_publishable
(publish_dir, uploader=None, affiliation=None, git_username=None, token=None, force_publish=False, use_sandbox=False, unit_test=False)[source]¶ Entry point to upload publishable model.
- Parameters
publish_dir (str) – Path to the publishing directory. Usually under exp/exp_name/publish_dir
uploader (str) – Full name of the uploader (Ex: Manuel Pariente)
affiliation (str, optional) – Affiliation (no accent).
git_username (str, optional) – GitHub username.
token (str) – Access token generated to upload depositions.
force_publish (bool) – Whether to directly publish without asking confirmation before. Defaults to False.
use_sandbox (bool) – Whether to use Zenodo’s sandbox instead of the official Zenodo.
unit_test (bool) – If True, we do not ask user input and do not publish.
-
asteroid.models.publisher.
make_license_notice
(model_name, licenses, uploader=None)[source]¶ Make license notice based on license dicts.
- Parameters
- Returns
- str, the license note describing the model, it’s attribution,
the original licenses, what we license it under and the licensor.
-
asteroid.models.publisher.
zenodo_upload
(model, token, model_path=None, use_sandbox=False)[source]¶ Create deposit and upload metadata + model
-
asteroid.models.publisher.
make_metadata_from_model
(model)[source]¶ Create Zenodo deposit metadata for a given publishable model.
- Parameters
model (dict) – Dictionary with all infos needed to publish. More info to come.
- Returns
dict, the metadata to create the Zenodo deposit with.
Note
We remove the PESQ from the final results as a license is needed to use it.