Shortcuts

Losses & Metrics

Permutation invariant training (PIT) made easy

Asteroid supports regular Permutation Invariant Training (PIT), it’s extension using Sinkhorn algorithm (SinkPIT) as well as Mixture Invariant Training (MixIT).

PIT

class asteroid.losses.pit_wrapper.PITLossWrapper(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: torch.nn.Module

Permutation invariant loss wrapper.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs).

  • pit_from (str) –

    Determines how PIT is applied.

    • 'pw_mtx' (pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)

    • 'pw_pt' (pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). See get_pw_losses().

    • 'perm_avg' (permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). See best_perm_from_perm_avg_loss().

    In terms of efficiency, 'perm_avg' is the least efficicient.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) --> (B, n\_src!)\). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.

For each of these modes, the best permutation and reordering will be automatically computed. When either 'pw_mtx' or 'pw_pt' is used, and the number of sources is larger than three, the hungarian algorithm is used to find the best permutation.

Examples
>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>>     return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>>                            perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>>                      reduce_kwargs=reduce_kwargs)
forward(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets

  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).

  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best permutation loss for each batch sample, average over the batch.

  • The reordered targets estimates if return_est is True. torch.Tensor of shape $(batch, nsrc, …)$.

static get_pw_losses(loss_func, est_targets, targets, **kwargs)[source]

Get pair-wise losses between the training targets and its estimate for a given loss function.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get pair-wise losses from.

  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets.

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for all permutations of the targets and est_targets.

This function can be called on a loss function which returns a tensor of size \((batch)\). There are more efficient ways to compute pair-wise losses using broadcasting.

static best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs)[source]

Find best permutation from loss function with source axis.

Parameters
  • loss_func – function with signature $(est_targets, targets, **kwargs)$ The loss function batch losses from.

  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of training targets.

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.

  • torch.Tensor:

    The indices of the best permutations.

static find_best_perm(pair_wise_losses, perm_reduce=None, **kwargs)[source]

Find the best permutation, given the pair-wise losses.

Dispatch between factorial method if number of sources is small (<3) and hungarian method for more sources. If perm_reduce is not None, the factorial method is always used.

Parameters
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n\_src, n\_src)\). Pairwise losses.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)

  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.

  • torch.Tensor: The indices of the best permutations.

static reorder_source(source, batch_indices)[source]

Reorder sources according to the best permutation.

Parameters
  • source (torch.Tensor) – Tensor of shape \((batch, n_src, time)\)

  • batch_indices (torch.Tensor) – Tensor of shape \((batch, n_src)\). Contains optimal permutation indices for each batch.

Returns

torch.Tensor – Reordered sources.

static find_best_perm_factorial(pair_wise_losses, perm_reduce=None, **kwargs)[source]

Find the best permutation given the pair-wise losses by looping through all the permutations.

Parameters
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.

  • perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)

  • **kwargs – additional keyword argument that will be passed to the permutation reduce function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size $(batch,)$.

  • torch.Tensor: The indices of the best permutations.

MIT Copyright (c) 2018 Kaituo XU. See Original code and License.

static find_best_perm_hungarian(pair_wise_losses: torch.Tensor)[source]

Find the best permutation given the pair-wise losses, using the Hungarian algorithm.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.Tensor: The indices of the best permutations.

class asteroid.losses.pit_wrapper.PITReorder(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]

Bases: asteroid.losses.pit_wrapper.PITLossWrapper

Permutation invariant reorderer. Only returns the reordered estimates. See :py:class:asteroid.losses.PITLossWrapper.

forward(est_targets, targets, reduce_kwargs=None, **kwargs)[source]

Find the best permutation and return the loss.

Parameters
  • est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.

  • targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets

  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).

  • reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best permutation loss for each batch sample, average over the batch.

  • The reordered targets estimates if return_est is True. torch.Tensor of shape $(batch, nsrc, …)$.

MixIT

class asteroid.losses.mixit_wrapper.MixITLossWrapper(loss_func, generalized=True, reduction='mean')[source]

Bases: torch.nn.Module

Mixture invariant loss wrapper.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs).

  • generalized (bool) – Determines how MixIT is applied. If False , apply MixIT for any number of mixtures as soon as they contain the same number of sources (best_part_mixit().) If True (default), apply MixIT for two mixtures, but those mixtures do not necessarly have to contain the same number of sources. See best_part_mixit_generalized().

  • reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output.

For each of these modes, the best partition and reordering will be automatically computed.

Examples

>>> import torch
>>> from asteroid.losses import multisrc_mse
>>> mixtures = torch.randn(10, 2, 16000)
>>> est_sources = torch.randn(10, 4, 16000)
>>> # Compute MixIT loss based on pairwise losses
>>> loss_func = MixITLossWrapper(multisrc_mse)
>>> loss_val = loss_func(est_sources, mixtures)
References

[1] Scott Wisdom et al. “Unsupervised sound separation using mixtures of mixtures.” arXiv:2006.12701 (2020)

forward(est_targets, targets, return_est=False, **kwargs)[source]

Find the best partition and return the loss.

Parameters
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, *)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets

  • return_est – Boolean. Whether to return the estimated mixtures estimates (To compute metrics or to save example).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best partition loss for each batch sample, average over the batch. torch.Tensor(loss_value)

  • The estimated mixtures (estimated sources summed according to the partition) if return_est is True. torch.Tensor of shape \((batch, nmix, ...)\).

static best_part_mixit(loss_func, est_targets, targets, **kwargs)[source]

Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid for any number of mixtures as soon as they contain the same number of sources.

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get batch losses from.

  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.LongTensor: The indices of the best partition.

  • list: list of the possible partitions of the sources.

static best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs)[source]

Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid only for two mixtures, but those mixtures do not necessarly have to contain the same number of sources e.g the case where one mixture is silent is allowed..

Parameters
  • loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get batch losses from.

  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.LongTensor: The indexes of the best permutations.

  • list: list of the possible partitions of the sources.

static loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs)[source]

Common loop between both best_part_mixit

static reorder_source(est_targets, targets, min_loss_idx, parts)[source]

Reorder sources according to the best partition.

Parameters
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets.

  • min_loss_idx – torch.LongTensor. The indexes of the best permutations.

  • parts – list of the possible partitions of the sources.

Returns

torch.Tensor – Reordered sources of shape \((batch, nmix, time)\).

SinkPIT

class asteroid.losses.sinkpit_wrapper.SinkPITLossWrapper(loss_func, n_iter=200, hungarian_validation=True)[source]

Bases: torch.nn.Module

Permutation invariant loss wrapper.

Parameters
  • loss_func – function with signature (targets, est_targets, **kwargs).

  • n_iter (int) – number of the Sinkhorn iteration (default = 200). Supposed to be an even number.

  • hungarian_validation (boolean) – Whether to use the Hungarian algorithm for the validation. (default = True)

loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\) It evaluates an approximate value of the PIT loss using Sinkhorn’s iterative algorithm. See best_softperm_sinkhorn() and http://arxiv.org/abs/2010.11871

Examples
>>> import torch
>>> import pytorch_lightning as pl
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute SinkPIT loss based on pairwise losses
>>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr)
>>> loss_val = loss_func(est_sources, sources)
>>> # A fixed temperature parameter `beta` (=10) is used
>>> # unless a cooling callback is set. The value can be
>>> # dynamically changed using a cooling callback module as follows.
>>> model = NeuralNetworkModel()
>>> optimizer = optim.Adam(model.parameters(), lr=1e-3)
>>> dataset = YourDataset()
>>> loader = data.DataLoader(dataset, batch_size=16)
>>> system = System(
>>>     model,
>>>     optimizer,
>>>     loss_func=SinkPITLossWrapper(pairwise_neg_sisdr),
>>>     train_loader=loader,
>>>     val_loader=loader,
>>>     )
>>>
>>> trainer = pl.Trainer(
>>>     max_epochs=100,
>>>     callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)],
>>>     )
>>>
>>> trainer.fit(system)
forward(est_targets, targets, return_est=False, **kwargs)[source]

Evaluate the loss using Sinkhorn’s algorithm.

Parameters
  • est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.

  • targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of training targets

  • return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).

  • **kwargs – additional keyword argument that will be passed to the loss function.

Returns

  • Best permutation loss for each batch sample, average over

    the batch. torch.Tensor(loss_value)

  • The reordered targets estimates if return_est is True.

    torch.Tensor of shape \((batch, nsrc, ...)\).

static best_softperm_sinkhorn(pair_wise_losses, beta=10, n_iter=200)[source]

Compute an approximate PIT loss using Sinkhorn’s algorithm. See http://arxiv.org/abs/2010.11871

Parameters
  • pair_wise_losses (torch.Tensor) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.

  • beta (float) – Inverse temperature parameter. (default = 10)

  • n_iter (int) – Number of iteration. Even number. (default = 200)

Returns

  • torch.Tensor – The loss corresponding to the best permutation of size (batch,).

  • torch.Tensor: A soft permutation matrix.

Available loss functions

PITLossWrapper supports three types of loss function. For “easy” losses, we implement the three types (pairwise point, single-source loss and multi-source loss). For others, we only implement the single-source loss which can be aggregated into both PIT and nonPIT training.

MSE

asteroid.losses.mse.PairwiseMSE(*args: Any, **kwargs: Any) → Any[source]

Measure pairwise mean square error on a batch.

Shape:
  • est_targets : \((batch, nsrc, ...)\).

  • targets: \((batch, nsrc, ...)\).

Returns

torch.Tensor – with shape \((batch, nsrc, nsrc)\)

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.SingleSrcMSE(*args: Any, **kwargs: Any) → Any[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).

  • targets: \((batch, ...)\).

Returns

torch.Tensor – with shape \((batch)\)

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
asteroid.losses.mse.MultiSrcMSE(*args: Any, **kwargs: Any) → Any[source]

Measure mean square error on a batch. Supports both tensors with and without source axis.

Shape:
  • est_targets: \((batch, ...)\).

  • targets: \((batch, ...)\).

Returns

torch.Tensor – with shape \((batch)\)

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'.
>>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)

SDR

asteroid.losses.sdr.PairwiseNegSDR(*args: Any, **kwargs: Any) → Any[source]

Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch.

Parameters
  • sdr_type (str) – choose between snr for plain SNR, sisdr for SI-SDR and sdsdr for SD-SDR [1].

  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.

  • take_log (bool, optional) – by default the log10 of sdr is returned.

Shape:
  • est_targets : \((batch, nsrc, ...)\).

  • targets: \((batch, nsrc, ...)\).

Returns

torch.Tensor – with shape \((batch, nsrc, nsrc)\). Pairwise losses.

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"),
>>>                            pit_from='pairwise')
>>> loss = loss_func(est_targets, targets)
References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.SingleSrcNegSDR(*args: Any, **kwargs: Any) → Any[source]

Base class for single-source negative SI-SDR, SD-SDR and SNR.

Parameters
  • sdr_type (str) – choose between snr for plain SNR, sisdr for SI-SDR and sdsdr for SD-SDR [1].

  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.

  • take_log (bool, optional) – by default the log10 of sdr is returned.

  • reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output.

Shape:
  • est_targets : \((batch, time)\).

  • targets: \((batch, time)\).

Returns

torch.Tensor – with shape \((batch)\) if reduction='none' else [] scalar if reduction='mean'.

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"),
>>>                            pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

asteroid.losses.sdr.MultiSrcNegSDR(*args: Any, **kwargs: Any) → Any[source]

Base class for computing negative SI-SDR, SD-SDR and SNR for a given permutation of source and their estimates.

Parameters
  • sdr_type (str) – choose between snr for plain SNR, sisdr for SI-SDR and sdsdr for SD-SDR [1].

  • zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.

  • take_log (bool, optional) – by default the log10 of sdr is returned.

Shape:
  • est_targets : \((batch, nsrc, time)\).

  • targets: \((batch, nsrc, time)\).

Returns

torch.Tensor – with shape \((batch)\) if reduction='none' else [] scalar if reduction='mean'.

Examples
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"),
>>>                            pit_from='perm_avg')
>>> loss = loss_func(est_targets, targets)
References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

PMSQE

asteroid.losses.pmsqe.SingleSrcPMSQE(*args: Any, **kwargs: Any) → Any[source]

Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.

Parameters
  • window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].

  • window_weight (float, optional) – Correction to the window factor applied.

  • bark_eq (bool, optional) – Whether to apply bark equalization.

  • gain_eq (bool, optional) – Whether to apply gain equalization.

  • sample_rate (int) – Sample rate of the input audio.

References

[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es

Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.

Note

Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.

Examples
>>> import torch
>>> from asteroid_filterbanks import STFTFB, Encoder, transforms
>>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
>>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
>>> ref_spec = transforms.mag(stft(ref))
>>> est_spec = transforms.mag(stft(est))
>>> loss_func = SingleSrcPMSQE()
>>> loss_value = loss_func(est_spec, ref_spec)
>>> # Usage with PITLossWrapper
>>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
>>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000)
>>> ref_spec = transforms.mag(stft(ref))
>>> est_spec = transforms.mag(stft(est))
>>> loss_value = loss_func(ref_spec, est_spec)

STOI

MultiScale Spectral Loss

asteroid.losses.multi_scale_spectral.SingleSrcMultiScaleSpectral(*args: Any, **kwargs: Any) → Any[source]

Measure multi-scale spectral loss as described in [1]

Parameters
  • n_filters (list) – list containing the number of filter desired for each STFT

  • windows_size (list) – list containing the size of the window desired for each STFT

  • hops_size (list) – list containing the size of the hop desired for each STFT

Shape:
  • est_targets : \((batch, time)\).

  • targets: \((batch, time)\).

Returns

torch.Tensor – with shape [batch]

Examples
>>> import torch
>>> targets = torch.randn(10, 32000)
>>> est_targets = torch.randn(10, 32000)
>>> # Using it by itself on a pair of source/estimate
>>> loss_func = SingleSrcMultiScaleSpectral()
>>> loss = loss_func(est_targets, targets)
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> # Using it with PITLossWrapper with sets of source/estimates
>>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(),
>>>                            pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
References

[1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts “DDSP: Differentiable Digital Signal Processing” ICLR 2020.

Deep clustering (Affinity) loss

asteroid.losses.cluster.deep_clustering_loss(embedding, tgt_index, binary_mask=None)[source]

Compute the deep clustering loss defined in [1].

Parameters
  • embedding (torch.Tensor) – Estimated embeddings. Expected shape \((batch, frequency * frame, embedding\_dim)\).

  • tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: \((batch, frequency, frame)\).

  • binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.dsp.vad.ebased_vad.

Returns

torch.Tensor. Deep clustering loss for every batch sample.

Examples
>>> import torch
>>> from asteroid.losses.cluster import deep_clustering_loss
>>> spk_cnt = 3
>>> embedding = torch.randn(10, 5*400, 20)
>>> targets = torch.LongTensor(10, 400, 5).random_(0, spk_cnt)
>>> loss = deep_clustering_loss(embedding, targets)
Reference

[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”

Note

Be careful in viewing the embedding tensors. The target indices tgt_index are of shape \((batch, freq, frames)\). Even if the embedding is of shape \((batch, freq * frames, emb)\), the underlying view should be \((batch, freq, frames, emb)\) and not \((batch, frames, freq, emb)\).

Computing metrics

asteroid.metrics.get_metrics(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False, ignore_metrics_errors=False, filename=None)[source]

Get speech separation/enhancement metrics from mix/clean/estimate.

Parameters
  • mix (np.array) – mixture array.

  • clean (np.array) – reference array.

  • estimate (np.array) – estimate array.

  • sample_rate (int) – sampling rate of the audio clips.

  • metrics_list (Union[List[str], str) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).

  • average (bool) – Return dict([float]) if True, else dict([array]).

  • compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)

  • ignore_metrics_errors (bool) – Whether to ignore errors that occur in computing the metrics. A warning will be printed instead.

  • filename (str, optional) – If computing a metric fails, print this filename along with the exception/warning message for debugging purposes.

Shape:
  • mix: \((D, N)\) or (N, ).

  • clean: \((K\_source, N)\) or (N, ).

  • estimate: \((K\_target, N)\) or (N, ).

Returns

dict – Dictionary with all requested metrics, with ‘input_’ prefix for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average.

Examples
>>> import numpy as np
>>> import pprint
>>> from asteroid.metrics import get_metrics
>>> mix = np.random.randn(1, 16000)
>>> clean = np.random.randn(2, 16000)
>>> est = np.random.randn(2, 16000)
>>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000,
...                            metrics_list='all')
>>> pprint.pprint(metrics_dict)
{'input_pesq': 1.924380898475647,
 'input_sar': -11.67667585294225,
 'input_sdr': -14.88667106190552,
 'input_si_sdr': -52.43849784881705,
 'input_sir': -0.10419427290163795,
 'input_stoi': 0.015112115177091223,
 'pesq': 1.7713886499404907,
 'sar': -11.610963379923195,
 'sdr': -14.527246041125844,
 'si_sdr': -46.26557128489802,
 'sir': 0.4799929272243427,
 'stoi': 0.022023073540350643}