egrecho.utils.apply#

egrecho.utils.apply.apply_to_collection(data, dtype, function, *args, wrong_dtype=None, include_none=True, allow_frozen=False, **kwargs)[source]#

Recursively applies a function to all elements of a certain dtype.

Parameters:
  • data (Any) -- the collection to apply the function to

  • dtype (Union[type, Any, Tuple[Union[type, Any]]]) -- the given function will be applied to all elements of this dtype

  • function (Callable) -- the function to apply

  • *args (Any) -- positional arguments (will be forwarded to calls of function)

  • wrong_dtype (Union[type, Tuple[type, ...], None]) -- the given function won’t be applied if this type is specified and the given collections is of the wrong_dtype even if it is of type dtype

  • include_none (bool) -- Whether to include an element if the output of function is None.

  • allow_frozen (bool) -- Whether not to error upon encountering a frozen dataclass instance.

  • **kwargs (Any) -- keyword arguments (will be forwarded to calls of function)

Return type:

Any

Returns:

The resulting collection

egrecho.utils.common#

egrecho.utils.common.alt_none(item, alt_item)[source]#

Replace None with alt_item.

Return type:

Any

egrecho.utils.common.is_in_range(val, max_val=None, min_val=None)[source]#

Value in range judging.

Range is close interval (e.g., [1, 2]), If the boundary is None, skip that condition.

Parameters:
  • val -- value to be judged.

  • max_val (Optional[Any], optional) -- Defaults to None.

  • min_val (Optional[Any], optional) -- Defaults to None.

Return type:

bool

Returns:

bool

class egrecho.utils.common.ObjectDict[source]#

Bases: dict

Makes a dictionary behave like an object, with attribute-style access.

Here are some examples of how it can be used:

o = ObjectDict(my_dict)
# or like this:
o = ObjectDict(samples=samples, sample_rate=sample_rate)
# Attribute-style access
samples = o.samples
# Dict-style access
samples = o["samples"]
class egrecho.utils.common.DataclassSerialMixin[source]#

Bases: object

From/to dict mixin of dataclass.

Example

>>> from dataclasses import dataclass
>>> @dataclass
... class Config(DataclassSerialMixin):
...   a: int = 123
...   b: str = "456"
...
>>> config = Config(filt_type='orig')
>>> config
Config(a=123, b='456')
>>> config.to_dict()
{'a': 123, 'b': '456'}
>>> config_ = Config.from_dict({"a": 123, "b": 456})
>>> config_
Config(a=123, b='456')
>>> assert config == config_
to_dict(dict_factory=<class 'dict'>, filt_type=None, init_field_only=True, save_dc_types=False)[source]#

Serializes this dataclass to a dict.

Return type:

dict

classmethod from_dict(obj, drop_extra_fields=None)[source]#

Parses an instance of cls from the given dict.

NOTE: If the decode_into_subclasses class attribute is set to True (or if decode_into_subclasses=True was passed in the class definition), then if there are keys in the dict that aren’t fields of the dataclass, this will decode the dict into an instance the first subclass of cls which has all required field names present in the dictionary.

  • Passing drop_extra_fields=None (default) will use the class attribute

    described above.

  • Passing drop_extra_fields=True will decode the dict into an instance

    of cls and drop the extra keys in the dict. Passing drop_extra_fields=False forces the above-mentioned behaviour.

class egrecho.utils.common.GenericSerialMixin[source]#

Bases: object

Serialization mixin of common class with config attribute.

to_dict(**kwargs)[source]#

Returns object’s configuration to config dictionary

Return type:

dict

classmethod from_dict(config)[source]#

Instantiates object using DictConfig-based configuration

class egrecho.utils.common.SaveLoadMixin[source]#

Bases: object

Save_to/load_from repo.

classmethod fetch_from(dirpath, *args, **kwargs)[source]#

Fetch from a repo.

egrecho.utils.common.fields_init_var(class_or_instance)[source]#

Return a tuple describing the InitVar fields of this dataclass.

Modified from: https://docs.python.org/3/library/dataclasses.html#dataclasses.fields

Accepts a dataclass or an instance of one. Tuple elements are of type Field.

egrecho.utils.common.batched(iterable, n)[source]#

Batch data into tuples of length n. The last batch may be shorter.

egrecho.utils.common.camelcase_to_snakecase(name)[source]#

Convert camel-case string to snake-case.

egrecho.utils.common.snakecase_to_camelcase(name)[source]#

Convert snake-case string to camel-case string.

egrecho.utils.common.get_diff_dict(curr_dict, src_dict)[source]#

Compare two dictionaries and return a new dictionary containing the differing key-value pairs.

Parameters:
  • curr_dict (Dict) -- The current dictionary to compare.

  • src_dict (Dict) -- The source dictionary to compare against.

Returns:

A dictionary containing the key-value pairs that differ between src_dict and curr_dict.

Return type:

Dict

Example

>>> src_dict = {"name": "John", "age": 30}
>>> curr_dict = {"name": "John", "age": 35, "city": "New York"}
>>> diff_dict = get_diff_dict(curr_dict, src_dict)
{'age': 35, 'city': 'New York'}
egrecho.utils.common.del_default_dict(data, defaults, recurse=False)[source]#

Removes key-value pairs from a dictionary that match default values.

Parameters:
  • data (Dict) -- The dictionary to remove default values from.

  • defaults (Dict) -- The dictionary containing default values to compare against.

  • recurse (bool) -- recursively processes. Defaults to False.

Returns:

A dictionary with default values removed.

Return type:

Dict

Example

>>> data = {"name": "John", "age": 30, "address": {"city": "New York", "zip": 10001}}
>>> defaults = {"name": "John", "age": 30, "address": {"city": "Unknown", "zip": None}}
>>> cleaned_data = del_default_dict(data, defaults)
>>> cleaned_data
{'address': {'city': 'New York', 'zip': 10001}}

Note

This function modifies the data dictionary in place and also returns it. If recurse=True it recursively processes nested dictionaries.

egrecho.utils.common.dict_union(*dicts, recurse=True, sort_keys=False, dict_factory=<class 'dict'>)[source]#

Combine multiple dictionaries into the first one.

Parameters:
  • *dicts (dict) -- One or more dictionaries to combine.

  • recurse (bool, optional) -- If True, also recursively combine nested dictionaries.

  • sort_keys (bool, optional) -- If True, sort the keys alphabetically in the resulting dictionary.

  • dict_factory (callable, optional) -- A callable that creates the output dictionary (default is dict).

Returns:

A new dictionary containing the union of all input dictionaries.

Return type:

dict

Example

>>> from collections import OrderedDict
>>> a = OrderedDict(a=1, b=2, c=3)
>>> b = OrderedDict(c=5, d=6, e=7)
>>> dict_union(a, b, dict_factory=OrderedDict)
OrderedDict([('a', 1), ('b', 2), ('c', 5), ('d', 6), ('e', 7)])
>>> a = OrderedDict(a=1, b=OrderedDict(c=2, d=3))
>>> b = OrderedDict(a=2, b=OrderedDict(c=3, e=6))
>>> dict_union(a, b, dict_factory=OrderedDict)
OrderedDict([('a', 2), ('b', OrderedDict([('c', 3), ('d', 3), ('e', 6)]))])
egrecho.utils.common.list2tuple(func)[source]#

Transfer the list in input parameter to hashable tuple, it is useful for lru_cache.

NOTE: Don’t support nest structure.

Example

>>> def get_input(*args, **kwargs):
...     return args, kwargs
>>> @list2tuple
>>> def get_input_wrapper(*args, **kwargs):
...     return args, kwargs
>>> arg_input = ([1,2,3], 'foo')
>>> kwargs_input = {'bar': True, 'action': ["go", "leave", "break"], 'action1':['buy', ['sell', 'argue']]}
>>> get_input(*arg_input, **kwargs_input)
(([1, 2, 3], 'foo'),
{'bar': True,
'action': ['go', 'leave', 'break'],
'action1': ['buy', ['sell', 'argue']]})
>>> get_input_wrapper(*arg_input, **kwargs_input)
(((1, 2, 3), 'foo'),
{'bar': True,
'action': ('go', 'leave', 'break'),
'action1': ('buy', ['sell', 'argue'])})
egrecho.utils.common.omegaconf_handler(config, omegaconf_resolve=True)[source]#

If omegaconf is available, creates DictConfig and resolves it when omegaconf_resolve=True.

Parameters:
  • config (Optional[Dict[Any, Any]]) -- lodaded kwargs dict.

  • omegaconf_resolve -- If False, returns DictConfig else returns the dict resolved from DictConfig.

egrecho.utils.constants#

egrecho.utils.cuda_utils#

egrecho.utils.cuda_utils.avoid_float16_autocast_context()[source]#

If the current autocast context is float16, cast it to bfloat16 if available (unless we’re in jit) or float32

egrecho.utils.cuda_utils.maybe_set_cuda_expandable_segments(enabled)[source]#

Configures PyTorch memory allocator to expand existing allocated segments instead of re-allocating them when tensor shape grows. This can help speed up the training when sequence length and/or batch size change often, and makes GPU more robust towards OOM.

See here for more details: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf

egrecho.utils.cuda_utils.set_to_cuda(modules)[source]#

Send modules to gpu.

Parameters:

modules -- nn.module or a list of module

egrecho.utils.cuda_utils.get_current_device()[source]#

Returns currently selected device (gpu/cpu). If cuda available, return gpu, otherwise return cpu.

Return type:

device

egrecho.utils.cuda_utils.to_device(data, device_object=None)[source]#

Move a tensor or collection of tensors to a specified device by inferring the device from another object.

Parameters:
  • data (Any) -- A tensor, collection of tensors, or anything with a .to(...) method.

  • device_object (Union[torch.device, str, int, torch.Tensor, torch.nn.Module], optional) --

    The target device. Can be one of the following:

    • torch.device, str, int, The target device.

    • torch.nn.Module: Infer the device from a module.

    • torch.Tensor: Infer the device from a tensor. (default: the current defualt cuda device.)

Return type:

Any

Returns:

The same collection with all contained tensors residing on the target device.

egrecho.utils.cuda_utils.parse_gpus(gpus)[source]#

Parse gpus option.

Parameters:

gpus (Union[str, int, Sequence[int]]) --

GPUs used for training on this machine.

-1: all; N: [0,N); “1,2”: comma-seperated; “0” means invalid while “0,” specify 1 gpu with id:1.

Return type:

Optional[List[int]]

egrecho.utils.cuda_utils.parse_gpus_opt(gpus)[source]#

Similar to parse_gpus but combines auto choose a single GPU.

Parameters:

gpus (Optional[Union[str, int]]) --

What GPUs should be used:

  • case 0: comma-separated list, e.g., “1,” or “0,1” means specified id(s).

  • case 1: a single int (str) negative number (-1) means all visible devices [0, N-1].

  • case 2: ‘’ or None or 0 returns None means no GPU.

  • case 3: a single int (str) number equals 1 means auto choose a spare GPU.

  • case 4: a single int (str) number n greater than 1 returns [0, n-1].

Returns:

A list of GPU IDs or None.

Return type:

Optional[List[int]]

egrecho.utils.cuda_utils.parse_gpu_id(gpu_id='auto')[source]#

Parse single gpu id option.

Parameters:

gpu_id (Optional[Union[str, int]]) --

select which GPU:

  • case 0: “auto”, auto select spare gpu.

  • case 1: a single int (str) negative number (-1) means cpu.

  • case 2: a single int (str) positive number means specified id.

  • case 3: ‘’ or None returns None, which means defualt behaviour in same case, e.g., torch.load(…)

  • case 4: other strings, e.g., “cuda:1”

Return type:

Union[str, int, None]

egrecho.utils.cuda_utils.synchronize()[source]#

Similar to cuda.synchronize(). Waits for all kernels in all streams on a CUDA device to complete.

egrecho.utils.cuda_utils.peak_cuda_memory()[source]#

Return the peak gpu memory statistics.

Returns:

the allocated CUDA memory max_cached (float): the cached CUDA memory

Return type:

max_alloc (float)

egrecho.utils.cuda_utils.release_memory(*objects)[source]#

Triggers garbage collection and Releases cuda cache memory.

This function sets the inputs to None, triggers garbage collection to release CPU memory references, and attempts to clear GPU memory cache.

Parameters:

*objects -- Variable number of objects to release.

Returns:

A list of None values, with the same length as the input objects.

Return type:

List[None]

Example

>>> import torch
>>> a = torch.ones(1024, 1024).cuda()
>>> b = torch.ones(1024, 1024).cuda()
>>> a, b = release_memory(a, b)
```
class egrecho.utils.cuda_utils.AutoGPUMode(value)[source]#

Bases: str, Enum

An enumeration.

class egrecho.utils.cuda_utils.GPUManager(addtional_qargs=None, mode=AutoGPUMode.MAX_MEM)[source]#

Bases: object

This class enables the automated selection of the most available GPU or another based on specified mode.

Parameters:
  • addtional_qargs (Optional) -- Additional arguments passed to nvidia-smi.

  • mode (AutoGPUMode) -- mode for GPU selection. Defaults to MAX_MEM (max-free) memory.

Example:

import os,torch
os.environ["CUDA_VISIBLE_DEVICES"]="1,2"
gm = GPUManager()
torch_device = gm.auto_choice()

or

torch_device = GPUManager.detect()


a = torch.randn(1,1000)
a.to(torch_device)
new_query()[source]#

Running the nvidia-smi command and organizing the results as a list of dictionaries containing information searched from nvidia-smi.

classmethod detect(mode=AutoGPUMode.MAX_MEM)[source]#

A classmethod calls :meth`auto_choice` method to select a GPU and returns dveice id.

Return type:

int

auto_choice(mode=None)[source]#

Auto choice a GPU ID based on specified mode.

Parameters:

mode (str) -- The mode for selecting the GPU.

Return type:

int

Returns:

device id.

egrecho.utils.cuda_utils.device2gpu_id(device_id)[source]#

Given the device index and get the unmasked real GPU ID.

Return type:

str

egrecho.utils.cuda_utils.num_gpus()#

Get visible gpu number.

Return type:

int

egrecho.utils.cuda_utils.is_cuda_available()[source]#

check cuda and leave cuda uninitialized.

egrecho.utils.cuda_utils.patch_nvml_check_env()[source]#

A context manager that patch PYTORCH_NVML_BASED_CUDA_CHECK=1, and restore finally.

class egrecho.utils.cuda_utils.NVMLDeviceCount[source]#

Bases: object

A tool for nvml-based cuda check for torch < 2.0 which won’t trigger the drivers and leave cuda uninitialized.

Coppied from pytorch, see:

https://github.com/pytorch/pytorch/pull/84879

egrecho.utils.data_utils#

egrecho.utils.data_utils.split_sequence(seq, split_num, mode='batch', shuffle=False, drop_last=False)[source]#

Split a sequence into num_splits equal parts. The element order can be randomized. Raises a ValueError if split_num is larger than len(seq). Support mode of ‘shard’ or ‘batch’. If ‘batch’, the splits lists as original sequence else shard the original sequence, e.g., for spliting [0, 1 , 2, 3] to 2 parts, shard mode result: [[0, 2], [1, 3]] while batch mode reult: [[0, 1], [2, 3]].

Parameters:
  • seq (Sequence) -- Input iterable.

  • num_splits (int) -- Split num.

  • mode (str) -- (‘shard’, ‘batch’)

  • shuffle (bool) -- If true, shuffle input sequence before split it.

  • drop_last (bool) -- If true, drop last items when len(seq) is not divisible by num_splits.

Return type:

List[List[Any]]

Returns:

List of smaller squences.

class egrecho.utils.data_utils.Dillable[source]#

Bases: object

Mix-in that will leverage dill instead of pickle when pickling an object.

It is useful when the user can’t avoid pickle (e.g. in multiprocessing), but needs to use unpicklable objects such as lambdas.

egrecho.utils.data_utils.iflatmap_unordered(nj, fn, kwargs_iter)[source]#

Parrallized mapping operation.

Note: Data are in kwargs_iter, and flats reults of all jobs to a queue. This operation don’t keep the original order in async way.

Parameters:
  • nj (int) -- num of jobs.

  • fn (Callable) -- a function can yied results from given args.

  • kwargs_iter (Iterable[dict]) -- kwargs map to fn.

Return type:

Iterable

egrecho.utils.data_utils.buffer_shuffle(data, buffer_size=10000, rng=None)[source]#

Buffer shuffle the data.

Parameters:
  • data (Iterable) -- data source.

  • buffer_size (int) -- defaults to 10000.

  • rng (np.random.Generator) -- fix random.

Return type:

Generator

Returns:

Generator yields data item.

egrecho.utils.data_utils.ichunk_size(total_len, split_num=None, chunk_size=None, even=True)[source]#

Infer an enven split chunksize generator before applying split operation if needed.

Parameters:
  • total_len (int) -- The lengths to be divided.

  • chunk_size (int) --

  • split_num (int) -- Number of splits, can be provided to infer chunksizes.

  • even (bool) -- If True, the max differ between chunksize is 1.

Return type:

Generator

Returns:

A generator yields sizes of total length.

Example

>>> list(ichunk_size(15, chunk_size=10, even=True))  # chunk_size=10, total_len=15, adapt (10, 5) -> (8, 7).
[8, 7]
>>> list(ichunk_size(10, split_num=4, even=True))  # split_num=4, total_len=10, adapt chunksize (3, 3, 3, 1) -> (3, 3, 2, 2).
[3, 3, 2, 2]
egrecho.utils.data_utils.ilen(iterable)[source]#

Return the number of items in iterable inputs.

This consumes the iterable data, so handle with care.

Example

>>> ilen(x for x in range(1000000) if x % 3 == 0)
333334
egrecho.utils.data_utils.zip_dict(*dicts)[source]#

Iterate over items of dictionaries grouped by their keys.

class egrecho.utils.data_utils.ClassLabel(num_classes=None, names=None, names_file=None, id=None)[source]#

Bases: DictFileMixin

The instance of this class stores the string names of labels, can be used for mapping str2label or label2str.

Modified from HuggingFace Datasets.

There are 3 ways to define a ClassLabel, which correspond to the 3 arguments:
  • num_classes: Create 0 to (num_classes-1) labels.

  • names: List of label strings.

  • names_file: File (Text) containing the list of labels.

Under the hood the labels are stored as integers. You can use negative integers to represent unknown/missing labels.

Serialize/deserialize of yaml files will be in a more readable way (from_yaml, ‘to_yaml`):

names: -> names:

  • negative -> ‘0’: negative

  • positive -> ‘1’: positive

Parameters:
  • num_classes (int, optional) -- Number of classes. All labels must be < num_classes.

  • names (list of str, optional) -- String names for the integer classes. The order in which the names are provided is kept.

  • names_file (str, optional) -- Path to a file with names for the integer classes, one per line.

Example

>>> label = ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3'])
>>> label
ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3'], id=None)
>>> label.encode_label('speaker1')
1
>>> label.encode_label(1)
1
>>> label.encode_label('1')
1
str2int(values)[source]#

Conversion class name string => integer.

Return type:

Union[int, Iterable]

Example

>>> label = ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3'])
>>> label.str2int('speaker1')
0
int2str(values)[source]#

Conversion integer => class name string.

Regarding unknown/missing labels: passing negative integers raises ValueError.

Return type:

Union[str, Iterable]

Example

>>> label = ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3'])
>>> label.int2str(0)
'speaker1'
classmethod from_dict(data)[source]#

Deserialize from dict.

transform dict to list:

names: -> names:

‘0’: negative -> - negative

‘1’: positive -> - positive

to_file(path, save_vocab=False)[source]#

Serialize to a dict file.

names: -> names:

  • negative -> ‘0’: negative

  • positive -> ‘1’: positive

Parameters:
  • path (Union[Path, str]) -- file path

  • save_vocab (bool) -- also saves a clean vocab json, i.e., str2int:

  • {'negative' -- 0, ‘positive’:1}

class egrecho.utils.data_utils.VSORT(value)[source]#

Bases: StrEnum

An enumeration.

egrecho.utils.data_utils.count_vocab_from_iterator(iterator)[source]#

Build a Vocab counter from an iterator.

Parameters:

iterator (Iterable) -- Iterator used to count vocab. Must yield list or iterator of tokens.

Return type:

VocabCounter

Returns:

Obj of VocabCounter.

Examples

>>> #generating vocab from text file
>>> from egrecho.utils.data_utils import build_vocab_from_iterator
>>> def yield_tokens(file_path):
>>>     with io.open(file_path, encoding = 'utf-8') as f:
>>>         for line in f:
>>>             yield line.strip().split()
>>> vocab_counter = count_vocab_from_iterator(yield_tokens(file_path))
egrecho.utils.data_utils.build_vocab_from_iterator(iterator, savedir=None, fname='vocab', min_freq=1, specials=None, special_first=True, max_tokens=-1)[source]#

Build a Vocab from an iterator.

Parameters:
  • iterator (Iterable) -- Iterator used to build Vocab. Must yield list or iterator of tokens.

  • savedir (Union[str, Path, None]) -- Path directory to save the vocab file. if None, will skip save.

  • fname (str) -- filename

  • specials (Optional[List[str]]) -- Special symbols to add. The order of supplied tokens will be preserved.

  • special_first (bool) -- Indicates whether to insert symbols at the beginning or at the end.

  • min_freq (int) -- If provided , defines the minimum frequency needed to include a token in the vocabulary.

  • max_tokens (int) -- If provided > 0 creates the vocab from the max_tokens - len(specials) most frequent tokens.

Return type:

ClassLabel

Returns:

Obj of ClassLabel.

Examples

>>> #generating vocab from text file
>>> from egrecho.utils.data_utils import build_vocab_from_iterator
>>> def yield_tokens(file_path):
>>>     with io.open(file_path, encoding = 'utf-8') as f:
>>>         for line in f:
>>>             yield line.strip().split()
>>> vocab = build_vocab_from_iterator(yield_tokens(file_path), specials=["<unk>"])
... # or save files
>>> vocab = build_vocab_from_iterator(yield_tokens(file_path), savedir='./', specials=["<unk>"])
class egrecho.utils.data_utils.SplitInfo(name='', patterns='', num_examples=0, meta=None)[source]#

Bases: object

A container of split dataset info.

egrecho.utils.data_utils.try_length(data)[source]#

Try to get the length of an object, fallback to return None.

Return type:

Optional[int]

egrecho.utils.data_utils.wavscp2dicts(wav_file, col_utt_name='id', col_path_name='audio_path')[source]#

Read wav.scp to a list of dict

Return type:

List[Dict[str, str]]

egrecho.utils.dist#

egrecho.utils.dist.send_exit_ddp(stop_flag=0)[source]#

reduce stop signal across ddp progresses, controled by main rank.

egrecho.utils.dist.get_free_port()[source]#

Select a free port for localhost.

Useful in single-node training when we don’t want to connect to a real main node but have to set the MASTER_PORT environment variable.

egrecho.utils.dist.is_port_in_use(port=None)[source]#

Checks if a port is in use on localhost.

Return type:

bool

class egrecho.utils.dist.DistInfo(world_size, rank)[source]#

Bases: object

Contains the environment for the current dist rank.

classmethod detect(group=None, allow_env=True)[source]#

Tries to automatically detect the pytorch distributed environment paramters. :rtype: DistInfo

Note

If allow_env = True, some other dist environment may be detected. This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) as the distributed framework won’t be initialized there. It will default to 1 distributed process in this case.

class egrecho.utils.dist.WorkerInfo(num_workers, id)[source]#

Bases: object

Contains the environment for the current dataloader within the current training process.

classmethod detect(allow_env=True)[source]#

Automatically detects the number of pytorch workers and the current rank. :rtype: WorkerInfo

Note

If allow_env = True, some other worker environment may be detected. This only works reliably within a dataloader worker as otherwise the necessary information won’t be present. In such a case it will default to 1 worker

class egrecho.utils.dist.EnvInfo(dist_info=None, worker_info=None)[source]#

Bases: object

Container of DistInfo and WorkerInfo.

classmethod from_args(world_size, rank, num_workers, worker_id)[source]#

Set env info from args.

Parameters:
  • world_size (int) -- The worldsize used for distributed training (equals total number of distributed processes)

  • rank (int) -- The distributed global rank of the current process

  • num_workers (int) -- The number of workers per distributed training process

  • worker_id (int) -- The rank of the current worker within the number of workers of the current training process

Return type:

EnvInfo

property num_shards: int#

Returns the total number of shards.

Note

This may not be accurate in a non-dataloader-worker process like the main training process as it doesn’t necessarily know about the number of dataloader workers.

property shard_rank: int#

Returns the rank of the current process wrt. the total number of shards.

Note

This may not be accurate in a non-dataloader-worker process like the main training process as it doesn’t necessarily know about the number of dataloader workers.

egrecho.utils.dist.is_global_rank_zero()[source]#

Helper function to determine if the current process is global_rank 0 (the main process)

class egrecho.utils.dist.TorchMPLauncher(num_processes, port=None, disable_mem_share=False, start_method='spawn')[source]#

Bases: object

Launches processes that run a given function in parallel, and joins them all at the end.

Worker processes gives a rank to os.envrion["LOCAL_RANK"] that ranges from 0 to N - 1.

Referring to lightning fabric.

Note

  • This launcher requires all objects to be pickleable.

  • Entry point to the program/script should guarded by if __name__ == "__main__".

  • In environments like Ipython notebooks where ‘spawn’ is not available, ‘fork’ works better.

  • Start method ‘fork’ the user must ensure that no CUDA context gets created in the main process before the launcher is invoked, i.e., torch.cuda should be uninitialized.

Parameters:
  • num_processes (int) -- number works.

  • port (Optional[int]) -- master port, if not set, will auto find in localhost.

  • disable_mem_share (bool) -- mem_share is the feature of torch.multiprocessing. Required set True when running models on CPU, see _disable_module_memory_sharing().

  • start_method (Literal['spawn', 'fork', 'forkserver']) -- The method how to start the processes.

Example:

launcher = TorchMPLauncher(num_processes=4)
launcher.launch(my_function, arg1, arg2, kwarg1=value1)
launch(function, *args, **kwargs)[source]#

Launches processes that run the given function in parallel.

The function is allowed to have a return value. However, when all processes join, only the return value of worker process 0 gets returned from this launch method in the main process.

Parameters:
  • function (Callable) -- The entry point for all launched processes.

  • *args (Any) -- Optional positional arguments to be passed to the given function.

  • **kwargs (Any) -- Optional keyword arguments to be passed to the given function.

Return type:

Any

egrecho.utils.imports#

egrecho.utils.imports.is_package_available(*modules)[source]#

Returns if a top-level module with name exists without importing it. This is generally safer than try-catch block around a import X. It avoids third party libraries breaking assumptions of some of our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544).

Return type:

bool

egrecho.utils.imports.is_module_available(module_path)#

Check if a module path is available in your environment. This will try to import it.

Return type:

bool

Example

>>> is_module_available('torch')
True
>>> is_module_available('fake')
False
>>> is_module_available('torch.utils')
True
>>> is_module_available('torch.util')
False
egrecho.utils.imports.compare_version(package, op, ver, use_base_version=False)[source]#

Compare package version with some requirements.

Return type:

bool

Example

>>> compare_version("torch", operator.ge, "0.1")
True
>>> compare_version("does_not_exist", operator.ge, "0.0")
False
egrecho.utils.imports.check_ort_requirements(version='1.4')#

Check onnxruntime is installed and if the installed version match is recent enough

Raises:

ImportError -- If onnxruntime is not installed or too old version is found

egrecho.utils.imports.lazy_import(module_name, callback=None)[source]#

Returns a proxy module object that will lazily import the given module the first time it is used.

Copied from lightning utilities.

Parameters:
  • module_name -- the fully-qualified module name to import

  • callback (None) -- a callback function to call before importing the module

Returns:

a proxy module object that will be lazily imported when first used

Example:

# Lazy version of `import tensorflow as tf`
tf = lazy_import("tensorflow")

# Other commands

# Now the module is loaded
tf.__version__
class egrecho.utils.imports.LazyModule(module_name, callback=None)[source]#

Bases: module

Proxy module that lazily imports the underlying module the first time it is actually used.

Parameters:
  • module_name -- the fully-qualified module name to import

  • callback (None) -- a callback function to call before importing the module

egrecho.utils.logging#

egrecho.utils.logging.get_logger(name='egrecho.utils.logging')[source]#

Get logger singleton instance based on package name.

Return type:

Logger

class egrecho.utils.logging.Logger(name)[source]#

Bases: object

Sigleton pattern for logger.

Parameters:

name (str) -- The name of the logger.

set_level(level)[source]#

Set the logging level

Parameters:

level (str) -- Can only be INFO, DEBUG, WARNING and ERROR.

Return type:

None

info(message, ranks=None, stack_pos=2, verbose=None)[source]#

Log an info message.

Parameters:
  • message (str) -- The message to be logged.

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

info_once(message, ranks=None)[source]#

Log a warning, but only once.

Parameters:
  • message (str) -- Message to display

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

warning(message, ranks=None, stack_pos=2, verbose=None)[source]#

Log a warning message.

Parameters:
  • message (str) -- The message to be logged.

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

warning_once(message, ranks=None)[source]#

Log a warning, but only once.

Parameters:
  • message (str) -- The message to be logged.

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

debug(message, ranks=None, stack_pos=2, verbose=None)[source]#

Log a debug message.

Parameters:
  • message (str) -- The message to be logged.

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

error(message, ranks=None, stack_pos=2, verbose=None)[source]#

Log an error message.

Parameters:
  • message (str) -- The message to be logged.

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

error_once(message, ranks=None, verbose=True)[source]#

Log a error, but only once.

Parameters:
  • message (str) -- The message to be logged.

  • ranks (List[int]) -- List of parallel ranks.

Return type:

None

egrecho.utils.mask#

egrecho.utils.mask.subsequent_chunk_mask(size, chunk_size, num_left_chunks=-1, device=device(type='cpu'))[source]#

Create mask for subsequent steps (size, size) with chunk size

Parameters:
  • size (int) -- size of mask

  • chunk_size (int) -- size of chunk

  • num_left_chunks (int) -- number of left chunks <0: use full chunk >=0: use num_left_chunks

  • device (torch.device) -- “cpu” or “cuda” or torch.Tensor.device

Returns:

mask

Return type:

torch.Tensor

Examples

>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 1],
 [1, 1, 1, 1]]
egrecho.utils.mask.make_pad_mask(lengths, max_len=0)[source]#

Make mask tensor containing indices of padded part.

See description of make_non_pad_mask.

Parameters:

lengths (torch.Tensor) -- Batch of lengths (B,).

Returns:

Mask tensor containing indices of padded part.

Return type:

torch.BoolTensor

Examples

>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
         [0, 0, 0, 1, 1],
         [0, 0, 1, 1, 1]]
egrecho.utils.mask.make_non_pad_mask(lengths, max_len=0)[source]#

Make mask tensor containing indices of non-padded part.

The sequences in a batch may have different lengths. To enable batch computing, padding is need to make all sequence in same size. To avoid the padding part pass value to context dependent block such as attention or convolution , this padding part is masked.

This pad_mask is used in both encoder and decoder.

1 for non-padded part and 0 for padded part.

Parameters:

