egrecho.utils.apply#
- egrecho.utils.apply.apply_to_collection(data, dtype, function, *args, wrong_dtype=None, include_none=True, allow_frozen=False, **kwargs)[source]#
Recursively applies a function to all elements of a certain dtype.
- Parameters:
data (
Any
) -- the collection to apply the function todtype (
Union
[type
,Any
,Tuple
[Union
[type
,Any
]]]) -- the given function will be applied to all elements of this dtypefunction (
Callable
) -- the function to apply*args (
Any
) -- positional arguments (will be forwarded to calls offunction
)wrong_dtype (
Union
[type
,Tuple
[type
,...
],None
]) -- the given function won’t be applied if this type is specified and the given collections is of thewrong_dtype
even if it is of typedtype
include_none (
bool
) -- Whether to include an element if the output offunction
isNone
.allow_frozen (
bool
) -- Whether not to error upon encountering a frozen dataclass instance.**kwargs (
Any
) -- keyword arguments (will be forwarded to calls offunction
)
- Return type:
Any
- Returns:
The resulting collection
egrecho.utils.common#
- egrecho.utils.common.is_in_range(val, max_val=None, min_val=None)[source]#
Value in range judging.
Range is close interval (e.g., [1, 2]), If the boundary is None, skip that condition.
- Parameters:
val -- value to be judged.
max_val (Optional[Any], optional) -- Defaults to None.
min_val (Optional[Any], optional) -- Defaults to None.
- Return type:
bool
- Returns:
bool
- class egrecho.utils.common.ObjectDict[source]#
Bases:
dict
Makes a dictionary behave like an object, with attribute-style access.
Here are some examples of how it can be used:
o = ObjectDict(my_dict) # or like this: o = ObjectDict(samples=samples, sample_rate=sample_rate) # Attribute-style access samples = o.samples # Dict-style access samples = o["samples"]
- class egrecho.utils.common.DataclassSerialMixin[source]#
Bases:
object
From/to dict mixin of dataclass.
Example
>>> from dataclasses import dataclass >>> @dataclass ... class Config(DataclassSerialMixin): ... a: int = 123 ... b: str = "456" ... >>> config = Config(filt_type='orig') >>> config Config(a=123, b='456') >>> config.to_dict() {'a': 123, 'b': '456'} >>> config_ = Config.from_dict({"a": 123, "b": 456}) >>> config_ Config(a=123, b='456') >>> assert config == config_
- to_dict(dict_factory=<class 'dict'>, filt_type=None, init_field_only=True, save_dc_types=False)[source]#
Serializes this dataclass to a dict.
- Return type:
dict
- classmethod from_dict(obj, drop_extra_fields=None)[source]#
Parses an instance of
cls
from the given dict.NOTE: If the
decode_into_subclasses
class attribute is set to True (or ifdecode_into_subclasses=True
was passed in the class definition), then if there are keys in the dict that aren’t fields of the dataclass, this will decode the dict into an instance the first subclass of cls which has all required field names present in the dictionary.- Passing
drop_extra_fields=None
(default) will use the class attribute described above.
- Passing
- Passing
drop_extra_fields=True
will decode the dict into an instance of
cls
and drop the extra keys in the dict. Passingdrop_extra_fields=False
forces the above-mentioned behaviour.
- Passing
- class egrecho.utils.common.GenericSerialMixin[source]#
Bases:
object
Serialization mixin of common class with config attribute.
- egrecho.utils.common.fields_init_var(class_or_instance)[source]#
Return a tuple describing the
InitVar
fields of this dataclass.Modified from: https://docs.python.org/3/library/dataclasses.html#dataclasses.fields
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- egrecho.utils.common.batched(iterable, n)[source]#
Batch data into tuples of length n. The last batch may be shorter.
- egrecho.utils.common.snakecase_to_camelcase(name)[source]#
Convert snake-case string to camel-case string.
- egrecho.utils.common.get_diff_dict(curr_dict, src_dict)[source]#
Compare two dictionaries and return a new dictionary containing the differing key-value pairs.
- Parameters:
curr_dict (Dict) -- The current dictionary to compare.
src_dict (Dict) -- The source dictionary to compare against.
- Returns:
A dictionary containing the key-value pairs that differ between
src_dict
andcurr_dict
.- Return type:
Dict
Example
>>> src_dict = {"name": "John", "age": 30} >>> curr_dict = {"name": "John", "age": 35, "city": "New York"} >>> diff_dict = get_diff_dict(curr_dict, src_dict) {'age': 35, 'city': 'New York'}
- egrecho.utils.common.del_default_dict(data, defaults, recurse=False)[source]#
Removes key-value pairs from a dictionary that match default values.
- Parameters:
data (Dict) -- The dictionary to remove default values from.
defaults (Dict) -- The dictionary containing default values to compare against.
recurse (
bool
) -- recursively processes. Defaults to False.
- Returns:
A dictionary with default values removed.
- Return type:
Dict
Example
>>> data = {"name": "John", "age": 30, "address": {"city": "New York", "zip": 10001}} >>> defaults = {"name": "John", "age": 30, "address": {"city": "Unknown", "zip": None}} >>> cleaned_data = del_default_dict(data, defaults) >>> cleaned_data {'address': {'city': 'New York', 'zip': 10001}}
Note
This function modifies the data dictionary in place and also returns it. If
recurse=True
it recursively processes nested dictionaries.
- egrecho.utils.common.dict_union(*dicts, recurse=True, sort_keys=False, dict_factory=<class 'dict'>)[source]#
Combine multiple dictionaries into the first one.
- Parameters:
*dicts (dict) -- One or more dictionaries to combine.
recurse (bool, optional) -- If True, also recursively combine nested dictionaries.
sort_keys (bool, optional) -- If True, sort the keys alphabetically in the resulting dictionary.
dict_factory (callable, optional) -- A callable that creates the output dictionary (default is dict).
- Returns:
A new dictionary containing the union of all input dictionaries.
- Return type:
dict
Example
>>> from collections import OrderedDict >>> a = OrderedDict(a=1, b=2, c=3) >>> b = OrderedDict(c=5, d=6, e=7) >>> dict_union(a, b, dict_factory=OrderedDict) OrderedDict([('a', 1), ('b', 2), ('c', 5), ('d', 6), ('e', 7)]) >>> a = OrderedDict(a=1, b=OrderedDict(c=2, d=3)) >>> b = OrderedDict(a=2, b=OrderedDict(c=3, e=6)) >>> dict_union(a, b, dict_factory=OrderedDict) OrderedDict([('a', 2), ('b', OrderedDict([('c', 3), ('d', 3), ('e', 6)]))])
- egrecho.utils.common.list2tuple(func)[source]#
Transfer the list in input parameter to hashable tuple, it is useful for
lru_cache
.NOTE: Don’t support nest structure.
Example
>>> def get_input(*args, **kwargs): ... return args, kwargs
>>> @list2tuple >>> def get_input_wrapper(*args, **kwargs): ... return args, kwargs >>> arg_input = ([1,2,3], 'foo') >>> kwargs_input = {'bar': True, 'action': ["go", "leave", "break"], 'action1':['buy', ['sell', 'argue']]} >>> get_input(*arg_input, **kwargs_input) (([1, 2, 3], 'foo'), {'bar': True, 'action': ['go', 'leave', 'break'], 'action1': ['buy', ['sell', 'argue']]}) >>> get_input_wrapper(*arg_input, **kwargs_input) (((1, 2, 3), 'foo'), {'bar': True, 'action': ('go', 'leave', 'break'), 'action1': ('buy', ['sell', 'argue'])})
- egrecho.utils.common.omegaconf_handler(config, omegaconf_resolve=True)[source]#
If omegaconf is available, creates DictConfig and resolves it when omegaconf_resolve=True.
- Parameters:
config (
Optional
[Dict
[Any
,Any
]]) -- lodaded kwargs dict.omegaconf_resolve -- If False, returns DictConfig else returns the dict resolved from DictConfig.
egrecho.utils.constants#
egrecho.utils.cuda_utils#
- egrecho.utils.cuda_utils.avoid_float16_autocast_context()[source]#
If the current autocast context is float16, cast it to bfloat16 if available (unless we’re in jit) or float32
- egrecho.utils.cuda_utils.maybe_set_cuda_expandable_segments(enabled)[source]#
Configures PyTorch memory allocator to expand existing allocated segments instead of re-allocating them when tensor shape grows. This can help speed up the training when sequence length and/or batch size change often, and makes GPU more robust towards OOM.
See here for more details: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
- egrecho.utils.cuda_utils.set_to_cuda(modules)[source]#
Send modules to gpu.
- Parameters:
modules -- nn.module or a list of module
- egrecho.utils.cuda_utils.get_current_device()[source]#
Returns currently selected device (gpu/cpu). If cuda available, return gpu, otherwise return cpu.
- Return type:
device
- egrecho.utils.cuda_utils.to_device(data, device_object=None)[source]#
Move a tensor or collection of tensors to a specified device by inferring the device from another object.
- Parameters:
data (Any) -- A tensor, collection of tensors, or anything with a
.to(...)
method.device_object (Union[torch.device, str, int, torch.Tensor, torch.nn.Module], optional) --
The target device. Can be one of the following:
torch.device
,str
,int
, The target device.torch.nn.Module
: Infer the device from a module.torch.Tensor
: Infer the device from a tensor. (default: the current defualt cuda device.)
- Return type:
Any
- Returns:
The same collection with all contained tensors residing on the target device.
- egrecho.utils.cuda_utils.parse_gpus(gpus)[source]#
Parse gpus option.
- Parameters:
gpus (
Union
[str
,int
,Sequence
[int
]]) --GPUs used for training on this machine.
-1: all; N: [0,N); “1,2”: comma-seperated; “0” means invalid while “0,” specify 1 gpu with id:1.
- Return type:
Optional
[List
[int
]]
- egrecho.utils.cuda_utils.parse_gpus_opt(gpus)[source]#
Similar to parse_gpus but combines auto choose a single GPU.
- Parameters:
gpus (Optional[Union[str, int]]) --
What GPUs should be used:
case 0: comma-separated list, e.g., “1,” or “0,1” means specified id(s).
case 1: a single int (str) negative number (-1) means all visible devices
[0, N-1]
.case 2: ‘’ or None or 0 returns None means no GPU.
case 3: a single int (str) number equals 1 means auto choose a spare GPU.
case 4: a single int (str) number n greater than 1 returns
[0, n-1]
.
- Returns:
A list of GPU IDs or None.
- Return type:
Optional[List[int]]
- egrecho.utils.cuda_utils.parse_gpu_id(gpu_id='auto')[source]#
Parse single gpu id option.
- Parameters:
gpu_id (Optional[Union[str, int]]) --
select which GPU:
case 0: “auto”, auto select spare gpu.
case 1: a single int (str) negative number (-1) means cpu.
case 2: a single int (str) positive number means specified id.
case 3: ‘’ or None returns None, which means defualt behaviour in same case, e.g., torch.load(…)
case 4: other strings, e.g., “cuda:1”
- Return type:
Union
[str
,int
,None
]
- egrecho.utils.cuda_utils.synchronize()[source]#
Similar to cuda.synchronize(). Waits for all kernels in all streams on a CUDA device to complete.
- egrecho.utils.cuda_utils.peak_cuda_memory()[source]#
Return the peak gpu memory statistics.
- Returns:
the allocated CUDA memory max_cached (float): the cached CUDA memory
- Return type:
max_alloc (float)
- egrecho.utils.cuda_utils.release_memory(*objects)[source]#
Triggers garbage collection and Releases cuda cache memory.
This function sets the inputs to None, triggers garbage collection to release CPU memory references, and attempts to clear GPU memory cache.
- Parameters:
*objects -- Variable number of objects to release.
- Returns:
A list of None values, with the same length as the input objects.
- Return type:
List[None]
Example
>>> import torch >>> a = torch.ones(1024, 1024).cuda() >>> b = torch.ones(1024, 1024).cuda() >>> a, b = release_memory(a, b) ```
- class egrecho.utils.cuda_utils.GPUManager(addtional_qargs=None, mode=AutoGPUMode.MAX_MEM)[source]#
Bases:
object
This class enables the automated selection of the most available GPU or another based on specified mode.
- Parameters:
addtional_qargs (Optional) -- Additional arguments passed to
nvidia-smi
.mode (AutoGPUMode) -- mode for GPU selection. Defaults to MAX_MEM (max-free) memory.
Example:
import os,torch os.environ["CUDA_VISIBLE_DEVICES"]="1,2" gm = GPUManager() torch_device = gm.auto_choice() or torch_device = GPUManager.detect() a = torch.randn(1,1000) a.to(torch_device)
- new_query()[source]#
Running the
nvidia-smi
command and organizing the results as a list of dictionaries containing information searched from nvidia-smi.
- egrecho.utils.cuda_utils.device2gpu_id(device_id)[source]#
Given the device index and get the unmasked real GPU ID.
- Return type:
str
- egrecho.utils.cuda_utils.num_gpus()#
Get visible gpu number.
- Return type:
int
- egrecho.utils.cuda_utils.patch_nvml_check_env()[source]#
A context manager that patch
PYTORCH_NVML_BASED_CUDA_CHECK=1
, and restore finally.
egrecho.utils.data_utils#
- egrecho.utils.data_utils.split_sequence(seq, split_num, mode='batch', shuffle=False, drop_last=False)[source]#
Split a sequence into
num_splits
equal parts. The element order can be randomized. Raises aValueError
ifsplit_num
is larger thanlen(seq)
. Support mode of ‘shard’ or ‘batch’. If ‘batch’, the splits lists as original sequence else shard the original sequence, e.g., for spliting[0, 1 , 2, 3]
to 2 parts, shard mode result:[[0, 2], [1, 3]]
while batch mode reult:[[0, 1], [2, 3]]
.- Parameters:
seq (Sequence) -- Input iterable.
num_splits (int) -- Split num.
mode (str) -- (‘shard’, ‘batch’)
shuffle (bool) -- If true, shuffle input sequence before split it.
drop_last (bool) -- If true, drop last items when
len(seq)
is not divisible bynum_splits
.
- Return type:
List
[List
[Any
]]- Returns:
List of smaller squences.
- class egrecho.utils.data_utils.Dillable[source]#
Bases:
object
Mix-in that will leverage
dill
instead ofpickle
when pickling an object.It is useful when the user can’t avoid
pickle
(e.g. in multiprocessing), but needs to use unpicklable objects such as lambdas.
- egrecho.utils.data_utils.iflatmap_unordered(nj, fn, kwargs_iter)[source]#
Parrallized mapping operation.
Note: Data are in kwargs_iter, and flats reults of all jobs to a queue. This operation don’t keep the original order in async way.
- Parameters:
nj (int) -- num of jobs.
fn (Callable) -- a function can yied results from given args.
kwargs_iter (Iterable[dict]) -- kwargs map to
fn
.
- Return type:
Iterable
- egrecho.utils.data_utils.buffer_shuffle(data, buffer_size=10000, rng=None)[source]#
Buffer shuffle the data.
- Parameters:
data (Iterable) -- data source.
buffer_size (int) -- defaults to 10000.
rng (np.random.Generator) -- fix random.
- Return type:
Generator
- Returns:
Generator yields data item.
- egrecho.utils.data_utils.ichunk_size(total_len, split_num=None, chunk_size=None, even=True)[source]#
Infer an enven split chunksize generator before applying split operation if needed.
- Parameters:
total_len (int) -- The lengths to be divided.
chunk_size (int) --
split_num (int) -- Number of splits, can be provided to infer chunksizes.
even (bool) -- If True, the max differ between chunksize is 1.
- Return type:
Generator
- Returns:
A generator yields sizes of total length.
Example
>>> list(ichunk_size(15, chunk_size=10, even=True)) # chunk_size=10, total_len=15, adapt (10, 5) -> (8, 7). [8, 7] >>> list(ichunk_size(10, split_num=4, even=True)) # split_num=4, total_len=10, adapt chunksize (3, 3, 3, 1) -> (3, 3, 2, 2). [3, 3, 2, 2]
- egrecho.utils.data_utils.ilen(iterable)[source]#
Return the number of items in iterable inputs.
This consumes the iterable data, so handle with care.
Example
>>> ilen(x for x in range(1000000) if x % 3 == 0) 333334
- egrecho.utils.data_utils.zip_dict(*dicts)[source]#
Iterate over items of dictionaries grouped by their keys.
- class egrecho.utils.data_utils.ClassLabel(num_classes=None, names=None, names_file=None, id=None)[source]#
Bases:
DictFileMixin
The instance of this class stores the string names of labels, can be used for mapping str2label or label2str.
Modified from HuggingFace Datasets.
- There are 3 ways to define a
ClassLabel
, which correspond to the 3 arguments: num_classes
: Create 0 to (num_classes-1) labels.names
: List of label strings.names_file
: File (Text) containing the list of labels.
Under the hood the labels are stored as integers. You can use negative integers to represent unknown/missing labels.
Serialize/deserialize of yaml files will be in a more readable way (
from_yaml
, ‘to_yaml`):names: -> names:
negative -> ‘0’: negative
positive -> ‘1’: positive
- Parameters:
num_classes (int, optional) -- Number of classes. All labels must be
< num_classes
.names (list of str, optional) -- String names for the integer classes. The order in which the names are provided is kept.
names_file (str, optional) -- Path to a file with names for the integer classes, one per line.
Example
>>> label = ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3']) >>> label ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3'], id=None) >>> label.encode_label('speaker1') 1 >>> label.encode_label(1) 1 >>> label.encode_label('1') 1
- str2int(values)[source]#
Conversion class name
string
=>integer
.- Return type:
Union
[int
,Iterable
]
Example
>>> label = ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3']) >>> label.str2int('speaker1') 0
- int2str(values)[source]#
Conversion
integer
=> class namestring
.Regarding unknown/missing labels: passing negative integers raises
ValueError
.- Return type:
Union
[str
,Iterable
]
Example
>>> label = ClassLabel(num_classes=3, names=['speaker1', 'speaker2', 'speaker3']) >>> label.int2str(0) 'speaker1'
- There are 3 ways to define a
- egrecho.utils.data_utils.count_vocab_from_iterator(iterator)[source]#
Build a Vocab counter from an iterator.
- Parameters:
iterator (
Iterable
) -- Iterator used to count vocab. Must yield list or iterator of tokens.- Return type:
VocabCounter
- Returns:
Obj of
VocabCounter
.
Examples
>>> #generating vocab from text file >>> from egrecho.utils.data_utils import build_vocab_from_iterator >>> def yield_tokens(file_path): >>> with io.open(file_path, encoding = 'utf-8') as f: >>> for line in f: >>> yield line.strip().split() >>> vocab_counter = count_vocab_from_iterator(yield_tokens(file_path))
- egrecho.utils.data_utils.build_vocab_from_iterator(iterator, savedir=None, fname='vocab', min_freq=1, specials=None, special_first=True, max_tokens=-1)[source]#
Build a Vocab from an iterator.
- Parameters:
iterator (
Iterable
) -- Iterator used to build Vocab. Must yield list or iterator of tokens.savedir (
Union
[str
,Path
,None
]) -- Path directory to save the vocab file. if None, will skip save.fname (
str
) -- filenamespecials (
Optional
[List
[str
]]) -- Special symbols to add. The order of supplied tokens will be preserved.special_first (
bool
) -- Indicates whether to insert symbols at the beginning or at the end.min_freq (
int
) -- If provided , defines the minimum frequency needed to include a token in the vocabulary.max_tokens (
int
) -- If provided > 0 creates the vocab from the max_tokens - len(specials) most frequent tokens.
- Return type:
- Returns:
Obj of
ClassLabel
.
Examples
>>> #generating vocab from text file >>> from egrecho.utils.data_utils import build_vocab_from_iterator >>> def yield_tokens(file_path): >>> with io.open(file_path, encoding = 'utf-8') as f: >>> for line in f: >>> yield line.strip().split() >>> vocab = build_vocab_from_iterator(yield_tokens(file_path), specials=["<unk>"]) ... # or save files >>> vocab = build_vocab_from_iterator(yield_tokens(file_path), savedir='./', specials=["<unk>"])
- class egrecho.utils.data_utils.SplitInfo(name='', patterns='', num_examples=0, meta=None)[source]#
Bases:
object
A container of split dataset info.
egrecho.utils.dist#
- egrecho.utils.dist.send_exit_ddp(stop_flag=0)[source]#
reduce stop signal across ddp progresses, controled by main rank.
- egrecho.utils.dist.get_free_port()[source]#
Select a free port for localhost.
Useful in single-node training when we don’t want to connect to a real main node but have to set the
MASTER_PORT
environment variable.
- egrecho.utils.dist.is_port_in_use(port=None)[source]#
Checks if a port is in use on localhost.
- Return type:
bool
- class egrecho.utils.dist.DistInfo(world_size, rank)[source]#
Bases:
object
Contains the environment for the current dist rank.
- classmethod detect(group=None, allow_env=True)[source]#
Tries to automatically detect the pytorch distributed environment paramters. :rtype:
DistInfo
Note
If
allow_env = True
, some other dist environment may be detected. This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) as the distributed framework won’t be initialized there. It will default to 1 distributed process in this case.
- class egrecho.utils.dist.WorkerInfo(num_workers, id)[source]#
Bases:
object
Contains the environment for the current dataloader within the current training process.
- classmethod detect(allow_env=True)[source]#
Automatically detects the number of pytorch workers and the current rank. :rtype:
WorkerInfo
Note
If
allow_env = True
, some other worker environment may be detected. This only works reliably within a dataloader worker as otherwise the necessary information won’t be present. In such a case it will default to 1 worker
- class egrecho.utils.dist.EnvInfo(dist_info=None, worker_info=None)[source]#
Bases:
object
Container of DistInfo and WorkerInfo.
- classmethod from_args(world_size, rank, num_workers, worker_id)[source]#
Set env info from args.
- Parameters:
world_size (
int
) -- The worldsize used for distributed training (equals total number of distributed processes)rank (
int
) -- The distributed global rank of the current processnum_workers (
int
) -- The number of workers per distributed training processworker_id (
int
) -- The rank of the current worker within the number of workers of the current training process
- Return type:
- property num_shards: int#
Returns the total number of shards.
Note
This may not be accurate in a non-dataloader-worker process like the main training process as it doesn’t necessarily know about the number of dataloader workers.
- property shard_rank: int#
Returns the rank of the current process wrt. the total number of shards.
Note
This may not be accurate in a non-dataloader-worker process like the main training process as it doesn’t necessarily know about the number of dataloader workers.
- egrecho.utils.dist.is_global_rank_zero()[source]#
Helper function to determine if the current process is global_rank 0 (the main process)
- class egrecho.utils.dist.TorchMPLauncher(num_processes, port=None, disable_mem_share=False, start_method='spawn')[source]#
Bases:
object
Launches processes that run a given function in parallel, and joins them all at the end.
Worker processes gives a rank to
os.envrion["LOCAL_RANK"]
that ranges from 0 to N - 1.Referring to lightning fabric.
Note
This launcher requires all objects to be pickleable.
Entry point to the program/script should guarded by
if __name__ == "__main__"
.In environments like Ipython notebooks where ‘spawn’ is not available, ‘fork’ works better.
Start method ‘fork’ the user must ensure that no CUDA context gets created in the main process before the launcher is invoked, i.e., torch.cuda should be uninitialized.
- Parameters:
num_processes (
int
) -- number works.port (
Optional
[int
]) -- master port, if not set, will auto find in localhost.disable_mem_share (
bool
) -- mem_share is the feature of torch.multiprocessing. Required set True when running models on CPU, see_disable_module_memory_sharing()
.start_method (
Literal
['spawn'
,'fork'
,'forkserver'
]) -- The method how to start the processes.
Example:
launcher = TorchMPLauncher(num_processes=4) launcher.launch(my_function, arg1, arg2, kwarg1=value1)
- launch(function, *args, **kwargs)[source]#
Launches processes that run the given function in parallel.
The function is allowed to have a return value. However, when all processes join, only the return value of worker process 0 gets returned from this launch method in the main process.
- Parameters:
function (
Callable
) -- The entry point for all launched processes.*args (
Any
) -- Optional positional arguments to be passed to the given function.**kwargs (
Any
) -- Optional keyword arguments to be passed to the given function.
- Return type:
Any
egrecho.utils.imports#
- egrecho.utils.imports.is_package_available(*modules)[source]#
Returns if a top-level module with
name
exists without importing it. This is generally safer than try-catch block around aimport X
. It avoids third party libraries breaking assumptions of some of our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544).- Return type:
bool
- egrecho.utils.imports.is_module_available(module_path)#
Check if a module path is available in your environment. This will try to import it.
- Return type:
bool
Example
>>> is_module_available('torch') True >>> is_module_available('fake') False >>> is_module_available('torch.utils') True >>> is_module_available('torch.util') False
- egrecho.utils.imports.compare_version(package, op, ver, use_base_version=False)[source]#
Compare package version with some requirements.
- Return type:
bool
Example
>>> compare_version("torch", operator.ge, "0.1") True >>> compare_version("does_not_exist", operator.ge, "0.0") False
- egrecho.utils.imports.check_ort_requirements(version='1.4')#
Check onnxruntime is installed and if the installed version match is recent enough
- Raises:
ImportError -- If onnxruntime is not installed or too old version is found
- egrecho.utils.imports.lazy_import(module_name, callback=None)[source]#
Returns a proxy module object that will lazily import the given module the first time it is used.
Copied from lightning utilities.
- Parameters:
module_name -- the fully-qualified module name to import
callback (None) -- a callback function to call before importing the module
- Returns:
a proxy module object that will be lazily imported when first used
Example:
# Lazy version of `import tensorflow as tf` tf = lazy_import("tensorflow") # Other commands # Now the module is loaded tf.__version__
- class egrecho.utils.imports.LazyModule(module_name, callback=None)[source]#
Bases:
module
Proxy module that lazily imports the underlying module the first time it is actually used.
- Parameters:
module_name -- the fully-qualified module name to import
callback (None) -- a callback function to call before importing the module
egrecho.utils.logging#
- egrecho.utils.logging.get_logger(name='egrecho.utils.logging')[source]#
Get logger singleton instance based on package name.
- Return type:
- class egrecho.utils.logging.Logger(name)[source]#
Bases:
object
Sigleton pattern for logger.
- Parameters:
name (str) -- The name of the logger.
- set_level(level)[source]#
Set the logging level
- Parameters:
level (str) -- Can only be INFO, DEBUG, WARNING and ERROR.
- Return type:
None
- info(message, ranks=None, stack_pos=2, verbose=None)[source]#
Log an info message.
- Parameters:
message (str) -- The message to be logged.
ranks (List[int]) -- List of parallel ranks.
- Return type:
None
- info_once(message, ranks=None)[source]#
Log a warning, but only once.
- Parameters:
message (
str
) -- Message to displayranks (List[int]) -- List of parallel ranks.
- Return type:
None
- warning(message, ranks=None, stack_pos=2, verbose=None)[source]#
Log a warning message.
- Parameters:
message (str) -- The message to be logged.
ranks (List[int]) -- List of parallel ranks.
- Return type:
None
- warning_once(message, ranks=None)[source]#
Log a warning, but only once.
- Parameters:
message (str) -- The message to be logged.
ranks (List[int]) -- List of parallel ranks.
- Return type:
None
- debug(message, ranks=None, stack_pos=2, verbose=None)[source]#
Log a debug message.
- Parameters:
message (str) -- The message to be logged.
ranks (List[int]) -- List of parallel ranks.
- Return type:
None
egrecho.utils.mask#
- egrecho.utils.mask.subsequent_chunk_mask(size, chunk_size, num_left_chunks=-1, device=device(type='cpu'))[source]#
Create mask for subsequent steps (size, size) with chunk size
- Parameters:
size (int) -- size of mask
chunk_size (int) -- size of chunk
num_left_chunks (int) -- number of left chunks <0: use full chunk >=0: use num_left_chunks
device (torch.device) -- “cpu” or “cuda” or torch.Tensor.device
- Returns:
mask
- Return type:
torch.Tensor
Examples
>>> subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]
- egrecho.utils.mask.make_pad_mask(lengths, max_len=0)[source]#
Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
- Parameters:
lengths (torch.Tensor) -- Batch of lengths (B,).
- Returns:
Mask tensor containing indices of padded part.
- Return type:
torch.BoolTensor
Examples
>>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]]
- egrecho.utils.mask.make_non_pad_mask(lengths, max_len=0)[source]#
Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable batch computing, padding is need to make all sequence in same size. To avoid the padding part pass value to context dependent block such as attention or convolution , this padding part is masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
- Parameters:
lengths (torch.Tensor) -- Batch of lengths (B,).
- Returns:
mask tensor containing indices of no padded part.
- Return type:
torch.BoolTensor
Examples
>>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[1, 1, 1, 1 ,1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]]
- egrecho.utils.mask.prepare_4d_attention_mask(mask, dtype, tgt_len=None)[source]#
Creates a non-causal 4D mask of shape (batch_size, 1, query_length, key_value_length) from a 2D mask of shape (batch_size, key_value_length). The values in org 2D mask is 0 (no-padded) or 1 (padded).
- Parameters:
mask (torch.Tensor or None) -- A 2D attention mask of shape (batch_size, key_value_length)
dtype (torch.dtype) -- The torch dtype the created mask shall have.
tgt_len (int) -- The target length or query length the created mask shall have.
- egrecho.utils.mask.make_causal_mask(inputs, dtype, past_key_values_length=0, sliding_window=None)[source]#
Creates a causal 4d mask from given querys.
- Parameters:
inputs (
Tensor
) -- inputs query of shape [bsz, seq_len, …].dtype (
dtype
) -- the torch dtype the created mask shall have.past_key_values_length (
int
) -- cached past kv lengthsliding_window (
Optional
[int
]) -- left chunk size
- Returns:
4d attention mask (batch_size, 1, query_length, key_value_length) that can be multiplied with attention scores.
Examples
>>> make_causal_mask(torch.arange(3).view(1,3),torch.float) tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38], [ 0.0000e+00, 0.0000e+00, -3.4028e+38], [ 0.0000e+00, 0.0000e+00, 0.0000e+00]]]]) >>> make_causal_mask(torch.arange(3).view(1,3), torch.float, past_key_values_length=2) tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38], [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38], [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]]) >>> make_causal_mask(torch.arange(3).view(1,3), torch.float, past_key_values_length=2, sliding_window=2) tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38], [-3.4028e+38, 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38], [-3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]])
egrecho.utils.misc#
- egrecho.utils.misc.parse_bytes(size)[source]#
parse size from bytes to the largest possible unit
- Return type:
float
- egrecho.utils.misc.parse_size(size, unit_type='abbrev')[source]#
parse size to the largest possible unit.
- egrecho.utils.misc.kill_name_proc(grep_str, force_kill=False)[source]#
Kills processes based on a grep string.
- Parameters:
grep_str (str) -- The grep string to search for processes.
force_kill (bool) -- If True, uses
SIGKILL
(-9) to forcefully terminate processes.
The function uses
ps aux | grep
to search for processes that match the givengrep_str
. It then requests user confirmation to kill these processes. If the user agrees, it terminates the processes.
- egrecho.utils.misc.valid_import_clspath(name)[source]#
Import path must be str with dot pattern (
'calendar.Calendar'
).
- egrecho.utils.misc.class2str(value)[source]#
Extract path from class type.
Example:
from calendar import Calendar assert str(Calendar) == "<class 'calendar.Calendar'>" assert class2str(Calendar) == "calendar.Calendar"
- egrecho.utils.misc.get_import_path(value)[source]#
Returns the shortest dot import path for the given object.
- Return type:
Optional
[str
]
- egrecho.utils.misc.locate_(path)[source]#
COPIED FROM Hydra.
Locate an object by name or dotted path, importing as necessary. This is similar to the pydoc function
locate
, except that it checks for the module from the given path from back to front.Behaviours like:
path = "calendar.Calendar" m, c = path.rsplit('.', 1) mo = importlib.import_module(m) cl = getattr(mo, c)
egrecho.utils.register#
- class egrecho.utils.register.Register(name)[source]#
Bases:
object
This class is used to register functions or classes.
- Parameters:
name (str) -- The name of the registry.
Example:
LITTLES = Register("littles") @LITTLES.register(name='litclass') class LitClass: def __init__(self, a, b=123): self.a = a self.b = b or LITTLES.register(LitClass, name='litclass')
- register(fn_or_cls=None, name=None, override=False, provider_info=None, **metadata)[source]#
Adds a function or class.
- Parameters:
fn_or_cls (Callable) -- The function to be registered.
name (str) -- Name string.
override (bool) -- Whether override if exists.
provider_info (ProviderInfo) -- infos about provider, which will be merged into metadata.
**metadata (dict) -- Additional dict to be saved.
- Return type:
Callable
- get(key, with_metadata=False, strict=True, **metadata)[source]#
Retrieves functions or classes with key name which has already been registed before.
- Parameters:
key (
str
) -- Name of the registered function.with_metadata (
bool
) -- Whether to include the associated metadata in the return value.strict (
bool
) -- Whether to return all matches or just one.metadata -- Metadata used to filter against existing registry item’s metadata.
- class egrecho.utils.register.ConcatRegister(*registers)[source]#
Bases:
Register
This class is used to concatenate multiple registers.
- class egrecho.utils.register.StrRegister(name)[source]#
Bases:
object
Registers multiple strings, which can be used to create alias names.
- Parameters:
name (str) -- The name of the registry.
Example:
TOTAL_STEPS = StrRegister("total_steps") TOTAL_STEPS.register("num_training_steps") or TOTAL_STEPS.register(["num_training_steps",]) assert TOTAL_STEPS.keys() == ["total_steps","num_training_steps"] assert "num_training_steps" in TOTAL_STEPS
- class egrecho.utils.register.ConcatStrRegister(*registers)[source]#
Bases:
StrRegister
This class is used to concatenate multiple registers.
egrecho.utils.torch_utils#
- class egrecho.utils.torch_utils.RandomValue(end, start=0)[source]#
Bases:
object
Generate a uniform distribution in the range
[start, end]
.
- egrecho.utils.torch_utils.batch_pad_right(tensors, mode='constant', value=0, val_index=-1)[source]#
Given a list of torch tensors it batches them together by padding to the right on each dimension in order to get same length for all.
- Parameters:
tensors (list) -- List of tensor we wish to pad together.
mode (str) -- Padding mode see torch.nn.functional.pad documentation.
value (float) -- Padding value see torch.nn.functional.pad documentation.
- Returns:
tensor (torch.Tensor) -- Padded tensor.
valid_vals (list) -- List containing proportion for each dimension of original, non-padded values.
- egrecho.utils.torch_utils.pad_right_to(tensor, target_shape, mode='constant', value=0)[source]#
This function takes a torch tensor of arbitrary shape and pads it to target shape by appending values on the right.
- Parameters:
tensor (input torch tensor) -- Input tensor whose dimension we need to pad.
target_shape ((list, tuple)) -- Target shape we want for the target tensor its len must be equal to tensor.ndim
mode (str) -- Pad mode, please refer to torch.nn.functional.pad documentation.
value (float) -- Pad value, please refer to torch.nn.functional.pad documentation.
- Returns:
tensor (torch.Tensor) -- Padded tensor.
valid_vals (list) -- List containing proportion for each dimension of original, non-padded values.
- egrecho.utils.torch_utils.audio_collate_fn(waveforms)[source]#
Pad a list of waves with shape (…, T), returns a tuple containing tensor and a list of its lengths.
- Return type:
Tuple
[Tensor
,Tensor
]
- egrecho.utils.torch_utils.infer_framework_from_repr(x)[source]#
Tries to guess the framework of an object x from its repr (brittle but will help in is_tensor to try the frameworks in a smart order, without the need to import the frameworks).
- egrecho.utils.torch_utils.to_py_obj(obj)[source]#
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
egrecho.utils.types#
- class egrecho.utils.types.ModelOutput[source]#
Bases:
OrderedDict
Base class for all model outputs as dataclass. Copied from huggingface modelout.
Has a
__getitem__()
that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the None attributes. Otherwise behaves like a regular python dictionary.Tip
You can’t unpack a ModelOutput directly. Use the [ModelOutput.to_tuple] method to convert it to a tuple before.
- class egrecho.utils.types.SingletonMeta[source]#
Bases:
type
A metaclass for creating singleton classes.
- class egrecho.utils.types.StrEnum(value)[source]#
Bases:
str
,Enum
Type of any enumerator with allowed comparison to string invariant to cases.
>>> class MySE(StrEnum): ... t1 = "T-1" ... t2 = "T-2" >>> MySE("T-1") == MySE.t1 True >>> MySE.from_str("t-2", source="value") == MySE.t2 True >>> MySE.from_str("t-2", source="value") <MySE.t2: 'T-2'> >>> MySE.from_str("t-3", source="any") Traceback (most recent call last): ... ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
- classmethod from_str(value, source='key')[source]#
Create
StrEnum
from a string matching the key or value.- Parameters:
value (
str
) -- matching stringsource (
Literal
['key'
,'value'
,'any'
]) --compare with:
"key"
: validates only from the enum keys, typical alphanumeric with “_”"value"
: validates only from the values, could be any string"any"
: validates with any key or value, but key has priority
- Raises:
ValueError -- if requested string does not match any option based on selected source.
- Return type: