Skip to content

Utility functions

aana.utils.asyncio

run_async

run_async(coro)

Run a coroutine in a thread if the current thread is running an event loop.

Otherwise, run the coroutine in the current asyncio loop.

Useful when you want to run an async function in a non-async context.

From: https://stackoverflow.com/a/75094151

PARAMETER DESCRIPTION
coro

The coroutine to run.

TYPE: Coroutine

RETURNS DESCRIPTION
T

The result of the coroutine.

TYPE: T

Source code in aana/utils/asyncio.py
def run_async(coro: Coroutine[Any, Any, T]) -> T:
    """Run a coroutine in a thread if the current thread is running an event loop.

    Otherwise, run the coroutine in the current asyncio loop.

    Useful when you want to run an async function in a non-async context.

    From: https://stackoverflow.com/a/75094151

    Args:
        coro (Coroutine): The coroutine to run.

    Returns:
        T: The result of the coroutine.
    """

    class RunThread(threading.Thread):
        """Run a coroutine in a thread."""

        def __init__(self, coro: Coroutine[Any, Any, T]):
            """Initialize the thread."""
            self.coro = coro
            self.result: T | None = None
            self.exception: Exception | None = None
            super().__init__()

        def run(self):
            """Run the coroutine."""
            try:
                self.result = asyncio.run(self.coro)
            except Exception as e:
                self.exception = e

    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        loop = None

    if loop and loop.is_running():
        thread = RunThread(coro)
        thread.start()
        thread.join()
        if thread.exception:
            raise thread.exception
        return thread.result
    else:
        return asyncio.run(coro)

aana.utils.json

json_serializer_default

json_serializer_default(obj)

Default function for json serializer to handle custom objects.

If json serializer does not know how to serialize an object, it calls the default function.

For example, if we see that the object is a pydantic model, we call the dict method to get the dictionary representation of the model that json serializer can deal with.

If the object is not supported, we raise a TypeError.

PARAMETER DESCRIPTION
obj

The object to serialize.

TYPE: object

RETURNS DESCRIPTION
object

The serializable object.

TYPE: object

RAISES DESCRIPTION
TypeError

If the object is not a pydantic model, Path, or Media object.

Source code in aana/utils/json.py
def json_serializer_default(obj: object) -> object:
    """Default function for json serializer to handle custom objects.

    If json serializer does not know how to serialize an object, it calls the default function.

    For example, if we see that the object is a pydantic model,
    we call the dict method to get the dictionary representation of the model
    that json serializer can deal with.

    If the object is not supported, we raise a TypeError.

    Args:
        obj (object): The object to serialize.

    Returns:
        object: The serializable object.

    Raises:
        TypeError: If the object is not a pydantic model, Path, or Media object.
    """
    if isinstance(obj, Engine):
        return None
    if isinstance(obj, BaseModel):
        return obj.model_dump()
    if isinstance(obj, Path):
        return str(obj)
    if isinstance(obj, type):
        return str(type)
    if isinstance(obj, bytes):
        return obj.decode()

    from aana.core.models.media import Media

    if isinstance(obj, Media):
        return str(obj)

    raise TypeError(type(obj))

jsonify

jsonify(data, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS, as_bytes=False)

Serialize content using orjson.

PARAMETER DESCRIPTION
data

The content to serialize.

TYPE: Any

option

The option for orjson.dumps.

TYPE: int | None DEFAULT: OPT_SERIALIZE_NUMPY | OPT_SORT_KEYS

as_bytes

Return output as bytes instead of string

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
str | bytes

bytes | str: The serialized data as desired format.

Source code in aana/utils/json.py
def jsonify(
    data: Any,
    option: int | None = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS,
    as_bytes: bool = False,
) -> str | bytes:
    """Serialize content using orjson.

    Args:
        data (Any): The content to serialize.
        option (int | None): The option for orjson.dumps.
        as_bytes (bool): Return output as bytes instead of string

    Returns:
        bytes | str: The serialized data as desired format.
    """
    output = orjson.dumps(data, option=option, default=json_serializer_default)
    return output if as_bytes else output.decode()

aana.utils.download

download_model

download_model(url, model_hash='', model_path=None, check_sum=True)

Download a model from a URL.

PARAMETER DESCRIPTION
url

the URL of the file to download

TYPE: str

model_hash

hash of the model file for checking sha256 hash if checksum is True

TYPE: str DEFAULT: ''

model_path

optional model path where it needs to be downloaded

TYPE: Path DEFAULT: None

check_sum

boolean to mention whether to check SHA-256 sum or not

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
Path

the downloaded file path

TYPE: Path

RAISES DESCRIPTION
DownloadException

Request does not succeed.

Source code in aana/utils/download.py
def download_model(
    url: str, model_hash: str = "", model_path: Path | None = None, check_sum=True
) -> Path:
    """Download a model from a URL.

    Args:
        url (str): the URL of the file to download
        model_hash (str): hash of the model file for checking sha256 hash if checksum is True
        model_path (Path): optional model path where it needs to be downloaded
        check_sum (bool): boolean to mention whether to check SHA-256 sum or not

    Returns:
        Path: the downloaded file path

    Raises:
        DownloadException: Request does not succeed.
    """
    if model_path is None:
        model_dir = settings.model_dir
        if not model_dir.exists():
            model_dir.mkdir(parents=True)
        model_path = model_dir / "model.bin"

    if model_path.exists() and not model_path.is_file():
        raise RuntimeError(f"Not a regular file: {model_path}")  # noqa: TRY003

    if not model_path.exists():
        try:
            with ExitStack() as stack:
                source = stack.enter_context(urllib.request.urlopen(url))  # noqa: S310
                output = stack.enter_context(Path.open(model_path, "wb"))

                loop = tqdm(
                    total=int(source.info().get("Content-Length")),
                    ncols=80,
                    unit="iB",
                    unit_scale=True,
                    unit_divisor=1024,
                )

                with loop:
                    while True:
                        buffer = source.read(8192)
                        if not buffer:
                            break

                        output.write(buffer)
                        loop.update(len(buffer))
        except Exception as e:
            raise DownloadException(url) from e

    model_sha256_hash = get_sha256_hash_file(model_path)
    if check_sum and model_sha256_hash != model_hash:
        checksum_error = "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
        raise RuntimeError(f"{checksum_error}")

    return model_path

download_file

download_file(url)

Download a file from a URL.

PARAMETER DESCRIPTION
url

the URL of the file to download

TYPE: str

RETURNS DESCRIPTION
bytes

the file content

TYPE: bytes

RAISES DESCRIPTION
DownloadException

Request does not succeed.

Source code in aana/utils/download.py
def download_file(url: str) -> bytes:
    """Download a file from a URL.

    Args:
        url (str): the URL of the file to download

    Returns:
        bytes: the file content

    Raises:
        DownloadException: Request does not succeed.
    """
    # TODO: add retries, check status code, etc.: add issue link
    try:
        response = requests.get(url)  # noqa: S113 TODO : add issue link
    except Exception as e:
        raise DownloadException(url) from e
    return response.content

aana.utils.gpu

get_gpu_memory

get_gpu_memory(gpu=0)

Get the total memory of a GPU in bytes.

PARAMETER DESCRIPTION
gpu

the GPU index. Defaults to 0.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
int

the total memory of the GPU in bytes

TYPE: int

Source code in aana/utils/gpu.py
def get_gpu_memory(gpu: int = 0) -> int:
    """Get the total memory of a GPU in bytes.

    Args:
        gpu (int): the GPU index. Defaults to 0.

    Returns:
        int: the total memory of the GPU in bytes
    """
    import torch

    return torch.cuda.get_device_properties(gpu).total_memory