lengths (torch.Tensor) -- Batch of lengths (B,).

Returns:

mask tensor containing indices of no padded part.

Return type:

torch.BoolTensor

Examples

>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
         [1, 1, 1, 0, 0],
         [1, 1, 0, 0, 0]]
egrecho.utils.mask.prepare_4d_attention_mask(mask, dtype, tgt_len=None)[source]#

Creates a non-causal 4D mask of shape (batch_size, 1, query_length, key_value_length) from a 2D mask of shape (batch_size, key_value_length). The values in org 2D mask is 0 (no-padded) or 1 (padded).

Parameters:
  • mask (torch.Tensor or None) -- A 2D attention mask of shape (batch_size, key_value_length)

  • dtype (torch.dtype) -- The torch dtype the created mask shall have.

  • tgt_len (int) -- The target length or query length the created mask shall have.

egrecho.utils.mask.make_causal_mask(inputs, dtype, past_key_values_length=0, sliding_window=None)[source]#

Creates a causal 4d mask from given querys.

Parameters:
  • inputs (Tensor) -- inputs query of shape [bsz, seq_len, …].

  • dtype (dtype) -- the torch dtype the created mask shall have.

  • past_key_values_length (int) -- cached past kv length

  • sliding_window (Optional[int]) -- left chunk size

Returns:

4d attention mask (batch_size, 1, query_length, key_value_length) that can be multiplied with attention scores.

Examples

>>> make_causal_mask(torch.arange(3).view(1,3),torch.float)
tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]]])
>>> make_causal_mask(torch.arange(3).view(1,3), torch.float, past_key_values_length=2)
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]])
>>> make_causal_mask(torch.arange(3).view(1,3), torch.float, past_key_values_length=2, sliding_window=2)
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        [-3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        [-3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]])

egrecho.utils.misc#

egrecho.utils.misc.parse_bytes(size)[source]#

parse size from bytes to the largest possible unit

Return type:

float

egrecho.utils.misc.parse_size(size, unit_type='abbrev')[source]#

parse size to the largest possible unit.

egrecho.utils.misc.kill_name_proc(grep_str, force_kill=False)[source]#

Kills processes based on a grep string.

Parameters:
  • grep_str (str) -- The grep string to search for processes.

  • force_kill (bool) -- If True, uses SIGKILL (-9) to forcefully terminate processes.

The function uses ps aux | grep to search for processes that match the given grep_str. It then requests user confirmation to kill these processes. If the user agrees, it terminates the processes.

egrecho.utils.misc.is_picklable(obj)[source]#

Tests if an object can be pickled.

Return type:

bool

egrecho.utils.misc.valid_import_clspath(name)[source]#

Import path must be str with dot pattern ('calendar.Calendar').

egrecho.utils.misc.class2str(value)[source]#

Extract path from class type.

Example:

from calendar import Calendar
assert str(Calendar) == "<class 'calendar.Calendar'>"
assert class2str(Calendar) == "calendar.Calendar"
egrecho.utils.misc.get_import_path(value)[source]#

Returns the shortest dot import path for the given object.

Return type:

Optional[str]

egrecho.utils.misc.locate_(path)[source]#

COPIED FROM Hydra.

Locate an object by name or dotted path, importing as necessary. This is similar to the pydoc function locate, except that it checks for the module from the given path from back to front.

Behaviours like:

path = "calendar.Calendar"
m, c = path.rsplit('.', 1)
mo = importlib.import_module(m)
cl = getattr(mo, c)

egrecho.utils.register#

class egrecho.utils.register.Register(name)[source]#

Bases: object

This class is used to register functions or classes.

Parameters:

name (str) -- The name of the registry.

Example:

LITTLES = Register("littles")

@LITTLES.register(name='litclass')
class LitClass:
    def __init__(self, a, b=123):
        self.a = a
        self.b = b

or

LITTLES.register(LitClass, name='litclass')
register(fn_or_cls=None, name=None, override=False, provider_info=None, **metadata)[source]#

Adds a function or class.

Parameters:
  • fn_or_cls (Callable) -- The function to be registered.

  • name (str) -- Name string.

  • override (bool) -- Whether override if exists.

  • provider_info (ProviderInfo) -- infos about provider, which will be merged into metadata.

  • **metadata (dict) -- Additional dict to be saved.

Return type:

Callable

get(key, with_metadata=False, strict=True, **metadata)[source]#

Retrieves functions or classes with key name which has already been registed before.

Parameters:
  • key (str) -- Name of the registered function.

  • with_metadata (bool) -- Whether to include the associated metadata in the return value.

  • strict (bool) -- Whether to return all matches or just one.

  • metadata -- Metadata used to filter against existing registry item’s metadata.

keys()[source]#

Get all name registred.

class egrecho.utils.register.ExternalRegister(name)[source]#

Bases: Register

TODO

class egrecho.utils.register.ConcatRegister(*registers)[source]#

Bases: Register

This class is used to concatenate multiple registers.

class egrecho.utils.register.StrRegister(name)[source]#

Bases: object

Registers multiple strings, which can be used to create alias names.

Parameters:

name (str) -- The name of the registry.

Example:

TOTAL_STEPS = StrRegister("total_steps")

TOTAL_STEPS.register("num_training_steps")

or

TOTAL_STEPS.register(["num_training_steps",])

assert TOTAL_STEPS.keys() == ["total_steps","num_training_steps"]
assert "num_training_steps" in TOTAL_STEPS
register(str_or_strs)[source]#

Adds a string or sequence of strings.

Parameters:

str_or_strs (Union[str, List, Tuple]) -- String(s) to be registed.

Return type:

Union[str, List[str]]

keys()[source]#

Get all name registred.

class egrecho.utils.register.ConcatStrRegister(*registers)[source]#

Bases: StrRegister

This class is used to concatenate multiple registers.

egrecho.utils.torch_utils#

class egrecho.utils.torch_utils.RandomValue(end, start=0)[source]#

Bases: object

Generate a uniform distribution in the range [start, end].

egrecho.utils.torch_utils.batch_pad_right(tensors, mode='constant', value=0, val_index=-1)[source]#

Given a list of torch tensors it batches them together by padding to the right on each dimension in order to get same length for all.

Parameters:
  • tensors (list) -- List of tensor we wish to pad together.

  • mode (str) -- Padding mode see torch.nn.functional.pad documentation.

  • value (float) -- Padding value see torch.nn.functional.pad documentation.

Returns:

  • tensor (torch.Tensor) -- Padded tensor.

  • valid_vals (list) -- List containing proportion for each dimension of original, non-padded values.

egrecho.utils.torch_utils.pad_right_to(tensor, target_shape, mode='constant', value=0)[source]#

This function takes a torch tensor of arbitrary shape and pads it to target shape by appending values on the right.

Parameters:
  • tensor (input torch tensor) -- Input tensor whose dimension we need to pad.

  • target_shape ((list, tuple)) -- Target shape we want for the target tensor its len must be equal to tensor.ndim

  • mode (str) -- Pad mode, please refer to torch.nn.functional.pad documentation.

  • value (float) -- Pad value, please refer to torch.nn.functional.pad documentation.

Returns:

  • tensor (torch.Tensor) -- Padded tensor.

  • valid_vals (list) -- List containing proportion for each dimension of original, non-padded values.

egrecho.utils.torch_utils.audio_collate_fn(waveforms)[source]#

Pad a list of waves with shape (…, T), returns a tuple containing tensor and a list of its lengths.

Return type:

Tuple[Tensor, Tensor]

egrecho.utils.torch_utils.is_numpy_array(x)[source]#

Tests if x is a numpy array or not.

egrecho.utils.torch_utils.infer_framework_from_repr(x)[source]#

Tries to guess the framework of an object x from its repr (brittle but will help in is_tensor to try the frameworks in a smart order, without the need to import the frameworks).

egrecho.utils.torch_utils.to_py_obj(obj)[source]#

Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.

egrecho.utils.torch_utils.to_numpy(obj)[source]#

Convert a PyTorch tensor, Numpy array or python list to a Numpy array.

egrecho.utils.torch_utils.save_dislike_batch(batch, exstr=None, expdir=None)[source]#

Save the dislike batch into disk. Might use this to debug.

Return type:

None

egrecho.utils.types#

class egrecho.utils.types.ModelOutput[source]#

Bases: OrderedDict

Base class for all model outputs as dataclass. Copied from huggingface modelout.

Has a __getitem__() that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the None attributes. Otherwise behaves like a regular python dictionary.

Tip

You can’t unpack a ModelOutput directly. Use the [ModelOutput.to_tuple] method to convert it to a tuple before.

setdefault(*args, **kwargs)[source]#

remove this action

pop(*args, **kwargs)[source]#

remove this action

update(*args, **kwargs)[source]#

remove this action

to_tuple()[source]#

Convert self to a tuple containing all the attributes/keys that are not None.

Return type:

Tuple[Any]

class egrecho.utils.types.SingletonMeta[source]#

Bases: type

A metaclass for creating singleton classes.

class egrecho.utils.types.StrEnum(value)[source]#

Bases: str, Enum

Type of any enumerator with allowed comparison to string invariant to cases.

>>> class MySE(StrEnum):
...     t1 = "T-1"
...     t2 = "T-2"
>>> MySE("T-1") == MySE.t1
True
>>> MySE.from_str("t-2", source="value") == MySE.t2
True
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any")
Traceback (most recent call last):
  ...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
classmethod from_str(value, source='key')[source]#

Create StrEnum from a string matching the key or value.

Parameters:
  • value (str) -- matching string

  • source (Literal['key', 'value', 'any']) --

    compare with:

    • "key": validates only from the enum keys, typical alphanumeric with “_”

    • "value": validates only from the values, could be any string

    • "any": validates with any key or value, but key has priority

Raises:

ValueError -- if requested string does not match any option based on selected source.

Return type:

StrEnum

classmethod try_from_str(value, source='key')[source]#

Try to create emun and if it does not match any, return None.

Return type:

Optional[StrEnum]

class egrecho.utils.types.Split(value)[source]#

Bases: StrEnum

Contains Enums of split.

class egrecho.utils.types.FilterType(value)[source]#

Bases: StrEnum

Filt strategy when dict dumps.

class egrecho.utils.types.PaddingStrategy(value)[source]#

Bases: StrEnum

Possible values for the padding argument. Useful for tab-completion in an IDE.

class egrecho.utils.types.InitWeightType(value)[source]#

Bases: StrEnum

An enumeration.

egrecho.utils.dummy#

egrecho.utils.io#

egrecho.utils.patch#

egrecho.utils.seeder#

egrecho.utils.text#