Losses & Metrics¶
Permutation invariant training (PIT) made easy¶
Asteroid supports regular Permutation Invariant Training (PIT), it’s extension using Sinkhorn algorithm (SinkPIT) as well as Mixture Invariant Training (MixIT).
PIT¶
-
class
asteroid.losses.pit_wrapper.
PITLossWrapper
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
torch.nn.Module
Permutation invariant loss wrapper.
- Parameters
loss_func – function with signature (est_targets, targets, **kwargs).
pit_from (str) –
Determines how PIT is applied.
'pw_mtx'
(pairwise matrix): loss_func computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\)'pw_pt'
(pairwise point): loss_func computes the loss for a batch of single source and single estimates (tensors won’t have the source axis). Output shape : \((batch)\). Seeget_pw_losses()
.'perm_avg'
(permutation average): loss_func computes the average loss for a given permutations of the sources and estimates. Output shape : \((batch)\). Seebest_perm_from_perm_avg_loss()
.
In terms of efficiency,
'perm_avg'
is the least efficicient.perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) --> (B, n\_src!)\). perm_reduce can receive **kwargs during forward using the reduce_kwargs argument (dict). If those argument are static, consider defining a small function or using functools.partial. Only used in ‘pw_mtx’ and ‘pw_pt’ pit_from modes.
For each of these modes, the best permutation and reordering will be automatically computed. When either
'pw_mtx'
or'pw_pt'
is used, and the number of sources is larger than three, the hungarian algorithm is used to find the best permutation.- Examples
>>> import torch >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute PIT loss based on pairwise losses >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') >>> loss_val = loss_func(est_sources, sources) >>> >>> # Using reduce >>> def reduce(perm_loss, src): >>> weighted = perm_loss * src.norm(dim=-1, keepdim=True) >>> return torch.mean(weighted, dim=-1) >>> >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx', >>> perm_reduce=reduce) >>> reduce_kwargs = {'src': sources} >>> loss_val = loss_func(est_sources, sources, >>> reduce_kwargs=reduce_kwargs)
-
forward
(est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
- Parameters
est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
Best permutation loss for each batch sample, average over the batch.
The reordered targets estimates if
return_est
is True.torch.Tensor
of shape $(batch, nsrc, …)$.
-
static
get_pw_losses
(loss_func, est_targets, targets, **kwargs)[source]¶ Get pair-wise losses between the training targets and its estimate for a given loss function.
- Parameters
loss_func – function with signature (est_targets, targets, **kwargs) The loss function to get pair-wise losses from.
est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets.
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor of size \((batch)\). There are more efficient ways to compute pair-wise losses using broadcasting.
-
static
best_perm_from_perm_avg_loss
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best permutation from loss function with source axis.
- Parameters
loss_func – function with signature $(est_targets, targets, **kwargs)$ The loss function batch losses from.
est_targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of target estimates.
targets – torch.Tensor. Expected shape $(batch, nsrc, *)$. The batch of training targets.
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
:The indices of the best permutations.
-
static
find_best_perm
(pair_wise_losses, perm_reduce=None, **kwargs)[source]¶ Find the best permutation, given the pair-wise losses.
Dispatch between factorial method if number of sources is small (<3) and hungarian method for more sources. If
perm_reduce
is not None, the factorial method is always used.- Parameters
pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n\_src, n\_src)\). Pairwise losses.perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
**kwargs – additional keyword argument that will be passed to the permutation reduce function.
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
: The indices of the best permutations.
-
static
reorder_source
(source, batch_indices)[source]¶ Reorder sources according to the best permutation.
- Parameters
source (torch.Tensor) – Tensor of shape \((batch, n_src, time)\)
batch_indices (torch.Tensor) – Tensor of shape \((batch, n_src)\). Contains optimal permutation indices for each batch.
- Returns
torch.Tensor
– Reordered sources.
-
static
find_best_perm_factorial
(pair_wise_losses, perm_reduce=None, **kwargs)[source]¶ Find the best permutation given the pair-wise losses by looping through all the permutations.
- Parameters
pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.perm_reduce (Callable) – torch function to reduce permutation losses. Defaults to None (equivalent to mean). Signature of the func (pwl_set, **kwargs) : \((B, n\_src!, n\_src) -> (B, n\_src!)\)
**kwargs – additional keyword argument that will be passed to the permutation reduce function.
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size $(batch,)$.torch.Tensor
: The indices of the best permutations.
MIT Copyright (c) 2018 Kaituo XU. See Original code and License.
-
static
find_best_perm_hungarian
(pair_wise_losses: torch.Tensor)[source]¶ Find the best permutation given the pair-wise losses, using the Hungarian algorithm.
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.Tensor
: The indices of the best permutations.
-
class
asteroid.losses.pit_wrapper.
PITReorder
(loss_func, pit_from='pw_mtx', perm_reduce=None)[source]¶ Bases:
asteroid.losses.pit_wrapper.PITLossWrapper
Permutation invariant reorderer. Only returns the reordered estimates. See :py:class:asteroid.losses.PITLossWrapper.
-
forward
(est_targets, targets, reduce_kwargs=None, **kwargs)[source]¶ Find the best permutation and return the loss.
- Parameters
est_targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of target estimates.
targets – torch.Tensor. Expected shape $(batch, nsrc, …)$. The batch of training targets
return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
reduce_kwargs (dict or None) – kwargs that will be passed to the pairwise losses reduce function (perm_reduce).
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
Best permutation loss for each batch sample, average over the batch.
The reordered targets estimates if
return_est
is True.torch.Tensor
of shape $(batch, nsrc, …)$.
-
MixIT¶
-
class
asteroid.losses.mixit_wrapper.
MixITLossWrapper
(loss_func, generalized=True, reduction='mean')[source]¶ Bases:
torch.nn.Module
Mixture invariant loss wrapper.
- Parameters
loss_func – function with signature (est_targets, targets, **kwargs).
generalized (bool) – Determines how MixIT is applied. If False , apply MixIT for any number of mixtures as soon as they contain the same number of sources (
best_part_mixit()
.) If True (default), apply MixIT for two mixtures, but those mixtures do not necessarly have to contain the same number of sources. Seebest_part_mixit_generalized()
.reduction (string, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
.'none'
: no reduction will be applied,'mean'
: the sum of the output will be divided by the number of elements in the output.
For each of these modes, the best partition and reordering will be automatically computed.
Examples
>>> import torch >>> from asteroid.losses import multisrc_mse >>> mixtures = torch.randn(10, 2, 16000) >>> est_sources = torch.randn(10, 4, 16000) >>> # Compute MixIT loss based on pairwise losses >>> loss_func = MixITLossWrapper(multisrc_mse) >>> loss_val = loss_func(est_sources, mixtures)
- References
[1] Scott Wisdom et al. “Unsupervised sound separation using mixtures of mixtures.” arXiv:2006.12701 (2020)
-
forward
(est_targets, targets, return_est=False, **kwargs)[source]¶ Find the best partition and return the loss.
- Parameters
est_targets – torch.Tensor. Expected shape \((batch, nsrc, *)\). The batch of target estimates.
targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets
return_est – Boolean. Whether to return the estimated mixtures estimates (To compute metrics or to save example).
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
Best partition loss for each batch sample, average over the batch. torch.Tensor(loss_value)
The estimated mixtures (estimated sources summed according to the partition) if return_est is True. torch.Tensor of shape \((batch, nmix, ...)\).
-
static
best_part_mixit
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid for any number of mixtures as soon as they contain the same number of sources.
- Parameters
loss_func – function with signature
(est_targets, targets, **kwargs)
The loss function to get batch losses from.est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indices of the best partition.list
: list of the possible partitions of the sources.
-
static
best_part_mixit_generalized
(loss_func, est_targets, targets, **kwargs)[source]¶ Find best partition of the estimated sources that gives the minimum loss for the MixIT training paradigm in [1]. Valid only for two mixtures, but those mixtures do not necessarly have to contain the same number of sources e.g the case where one mixture is silent is allowed..
- Parameters
loss_func – function with signature
(est_targets, targets, **kwargs)
The loss function to get batch losses from.est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets (mixtures).
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.LongTensor
: The indexes of the best permutations.list
: list of the possible partitions of the sources.
-
static
loss_set_from_parts
(loss_func, est_targets, targets, parts, **kwargs)[source]¶ Common loop between both best_part_mixit
-
static
reorder_source
(est_targets, targets, min_loss_idx, parts)[source]¶ Reorder sources according to the best partition.
- Parameters
est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
targets – torch.Tensor. Expected shape \((batch, nmix, ...)\). The batch of training targets.
min_loss_idx – torch.LongTensor. The indexes of the best permutations.
parts – list of the possible partitions of the sources.
- Returns
torch.Tensor
– Reordered sources of shape \((batch, nmix, time)\).
SinkPIT¶
-
class
asteroid.losses.sinkpit_wrapper.
SinkPITLossWrapper
(loss_func, n_iter=200, hungarian_validation=True)[source]¶ Bases:
torch.nn.Module
Permutation invariant loss wrapper.
- Parameters
loss_func
computes pairwise losses and returns a torch.Tensor of shape \((batch, n\_src, n\_src)\). Each element \((batch, i, j)\) corresponds to the loss between \(targets[:, i]\) and \(est\_targets[:, j]\) It evaluates an approximate value of the PIT loss using Sinkhorn’s iterative algorithm. Seebest_softperm_sinkhorn()
and http://arxiv.org/abs/2010.11871- Examples
>>> import torch >>> import pytorch_lightning as pl >>> from asteroid.losses import pairwise_neg_sisdr >>> sources = torch.randn(10, 3, 16000) >>> est_sources = torch.randn(10, 3, 16000) >>> # Compute SinkPIT loss based on pairwise losses >>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr) >>> loss_val = loss_func(est_sources, sources) >>> # A fixed temperature parameter `beta` (=10) is used >>> # unless a cooling callback is set. The value can be >>> # dynamically changed using a cooling callback module as follows. >>> model = NeuralNetworkModel() >>> optimizer = optim.Adam(model.parameters(), lr=1e-3) >>> dataset = YourDataset() >>> loader = data.DataLoader(dataset, batch_size=16) >>> system = System( >>> model, >>> optimizer, >>> loss_func=SinkPITLossWrapper(pairwise_neg_sisdr), >>> train_loader=loader, >>> val_loader=loader, >>> ) >>> >>> trainer = pl.Trainer( >>> max_epochs=100, >>> callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)], >>> ) >>> >>> trainer.fit(system)
-
forward
(est_targets, targets, return_est=False, **kwargs)[source]¶ Evaluate the loss using Sinkhorn’s algorithm.
- Parameters
est_targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of target estimates.
targets – torch.Tensor. Expected shape \((batch, nsrc, ...)\). The batch of training targets
return_est – Boolean. Whether to return the reordered targets estimates (To compute metrics or to save example).
**kwargs – additional keyword argument that will be passed to the loss function.
- Returns
- Best permutation loss for each batch sample, average over
the batch. torch.Tensor(loss_value)
- The reordered targets estimates if return_est is True.
torch.Tensor of shape \((batch, nsrc, ...)\).
-
static
best_softperm_sinkhorn
(pair_wise_losses, beta=10, n_iter=200)[source]¶ Compute an approximate PIT loss using Sinkhorn’s algorithm. See http://arxiv.org/abs/2010.11871
- Parameters
pair_wise_losses (
torch.Tensor
) – Tensor of shape \((batch, n_src, n_src)\). Pairwise losses.beta (float) – Inverse temperature parameter. (default = 10)
n_iter (int) – Number of iteration. Even number. (default = 200)
- Returns
torch.Tensor
– The loss corresponding to the best permutation of size (batch,).torch.Tensor
: A soft permutation matrix.
Available loss functions¶
PITLossWrapper
supports three types of loss function. For “easy” losses,
we implement the three types (pairwise point, single-source loss and multi-source loss).
For others, we only implement the single-source loss which can be aggregated
into both PIT and nonPIT training.
MSE¶
-
asteroid.losses.mse.
PairwiseMSE
(*args: Any, **kwargs: Any) → Any[source]¶ Measure pairwise mean square error on a batch.
- Shape:
est_targets : \((batch, nsrc, ...)\).
targets: \((batch, nsrc, ...)\).
- Returns
torch.Tensor
– with shape \((batch, nsrc, nsrc)\)
- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
SingleSrcMSE
(*args: Any, **kwargs: Any) → Any[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
est_targets: \((batch, ...)\).
targets: \((batch, ...)\).
- Returns
torch.Tensor
– with shape \((batch)\)
- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
-
asteroid.losses.mse.
MultiSrcMSE
(*args: Any, **kwargs: Any) → Any[source]¶ Measure mean square error on a batch. Supports both tensors with and without source axis.
- Shape:
est_targets: \((batch, ...)\).
targets: \((batch, ...)\).
- Returns
torch.Tensor
– with shape \((batch)\)
- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
SDR¶
-
asteroid.losses.sdr.
PairwiseNegSDR
(*args: Any, **kwargs: Any) → Any[source]¶ Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch.
- Parameters
- Shape:
est_targets : \((batch, nsrc, ...)\).
targets: \((batch, nsrc, ...)\).
- Returns
torch.Tensor
– with shape \((batch, nsrc, nsrc)\). Pairwise losses.
- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"), >>> pit_from='pairwise') >>> loss = loss_func(est_targets, targets)
- References
[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
-
asteroid.losses.sdr.
SingleSrcNegSDR
(*args: Any, **kwargs: Any) → Any[source]¶ Base class for single-source negative SI-SDR, SD-SDR and SNR.
- Parameters
sdr_type (str) – choose between
snr
for plain SNR,sisdr
for SI-SDR andsdsdr
for SD-SDR [1].zero_mean (bool, optional) – by default it zero mean the target and estimate before computing the loss.
take_log (bool, optional) – by default the log10 of sdr is returned.
reduction (string, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
.'none'
: no reduction will be applied,'mean'
: the sum of the output will be divided by the number of elements in the output.
- Shape:
est_targets : \((batch, time)\).
targets: \((batch, time)\).
- Returns
torch.Tensor
– with shape \((batch)\) ifreduction='none'
else [] scalar ifreduction='mean'
.
- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
- References
[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
-
asteroid.losses.sdr.
MultiSrcNegSDR
(*args: Any, **kwargs: Any) → Any[source]¶ Base class for computing negative SI-SDR, SD-SDR and SNR for a given permutation of source and their estimates.
- Parameters
- Shape:
est_targets : \((batch, nsrc, time)\).
targets: \((batch, nsrc, time)\).
- Returns
torch.Tensor
– with shape \((batch)\) ifreduction='none'
else [] scalar ifreduction='mean'
.
- Examples
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"), >>> pit_from='perm_avg') >>> loss = loss_func(est_targets, targets)
- References
[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.
PMSQE¶
-
asteroid.losses.pmsqe.
SingleSrcPMSQE
(*args: Any, **kwargs: Any) → Any[source]¶ Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) as described in [1]. This version is only designed for 16 kHz (512 length DFT). Adaptation to 8 kHz could be done by changing the parameters of the class (see Tensorflow implementation). The SLL, frequency and gain equalization are applied in each sequence independently.
- Parameters
window_name (str) – Select the used window function for the correct factor to be applied. Defaults to sqrt hanning window. Among [‘rect’, ‘hann’, ‘sqrt_hann’, ‘hamming’, ‘flatTop’].
window_weight (float, optional) – Correction to the window factor applied.
bark_eq (bool, optional) – Whether to apply bark equalization.
gain_eq (bool, optional) – Whether to apply gain equalization.
sample_rate (int) – Sample rate of the input audio.
- References
[1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado ‘A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality’, IEEE Signal Processing Letters, 2018. Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es
Copyright 2019: University of Granada, Signal Processing, Multimedia Transmission and Speech/Audio Technologies (SigMAT) Group.
Note
Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) algorithm, this function consists of two regularization factors : the symmetrical and asymmetrical distortion in the loudness domain.
- Examples
>>> import torch >>> from asteroid_filterbanks import STFTFB, Encoder, transforms >>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE >>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) >>> # Usage by itself >>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) >>> ref_spec = transforms.mag(stft(ref)) >>> est_spec = transforms.mag(stft(est)) >>> loss_func = SingleSrcPMSQE() >>> loss_value = loss_func(est_spec, ref_spec) >>> # Usage with PITLossWrapper >>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt') >>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000) >>> ref_spec = transforms.mag(stft(ref)) >>> est_spec = transforms.mag(stft(est)) >>> loss_value = loss_func(ref_spec, est_spec)
STOI¶
MultiScale Spectral Loss¶
-
asteroid.losses.multi_scale_spectral.
SingleSrcMultiScaleSpectral
(*args: Any, **kwargs: Any) → Any[source]¶ Measure multi-scale spectral loss as described in [1]
- Parameters
- Shape:
est_targets : \((batch, time)\).
targets: \((batch, time)\).
- Returns
torch.Tensor
– with shape [batch]
- Examples
>>> import torch >>> targets = torch.randn(10, 32000) >>> est_targets = torch.randn(10, 32000) >>> # Using it by itself on a pair of source/estimate >>> loss_func = SingleSrcMultiScaleSpectral() >>> loss = loss_func(est_targets, targets)
>>> import torch >>> from asteroid.losses import PITLossWrapper >>> targets = torch.randn(10, 2, 32000) >>> est_targets = torch.randn(10, 2, 32000) >>> # Using it with PITLossWrapper with sets of source/estimates >>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(), >>> pit_from='pw_pt') >>> loss = loss_func(est_targets, targets)
- References
[1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts “DDSP: Differentiable Digital Signal Processing” ICLR 2020.
Deep clustering (Affinity) loss¶
-
asteroid.losses.cluster.
deep_clustering_loss
(embedding, tgt_index, binary_mask=None)[source]¶ Compute the deep clustering loss defined in [1].
- Parameters
embedding (torch.Tensor) – Estimated embeddings. Expected shape \((batch, frequency * frame, embedding\_dim)\).
tgt_index (torch.Tensor) – Dominating source index in each TF bin. Expected shape: \((batch, frequency, frame)\).
binary_mask (torch.Tensor) – VAD in TF plane. Bool or Float. See asteroid.dsp.vad.ebased_vad.
- Returns
torch.Tensor. Deep clustering loss for every batch sample.
- Examples
>>> import torch >>> from asteroid.losses.cluster import deep_clustering_loss >>> spk_cnt = 3 >>> embedding = torch.randn(10, 5*400, 20) >>> targets = torch.LongTensor(10, 400, 5).random_(0, spk_cnt) >>> loss = deep_clustering_loss(embedding, targets)
- Reference
[1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey “ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING”
Note
Be careful in viewing the embedding tensors. The target indices
tgt_index
are of shape \((batch, freq, frames)\). Even if the embedding is of shape \((batch, freq * frames, emb)\), the underlying view should be \((batch, freq, frames, emb)\) and not \((batch, frames, freq, emb)\).
Computing metrics¶
-
asteroid.metrics.
get_metrics
(mix, clean, estimate, sample_rate=16000, metrics_list='all', average=True, compute_permutation=False, ignore_metrics_errors=False, filename=None)[source]¶ Get speech separation/enhancement metrics from mix/clean/estimate.
- Parameters
mix (np.array) – mixture array.
clean (np.array) – reference array.
estimate (np.array) – estimate array.
sample_rate (int) – sampling rate of the audio clips.
metrics_list (Union[List[str], str) – List of metrics to compute. Defaults to ‘all’ ([‘si_sdr’, ‘sdr’, ‘sir’, ‘sar’, ‘stoi’, ‘pesq’]).
average (bool) – Return dict([float]) if True, else dict([array]).
compute_permutation (bool) – Whether to compute the permutation on estimate sources for the output metrics (default False)
ignore_metrics_errors (bool) – Whether to ignore errors that occur in computing the metrics. A warning will be printed instead.
filename (str, optional) – If computing a metric fails, print this filename along with the exception/warning message for debugging purposes.
- Shape:
mix: \((D, N)\) or (N, ).
clean: \((K\_source, N)\) or (N, ).
estimate: \((K\_target, N)\) or (N, ).
- Returns
dict – Dictionary with all requested metrics, with ‘input_’ prefix for metrics at the input (mixture against clean), no prefix at the output (estimate against clean). Output format depends on average.
- Examples
>>> import numpy as np >>> import pprint >>> from asteroid.metrics import get_metrics >>> mix = np.random.randn(1, 16000) >>> clean = np.random.randn(2, 16000) >>> est = np.random.randn(2, 16000) >>> metrics_dict = get_metrics(mix, clean, est, sample_rate=8000, ... metrics_list='all') >>> pprint.pprint(metrics_dict) {'input_pesq': 1.924380898475647, 'input_sar': -11.67667585294225, 'input_sdr': -14.88667106190552, 'input_si_sdr': -52.43849784881705, 'input_sir': -0.10419427290163795, 'input_stoi': 0.015112115177091223, 'pesq': 1.7713886499404907, 'sar': -11.610963379923195, 'sdr': -14.527246041125844, 'si_sdr': -46.26557128489802, 'sir': 0.4799929272243427, 'stoi': 0.022023073540350643}