diff --git a/requirements.txt b/requirements.txt index d995e8f..8a40850 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,10 +12,9 @@ PyYAML>=5.3.1 requests>=2.23.0 scipy>=1.4.1 thop>=0.1.1 # FLOPs computation -torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended) -torchvision>=0.8.1 +torch>=2.0.0 # see https://pytorch.org/get-started/locally (recommended) +torchvision>=0.15.0 tqdm>=4.64.0 -ultralytics>=8.0.100 # protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012 # Logging --------------------------------------------------------------------- @@ -50,7 +49,7 @@ setuptools>=65.5.1 # Snyk vulnerability fix # CLI -fire +fire==0.6.0 # AWS boto3>=1.19.1 # coco to yolov5 conversion diff --git a/yolov5/hubconf.py b/yolov5/hubconf.py index 10295bf..877a644 100644 --- a/yolov5/hubconf.py +++ b/yolov5/hubconf.py @@ -62,7 +62,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path model = DetectionModel(cfg, channels, classes) # create model if pretrained: - ckpt = torch.load(attempt_download(path), map_location=device) # load + ckpt = torch.load(attempt_download(path), weights_only=False, map_location=device) # load csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect model.load_state_dict(csd, strict=False) # load diff --git a/yolov5/models/experimental.py b/yolov5/models/experimental.py index b554401..5afd12a 100644 --- a/yolov5/models/experimental.py +++ b/yolov5/models/experimental.py @@ -76,7 +76,7 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - ckpt = torch.load(attempt_download(w), map_location='cpu') # load + ckpt = torch.load(attempt_download(w), weights_only=False, map_location='cpu') # load ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model # Model compatibility updates diff --git a/yolov5/segment/train.py b/yolov5/segment/train.py index e1dae24..6a91e69 100644 --- a/yolov5/segment/train.py +++ b/yolov5/segment/train.py @@ -118,7 +118,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak + ckpt = torch.load(weights, weights_only=False, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 @@ -544,7 +544,7 @@ def main(opt, callbacks=Callbacks()): with open(opt_yaml, errors='ignore') as f: d = yaml.safe_load(f) else: - d = torch.load(last, map_location='cpu')['opt'] + d = torch.load(last, weights_only=False, map_location='cpu')['opt'] opt = argparse.Namespace(**d) # replace opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate if is_url(opt_data): diff --git a/yolov5/train.py b/yolov5/train.py index 153f504..fe902dc 100644 --- a/yolov5/train.py +++ b/yolov5/train.py @@ -137,7 +137,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak + ckpt = torch.load(weights, weights_only=False, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 @@ -578,7 +578,7 @@ def main(opt, callbacks=Callbacks()): with open(opt_yaml, errors='ignore') as f: d = yaml.safe_load(f) else: - d = torch.load(last, map_location='cpu')['opt'] + d = torch.load(last, weights_only=False, map_location='cpu')['opt'] if opt.neptune_resume_id is not None: d["neptune_resume_id"] = opt.neptune_resume_id opt = argparse.Namespace(**d) # replace diff --git a/yolov5/utils/general.py b/yolov5/utils/general.py index 16fd865..def4e78 100644 --- a/yolov5/utils/general.py +++ b/yolov5/utils/general.py @@ -4,6 +4,7 @@ """ import contextlib +import functools import glob import inspect import logging @@ -20,11 +21,13 @@ import urllib from copy import deepcopy from datetime import datetime +from importlib import metadata from itertools import repeat from multiprocessing.pool import ThreadPool from pathlib import Path from subprocess import check_output from tarfile import is_tarfile +from types import SimpleNamespace from typing import Optional from zipfile import ZipFile, is_zipfile @@ -36,8 +39,6 @@ import torchvision import yaml -from ultralytics.utils.checks import check_requirements - from yolov5.utils import TryExcept, emojis from yolov5.utils.downloads import curl_download, gsutil_getsize from yolov5.utils.metrics import box_iou, fitness @@ -121,6 +122,316 @@ def is_writeable(dir, test=False): return True except OSError: return False + + +# =========================================================================== +# Until indicated otherwise, the following code has been copied verbatim +# from the ultralytics repository: https://github.com/ultralytics/ultralytics +# which is licensed under the AGPL-3.0 license included in this project. +# =========================================================================== + +def is_online() -> bool: + """ + Check internet connectivity by attempting to connect to a known online host. + + Returns: + (bool): True if connection is successful, False otherwise. + """ + try: + assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True" + import socket + + for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS + socket.create_connection(address=(dns, 80), timeout=2.0).close() + return True + except Exception: + return False + + +ONLINE = is_online() +ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans +MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans + + +class Retry(contextlib.ContextDecorator): + """ + Retry class for function execution with exponential backoff. + + This decorator can be used to retry a function on exceptions, up to a specified number of times with an + exponentially increasing delay between retries. It's useful for handling transient failures in network + operations or other unreliable processes. + + Attributes: + times (int): Maximum number of retry attempts. + delay (int): Initial delay between retries in seconds. + + Examples: + Example usage as a decorator: + >>> @Retry(times=3, delay=2) + >>> def test_func(): + >>> # Replace with function logic that may raise exceptions + >>> return True + """ + + def __init__(self, times=3, delay=2): + """Initialize Retry class with specified number of retries and delay.""" + self.times = times + self.delay = delay + self._attempts = 0 + + def __call__(self, func): + """Decorator implementation for Retry with exponential backoff.""" + + def wrapped_func(*args, **kwargs): + """Apply retries to the decorated function or method.""" + self._attempts = 0 + while self._attempts < self.times: + try: + return func(*args, **kwargs) + except Exception as e: + self._attempts += 1 + LOGGER.warning(f"Retry {self._attempts}/{self.times} failed: {e}") + if self._attempts >= self.times: + raise e + time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay + + return wrapped_func + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file. + + Returns: + requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and + `specifier` attributes. + + Examples: + >>> from ultralytics.utils.checks import parse_requirements + >>> parse_requirements(package="ultralytics") + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.partition("#")[0].strip() # ignore inline comments + match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) + if match: + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +@functools.lru_cache +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +@functools.lru_cache +def check_uv(): + """Check if uv package manager is installed and can run successfully.""" + try: + return subprocess.run(["uv", "-V"], capture_output=True).returncode == 0 + except FileNotFoundError: + return False + + +@functools.lru_cache +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str): Name to be used in warning message. + hard (bool): If True, raise an AssertionError if the requirement is not met. + verbose (bool): If True, print warning message if requirement is not met. + msg (str): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Examples: + Check if current version is exactly 22.04 + >>> check_version(current="22.04", required="==22.04") + + Check if current version is greater than or equal to 22.04 + >>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + Check if current version is less than or equal to 22.04 + >>> check_version(current="22.04", required="<=22.04") + + Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + >>> check_version(current="21.10", required=">20.04,<22.04") + """ + if not current: # if current is '' or None + LOGGER.warning(f"invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(f"{current} package is required but not installed") from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"{name}{required} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(warning) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed. + + Args: + requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (tuple): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Examples: + >>> from ultralytics.utils.checks import check_requirements + + Check a requirements.txt file + >>> check_requirements("path/to/requirements.txt") + + Check a single package + >>> check_requirements("ultralytics>=8.0.0") + + Check multiple packages + >>> check_requirements(["numpy", "ultralytics>=8.0.0"]) + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.rpartition("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands, use_uv): + """Attempt package installation with uv if available, falling back to pip.""" + if use_uv: + base = f"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow" + try: + return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode() + except subprocess.CalledProcessError as e: + if e.stderr and "No virtual environment found" in e.stderr.decode(): + return subprocess.check_output( + base.replace("uv pip install", "uv pip install --system"), shell=True + ).decode() + raise + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + # Note uv fails on arm64 macOS and Raspberry Pi runners + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv())) + dt = time.time() - t + LOGGER.info(f"{prefix} AutoUpdate success ✅ {dt:.1f}s") + LOGGER.warning( + f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + +# =========================================================================== +# This marker indicates the end of the code which has been copied verbatim +# from the ultralytics repository: https://github.com/ultralytics/ultralytics +# which is licensed under the AGPL-3.0 license included in this project. +# =========================================================================== LOGGING_NAME = 'yolov5' @@ -969,7 +1280,7 @@ def non_max_suppression( def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() # Strip optimizer from 'f' to finalize training, optionally save as 's' - x = torch.load(f, map_location=torch.device('cpu')) + x = torch.load(f, weights_only=False, map_location=torch.device('cpu')) if x.get('ema'): x['model'] = x['ema'] # replace model with ema for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys