Shortcuts

Losses & Metrics

Permutation invariant training (PIT) made easy

Asteroid supports regular Permutation Invariant Training (PIT), it’s extension using Sinkhorn algorithm (SinkPIT) as well as Mixture Invariant Training (MixIT).

PIT

class asteroid.losses.pit_wrapper.PITLossWrapper(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: torch.nn.Module

Permutation invariant loss wrapper.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs).

  • pit_from (str) –

    Determines how PIT is applied.

    • 'pw_mtx' (pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)

    • 'pw_pt' (pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). See get_pw_losses().

    • 'perm_avg' (permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). See best_perm_from_perm_avg_loss().

    In terms of efficiency, 'perm_avg' is the least efficicient.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) --> (B, n\_src!)\). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.

For each of these modes, the best permutation and reordering will be automatically computed. When either 'pw_mtx' or 'pw_pt' is used, and the number of sources is larger than three, the hungarian algorithm is used to find the best permutation.

Examples
>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>>     return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>>                            perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>>                      reduce_kwargs=reduce_kwargs)
forward(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets

  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).

  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best permutation loss for each batch sample, average over the batch.

  • The reordered targets estimates if return_est is True. torch.Tensor of shape $(batch, nsrc, …)$.

static get_pw_losses(loss_func, est_targets, targets, **kwargs)[source]

Get pair-wise losses between the training targets and its estimate for a given loss function.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get pair-wise losses from.

  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets.

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for all permutations of the targets and est_targets.

This function can be called on a loss function which returns a tensor of size \((batch)\). There are more efficient ways to compute pair-wise losses using broadcasting.

static best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs)[source]

Find best permutation from loss function with source axis.

Parameters
  • loss_func – function with signature $(est_targets, targets, **kwargs)$ The loss function batch losses from.

  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of training targets.

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.

  • torch.Tensor:

    The indices of the best permutations.

static find_best_perm(pair_wise_losses, perm_reduce=None, **kwargs)[source]

Find the best permutation, given the pair-wise losses.

Dispatch between factorial method if number of sources is small (<3) and hungarian method for more sources. If perm_reduce is not None, the factorial method is always used.

Parameters
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n\_src, n\_src)\). Pairwise losses.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)

  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.

  • torch.Tensor: The indices of the best permutations.

static reorder_source(source, batch_indices)[source]

Reorder sources according to the best permutation.

Parameters
  • source (torch.Tensor) – Tensor of shape \((batch, n_src, time)\)

  • batch_indices (torch.Tensor) – Tensor of shape \((batch, n_src)\). Contains optimal permutation indices for each batch.

Returns

torch.Tensor – Reordered sources.

static find_best_perm_factorial(pair_wise_losses, perm_reduce=None, **kwargs)[source]

Find the best permutation given the pair-wise losses by looping through all the permutations.

Parameters
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)

  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.

  • torch.Tensor: The indices of the best permutations.

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

static find_best_perm_hungarian(pair_wise_losses: torch.Tensor)[source]

Find the best permutation given the pair-wise losses, using the Hungarian algorithm.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.Tensor: The indices of the best permutations.

class asteroid.losses.pit_wrapper.PITReorder(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: asteroid.losses.pit_wrapper.PITLossWrapper

Permutation invariant reorderer. Only returns the reordered estimates. See :py:class:asteroid.losses.PITLossWrapper.

forward(est_targets, targets, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets

  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).

  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best permutation loss for each batch sample, average over the batch.

  • The reordered targets estimates if return_est is True. torch.Tensor of shape $(batch, nsrc, …)$.

MixIT

class asteroid.losses.mixit_wrapper.MixITLossWrapper(loss_func, generalized=True)[source]

Bases: torch.nn.Module

Mixture invariant loss wrapper.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs).

  • generalized (bool) – Determines how MixIT is applied. If False , apply MixIT for any number of mixtures as soon as they contain the same number of sources (best_part_mixit().) If True (default), apply MixIT for two mixtures, but those mixtures do not necessarly have to contain the same number of sources. See best_part_mixit_generalized().

For each of these modes, the best partition and reordering will be automatically computed.

Examples

