Utils¶
Parser utils¶
Asteroid has its own argument parser (built on argparse
) that handles
dict-like structure, created from a config YAML file.
-
asteroid.utils.parser_utils.
prepare_parser_from_dict
(dic, parser=None)[source]¶ Prepare an argparser from a dictionary.
- Parameters
dic (dict) – Two-level config dictionary with unique bottom-level keys.
parser (argparse.ArgumentParser, optional) – If a parser already exists, add the keys from the dictionary on the top of it.
- Returns
argparse.ArgumentParser – Parser instance with groups corresponding to the first level keys and arguments corresponding to the second level keys with default values given by the values.
-
asteroid.utils.parser_utils.
str_int_float
(value)[source]¶ Type to convert strings to int, float (in this order) if possible.
- Parameters
value (str) – Value to convert.
- Returns
int, float, str – Converted value.
-
asteroid.utils.parser_utils.
str2bool
(value)[source]¶ Type to convert strings to Boolean (returns input if not boolean)
-
asteroid.utils.parser_utils.
str2bool_arg
(value)[source]¶ Argparse type to convert strings to Boolean
-
asteroid.utils.parser_utils.
isfloat
(value)[source]¶ Computes whether value can be cast to a float.
- Parameters
value (str) – Value to check.
- Returns
bool – Whether value can be cast to a float.
-
asteroid.utils.parser_utils.
isint
(value)[source]¶ Computes whether value can be cast to an int
- Parameters
value (str) – Value to check.
- Returns
bool – Whether value can be cast to an int.
-
asteroid.utils.parser_utils.
parse_args_as_dict
(parser, return_plain_args=False, args=None)[source]¶ Get a dict of dicts out of process parser.parse_args()
Top-level keys corresponding to groups and bottom-level keys corresponding to arguments. Under ‘main_args’, the arguments which don’t belong to a argparse group (i.e main arguments defined before parsing from a dict) can be found.
- Parameters
parser (argparse.ArgumentParser) – ArgumentParser instance containing groups. Output of prepare_parser_from_dict.
return_plain_args (bool) – Whether to return the output or parser.parse_args().
args (list) – List of arguments as read from the command line. Used for unit testing.
- Returns
dict – Dictionary of dictionaries containing the arguments. Optionally the direct output parser.parse_args().
Torch utils¶
-
asteroid.utils.torch_utils.
to_cuda
(tensors)[source]¶ Transfer tensor, dict or list of tensors to GPU.
- Parameters
tensors (
torch.Tensor
, list or dict) – May be a single, a list or a dictionary of tensors.- Returns
torch.Tensor
– Same as input but transferred to cuda. Goes through lists and dicts and transfers the torch.Tensor to cuda. Leaves the rest untouched.
-
asteroid.utils.torch_utils.
tensors_to_device
(tensors, device)[source]¶ Transfer tensor, dict or list of tensors to device.
- Parameters
tensors (
torch.Tensor
) – May be a single, a list or a dictionary of tensors.( (device) – class: torch.device): the device where to place the tensors.
- Returns
Union [
torch.Tensor
, list, tuple, dict] – Same as input but transferred to device. Goes through lists and dicts and transfers the torch.Tensor to device. Leaves the rest untouched.
-
asteroid.utils.torch_utils.
get_device
(tensor_or_module, default=None)[source]¶ Get the device of a tensor or a module.
- Parameters
tensor_or_module (Union[torch.Tensor, torch.nn.Module]) – The object to get the device from. Can be a
torch.Tensor
, atorch.nn.Module
, or anything else that has adevice
attribute or aparameters() -> Iterator[torch.Tensor]
method.default (Optional[Union[str, torch.device]]) – If the device can not be determined, return this device instead. If
None
(the default), raise aTypeError
instead.
- Returns
torch.device – The device that
tensor_or_module
is on.
-
asteroid.utils.torch_utils.
is_tracing
()[source]¶ Returns
True
in tracing (if a function is called during the tracing of code withtorch.jit.trace
) andFalse
otherwise.
-
asteroid.utils.torch_utils.
script_if_tracing
(fn)[source]¶ Compiles
fn
when it is first called during tracing.torch.jit.script
has a non-negligible start up time when it is first called due to lazy-initializations of many compiler builtins. Therefore you should not use it in library code. However, you may want to have parts of your library work in tracing even if they use control flow. In these cases, you should use@torch.jit.script_if_tracing
to substitute fortorch.jit.script
.- Parameters
fn – A function to compile.
- Returns
If called during tracing, a
ScriptFunction
created by ` torch.jit.script` is returned. Otherwise, the original functionfn
is returned.
-
asteroid.utils.torch_utils.
pad_x_to_y
(x: torch.Tensor, y: torch.Tensor, axis: int = -1) → torch.Tensor[source]¶ Right-pad or right-trim first argument to have same size as second argument
- Parameters
x (torch.Tensor) – Tensor to be padded.
y (torch.Tensor) – Tensor to pad x to.
axis (int) – Axis to pad on.
- Returns
torch.Tensor, x padded to match y’s shape.
-
asteroid.utils.torch_utils.
load_state_dict_in
(state_dict, model)[source]¶ - Strictly loads state_dict in model, or the next submodel.
Useful to load standalone model after training it with System.
- Parameters
state_dict (OrderedDict) – the state_dict to load.
model (torch.nn.Module) – the model to load it into
- Returns
torch.nn.Module – model with loaded weights.
Note
Keys in a state_dict look like
object1.object2.layer_name.weight.etc
We first try to load the model in the classic way. If this fail we removes the first left part of the key to obtainobject2.layer_name.weight.etc
. Blindly loading withstrictly=False
should be done with some logging of the missing keys in the state_dict and the model.
-
asteroid.utils.torch_utils.
are_models_equal
(model1, model2)[source]¶ Check for weights equality between models.
- Parameters
model1 (nn.Module) – model instance to be compared.
model2 (nn.Module) – second model instance to be compared.
- Returns
bool – Whether all model weights are equal.
-
asteroid.utils.torch_utils.
jitable_shape
(tensor)[source]¶ Gets shape of
tensor
astorch.Tensor
type for jit compilerNote
Returning
tensor.shape
oftensor.size()
directly is not torchscript compatible as return type would not be supported.- Parameters
tensor (torch.Tensor) – Tensor
- Returns
torch.Tensor – Shape of
tensor
Hub utils¶
-
asteroid.utils.hub_utils.
cached_download
(filename_or_url)[source]¶ Download from URL and cache the result in ASTEROID_CACHE.
- Parameters
filename_or_url (str) – Name of a model as named on the Zenodo Community page (ex:
"mpariente/ConvTasNet_WHAM!_sepclean"
), or model id from the Hugging Face model hub (ex:"julien-c/DPRNNTasNet-ks16_WHAM_sepclean"
), or a URL to a model file (ex:"https://zenodo.org/.../model.pth"
), or a filename that exists locally (ex:"local/tmp_model.pth"
)- Returns
str, normalized path to the downloaded (or not) model
Generic utils¶
-
asteroid.utils.generic_utils.
has_arg
(fn, name)[source]¶ Checks if a callable accepts a given keyword argument.
- Parameters
fn (callable) – Callable to inspect.
name (str) – Check if
fn
can be called withname
as a keyword argument.
- Returns
bool – whether
fn
accepts aname
keyword argument.
-
asteroid.utils.generic_utils.
flatten_dict
(d, parent_key='', sep='_')[source]¶ Flattens a dictionary into a single-level dictionary while preserving parent keys. Taken from SO
-
asteroid.utils.generic_utils.
average_arrays_in_dic
(dic)[source]¶ Take average of numpy arrays in a dictionary.
- Parameters
dic (dict) – Input dictionary to take average from
- Returns
dict – New dictionary with array averaged.
-
asteroid.utils.generic_utils.
get_wav_random_start_stop
(signal_len, desired_len=32000)[source]¶ Get indexes for a chunk of signal of a given length.
-
asteroid.utils.generic_utils.
unet_decoder_args
(encoders, *, skip_connections)[source]¶ Get list of decoder arguments for upsampling (right) side of a symmetric u-net, given the arguments used to construct the encoder.
- Parameters
encoders (tuple of length N of tuples of (in_chan, out_chan, kernel_size, stride, padding)) – List of arguments used to construct the encoders
skip_connections (bool) – Whether to include skip connections in the calculation of decoder input channels.
- Returns
tuple of length N of tuples of (in_chan, out_chan, kernel_size, stride, padding) – Arguments to be used to construct decoders