>>> import torch
>>> from asteroid.losses import multisrc_mse
>>> mixtures = torch.randn(10, 2, 16000)
>>> est_sources = torch.randn(10, 4, 16000)
>>> # Compute MixIT loss based on pairwise losses
>>> loss_func = MixITLossWrapper(multisrc_mse)
>>> loss_val = loss_func(est_sources, mixtures)
References

[1] Scott Wisdom et al. “Unsupervised sound separation using mixtures of mixtures.” arXiv:2006.12701 (2020)

forward(est_targets, targets, return_est=False, **kwargs)[source]

Find the best partition and return the loss.

Parameters
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, *)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets

  • return_est – Boolean. Whether to return the estimated mixtures estimates (To compute metrics or to save example).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best partition loss for each batch sample, average over the batch. torch.Tensor(loss_value)

  • The estimated mixtures (estimated sources summed according to the partition) if return_est is True. torch.Tensor of shape \((batch, nmix, ...)\).

static best_part_mixit(loss_func, est_targets, targets, **kwargs)[source]

Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid for any number of mixtures as soon as they contain the same number of sources.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get batch losses from.

  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.LongTensor: The indices of the best partition.

  • list: list of the possible partitions of the sources.

static best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs)[source]

Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid only for two mixtures, but those mixtures do not necessarly have to contain the same number of sources e.g the case where one mixture is silent is allowed..

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get batch losses from.

  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.LongTensor: The indexes of the best permutations.

  • list: list of the possible partitions of the sources.

static loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs)[source]

Common loop between both best_part_mixit

static reorder_source(est_targets, targets, min_loss_idx, parts)[source]

Reorder sources according to the best partition.

Parameters
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets.

  • min_loss_idx – torch.LongTensor. The indexes of the best permutations.

  • parts – list of the possible partitions of the sources.

Returns

torch.Tensor – Reordered sources of shape \((batch, nmix, time)\).

SinkPIT

class asteroid.losses.sinkpit_wrapper.SinkPITLossWrapper(loss_func, n_iter=200, hungarian_validation=True)[source]

Bases: torch.nn.Module

Permutation invariant loss wrapper.

Parameters
  • loss_func – function with signature (targets, est_targets, **kwargs).

  • n_iter (int) – number of the Sinkhorn iteration (default = 200). Supposed to be an even number.

  • hungarian_validation (boolean) – Whether to use the Hungarian algorithm for the validation. (default = True)

loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\) It evaluates an approximate value of the PIT loss using Sinkhorn’s iterative algorithm. See best_softperm_sinkhorn() and http://arxiv.org/abs/2010.11871

Examples
>>> import torch
>>> import pytorch_lightning as pl
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute SinkPIT loss based on pairwise losses
>>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr)
>>> loss_val = loss_func(est_sources, sources)
>>> # A fixed temperature parameter `beta` (=10) is used
>>> # unless a cooling callback is set. The value can be
>>> # dynamically changed using a cooling callback module as follows.
>>> model = NeuralNetworkModel()
>>> optimizer = optim.Adam(model.parameters(), lr=1e-3)
>>> dataset = YourDataset()
>>> loader = data.DataLoader(dataset, batch_size=16)
>>> system = System(
>>>     model,
>>>     optimizer,
>>>     loss_func=SinkPITLossWrapper(pairwise_neg_sisdr),
>>>     train_loader=loader,
>>>     val_loader=loader,
>>>     )
>>>
>>> trainer = pl.Trainer(
>>>     max_epochs=100,
>>>     callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)],
>>>     )
>>>
>>> trainer.fit(system)
forward(est_targets, targets, return_est=False, **kwargs)[source]

Evaluate the loss using Sinkhorn’s algorithm.

Parameters
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of training targets

  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best permutation loss for each batch sample, average over

    the batch. torch.Tensor(loss_value)

  • The reordered targets estimates if return_est is True.

    torch.Tensor of shape \((batch, nsrc, ...)\).

static best_softperm_sinkhorn(pair_wise_losses, beta=10, n_iter=200)[source]

Compute an approximate PIT loss using Sinkhorn’s algorithm. See http://arxiv.org/abs/2010.11871

Parameters
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.

  • beta (float) – Inverse temperature parameter. (default = 10)

  • n_iter (int) – Number of iteration. Even number. (default = 200)

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.Tensor: A soft permutation matrix.

Available loss functions

PITLossWrapper supports three types of loss function. For “easy” losses, we implement the three types (pairwise point, single-source loss and multi-source loss). For others, we only implement the single-source loss which can be aggregated into both PIT and nonPIT training.

MSE

asteroid.losses.mse.PairwiseMSE(*args: Any, **kwargs: Any) → Any[source]

Measure pairwise mean square error on a batch.

Shape:
  • est_targets : \((batch, nsrc, ...)\).

  • targets: \((batch, nsrc, ...)\).

Returns

torch.Tensor – with shape \((batch, nsrc, nsrc)\)

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.SingleSrcMSE(*args: Any, **kwargs: Any) → Any[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).

  • targets: \((batch, ...)\).

Returns

torch.Tensor – with shape \((batch)\)

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.MultiSrcMSE(*args: Any, **kwargs: Any) → Any[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).

  • targets: \((batch, ...)\).

Returns

torch.Tensor – with shape \((batch)\)

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)

SDR

asteroid.losses.sdr.PairwiseNegSDR(*args: Any, **kwargs: Any) → Any[source]

Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch.

Parameters
  • sdr_type (str) – choose between snr for plain SNR, sisdr for SI-SDR and sdsdr for SD-SDR [1].

  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.

  • take_log (bool, optional) – by default the log10 of sdr is returned.

Shape:
  • est_targets : \((batch, nsrc, ...)\).

  • targets: \((batch, nsrc, ...)\).

Returns

torch.Tensor – with shape \((batch, nsrc, nsrc)\). Pairwise losses.

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"),
>>>                            pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.SingleSrcNegSDR(*args: Any, **kwargs: Any) → Any[source]

Base class for single-source negative SI-SDR, SD-SDR and SNR.

Parameters
  • sdr_type (str) – choose between snr for plain SNR, sisdr for SI-SDR and sdsdr for SD-SDR [1].

  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.

  • take_log (bool, optional) – by default the log10 of sdr is returned.

  • reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output.

Shape:
  • est_targets : \((batch, time)\).

  • targets: \((batch, time)\).

Returns

torch.Tensor – with shape \((batch)\) if reduction='none' else [] scalar if reduction='mean'.

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"),
>>>                            pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.MultiSrcNegSDR(*args: Any, **kwargs: Any) → Any[source]

Base class for computing negative SI-SDR, SD-SDR and SNR for a given permutation of source and their estimates.

Parameters
  • sdr_type (str) – choose between snr for plain SNR, sisdr for SI-SDR and sdsdr for SD-SDR [1].

  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.

  • take_log (bool, optional) – by default the log10 of sdr is returned.

Shape:
  • est_targets : \((batch, nsrc, time)\).

  • targets: \((batch, nsrc, time)\).

Returns

torch.Tensor – with shape \((batch)\) if reduction='none' else [] scalar if reduction='mean'.

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"),
>>>                            pit_from='perm_avg')
>>> loss = loss_func(est_targets, targets)
References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

PMSQE

asteroid.losses.pmsqe.SingleSrcPMSQE(*args: Any, **kwargs: Any) → Any[source]

Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.

Parameters
  • window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].

  • window_weight (float, optional) – Correction to the window factor applied.

  • bark_eq (bool, optional) – Whether to apply bark equalization.

  • gain_eq (bool, optional) – Whether to apply gain equalization.

  • sample_rate (int) – Sample rate of the input audio.

References

[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es

Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.

Note

Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.

Examples
>>> import torch
>>> from asteroid_filterbanks import STFTFB, Encoder, transforms
>>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
>>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
>>> ref_spec = transforms.mag(stft(ref))
>>> est_spec = transforms.mag(stft(est))
>>> loss_func = SingleSrcPMSQE()
>>> loss_value = loss_func(est_spec, ref_spec)
>>> # Usage with PITLossWrapper
>>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
>>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000)
>>> ref_spec = transforms.mag(stft(ref))
>>> est_spec = transforms.mag(stft(est))
>>> loss_value = loss_func(ref_spec, est_spec)

STOI

MultiScale Spectral Loss

asteroid.losses.multi_scale_spectral.SingleSrcMultiScaleSpectral(*args: Any, **kwargs: Any) → Any[source]

Measure multi-scale spectral loss as described in [1]

Parameters
  • n_filters (list) – list containing the number of filter desired for each STFT

  • windows_size (list) – list containing the size of the window desired for each STFT

  • hops_size (list) – list containing the size of the hop desired for each STFT

Shape:
  • est_targets : \((batch, time)\).

  • targets: \((batch, time)\).

Returns

torch.Tensor – with shape [batch]

Examples
>>> import torch
>>> targets = torch.randn(10, 32000)
>>> est_targets = torch.randn(10, 32000)
>>> # Using it by itself on a pair of source/estimate
>>> loss_func = SingleSrcMultiScaleSpectral()
>>> loss = loss_func(est_targets, targets)
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # Using it with PITLossWrapper with sets of source/estimates
>>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(),
>>>                            pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
References

[1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts “DDSP: Differentiable Digital Signal Processing” ICLR 2020.

Deep clustering (Affinity) loss

asteroid.losses.cluster.deep_clustering_loss(embedding, tgt_index, binary_mask=None)[source]

Compute the deep clustering loss defined in [1].

Parameters
  • embedding (torch.Tensor) – Estimated embeddings. Expected shape \((batch, frequency * frame, embedding\_dim)\).

  • tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: \((batch, frequency, frame)\).

  • binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.dsp.vad.ebased_vad.

Returns

torch.Tensor. Deep clustering loss for every batch sample.

Examples
>>> import torch
>>> from asteroid.losses.cluster import deep_clustering_loss
>>> spk_cnt = 3
>>> embedding = torch.randn(10, 5*400, 20)
>>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)
>>> loss = deep_clustering_loss(embedding, targets)
Reference

[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”

Note

Be careful in viewing the embedding tensors. The target indices tgt_index are of shape \((batch, freq, frames)\). Even if the embedding is of shape \((batch, freq * frames, emb)\), the underlying view should be \((batch, freq, frames, emb)\) and not \((batch, frames, freq, emb)\).

Computing metrics

asteroid.metrics.get_metrics(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False, ignore_metrics_errors=False, filename=None)[source]

Get speech separation/enhancement metrics from mix/clean/estimate.

Parameters
  • mix (np.array) – mixture array.

  • clean (np.array) – reference array.

  • estimate (np.array) – estimate array.

  • sample_rate (int) – sampling rate of the audio clips.

  • metrics_list (Union[List[str], str) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).

  • average (bool) – Return dict([float]) if True, else dict([array]).

  • compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)

  • ignore_metrics_errors (bool) – Whether to ignore errors that occur in computing the metrics. A warning will be printed instead.

  • filename (str, optional) – If computing a metric fails, print this filename along with the exception/warning message for debugging purposes.

Shape:
  • mix: \((D, N)\) or (N, ).

  • clean: \((K\_source, N)\) or (N, ).

  • estimate: \((K\_target, N)\) or (N, ).

Returns

dict – Dictionary with all requested metrics, with ‘input_’ prefix for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average.

Examples
>>> import numpy as np
>>> import pprint
>>> from asteroid.metrics import get_metrics
>>> mix = np.random.randn(1, 16000)
>>> clean = np.random.randn(2, 16000)
>>> est = np.random.randn(2, 16000)
>>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000,
...                            metrics_list='all')
>>> pprint.pprint(metrics_dict)
{'input_pesq': 1.924380898475647,
 'input_sar': -11.67667585294225,
 'input_sdr': -14.88667106190552,
 'input_si_sdr': -52.43849784881705,
 'input_sir': -0.10419427290163795,
 'input_stoi': 0.015112115177091223,
 'pesq': 1.7713886499404907,
 'sar': -11.610963379923195,
 'sdr': -14.527246041125844,
 'si_sdr': -46.26557128489802,
 'sir': 0.4799929272243427,
 'stoi': 0.022023073540350643}