2026-03-12 17:23:08 +08:00

335 lines
10 KiB
Python

#!/usr/bin/env python3
"""Shared helpers for Gitea skill scripts."""
from __future__ import annotations
import json
import os
import re
import subprocess
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
SSH_RE = re.compile(r"^git@(?P<host>[^:]+):(?P<path>.+?)(?:\.git)?/?$")
REPO_PATH_RE = re.compile(r"^(?P<owner>[^/]+)/(?P<repo>[^/]+)$")
class GitCommandError(RuntimeError):
"""Raised when a git command fails."""
@dataclass
class RepoContext:
origin: str
owner: str
repo: str
repo_url: str
remote_name: str = "origin"
remote_url: str | None = None
def normalize_base_url(base_url: str) -> str:
return base_url.rstrip("/")
def ensure_git_repo() -> None:
result = subprocess.run(
["git", "rev-parse", "--is-inside-work-tree"],
capture_output=True,
text=True,
)
if result.returncode != 0 or result.stdout.strip() != "true":
raise SystemExit("Current directory is not a git repository.")
def run_git(args: list[str], cwd: str | None = None, check: bool = True) -> str:
result = subprocess.run(
["git", *args],
capture_output=True,
text=True,
cwd=cwd,
)
if check and result.returncode != 0:
detail = result.stderr.strip() or result.stdout.strip() or "unknown git error"
raise GitCommandError(f"git {' '.join(args)} failed: {detail}")
return result.stdout.strip()
def parse_http_repo_url(repo_url: str) -> tuple[str, str, str, str]:
parsed = urllib.parse.urlsplit(repo_url)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise SystemExit("Invalid repo URL. Expected format: https://host/owner/repo")
path = parsed.path.rstrip("/")
if path.endswith(".git"):
path = path[:-4]
parts = [part for part in path.split("/") if part]
if len(parts) < 2:
raise SystemExit("Invalid repo URL. Expected format: https://host/owner/repo")
owner, repo = parts[-2], parts[-1]
prefix = "/".join(parts[:-2])
origin = f"{parsed.scheme}://{parsed.netloc}"
if prefix:
origin = f"{origin}/{prefix}"
normalized_repo_url = f"{origin}/{owner}/{repo}"
return origin, owner, repo, normalized_repo_url
def parse_repo_target(
repo_target: str,
base_url: str | None = None,
) -> tuple[str, str, str, str]:
value = repo_target.strip()
if not value:
raise SystemExit("Repo target cannot be empty.")
if value.startswith("http://") or value.startswith("https://"):
return parse_http_repo_url(value)
if value.startswith("ssh://"):
parsed = urllib.parse.urlsplit(value)
path = parsed.path.rstrip("/")
if path.endswith(".git"):
path = path[:-4]
parts = [part for part in path.split("/") if part]
if len(parts) < 2:
raise SystemExit(
"Invalid SSH repo URL. Expected ssh://git@host/owner/repo.git"
)
owner, repo = parts[-2], parts[-1]
prefix = "/".join(parts[:-2])
host = parsed.hostname
if not host:
raise SystemExit("Invalid SSH repo URL. Missing host.")
origin = (
normalize_base_url(base_url)
if base_url
else f"https://{host}{f'/{prefix}' if prefix else ''}"
)
return origin, owner, repo, f"{origin}/{owner}/{repo}"
ssh_match = SSH_RE.match(value)
if ssh_match:
path = ssh_match.group("path").rstrip("/")
if path.endswith(".git"):
path = path[:-4]
parts = [part for part in path.split("/") if part]
if len(parts) < 2:
raise SystemExit(
"Invalid SSH repo target. Expected git@host:owner/repo.git"
)
owner, repo = parts[-2], parts[-1]
prefix = "/".join(parts[:-2])
origin = (
normalize_base_url(base_url)
if base_url
else f"https://{ssh_match.group('host')}{f'/{prefix}' if prefix else ''}"
)
return origin, owner, repo, f"{origin}/{owner}/{repo}"
path_match = REPO_PATH_RE.match(value)
if path_match:
if not base_url:
raise SystemExit(
"Repo shorthand owner/repo requires GITEA_BASE_URL or a full repo URL."
)
origin = normalize_base_url(base_url)
owner = path_match.group("owner")
repo = path_match.group("repo")
return origin, owner, repo, f"{origin}/{owner}/{repo}"
raise SystemExit(
"Invalid repo target. Use https://host/owner/repo, git@host:owner/repo.git, "
"ssh://git@host/owner/repo.git, or owner/repo with GITEA_BASE_URL."
)
def get_remote_url(remote: str = "origin") -> str:
ensure_git_repo()
try:
return run_git(["remote", "get-url", remote])
except GitCommandError as exc:
raise SystemExit(f"Failed to read git remote '{remote}': {exc}") from exc
def resolve_repo(
repo_url: str | None = None,
repo: str | None = None,
remote: str = "origin",
) -> RepoContext:
base_url = os.getenv("GITEA_BASE_URL")
remote_url = None
target = repo_url or repo
if not target:
remote_url = get_remote_url(remote)
target = remote_url
origin, owner, repo_name, normalized_repo_url = parse_repo_target(
target,
base_url=base_url,
)
return RepoContext(
origin=origin,
owner=owner,
repo=repo_name,
repo_url=normalized_repo_url,
remote_name=remote,
remote_url=remote_url,
)
def api_base(context: RepoContext) -> str:
return f"{context.origin}/api/v1/repos/{context.owner}/{context.repo}"
def load_token(required: bool = True) -> str:
token = os.getenv("GITEA_TOKEN", "").strip()
if not token and required:
raise SystemExit("Missing GITEA_TOKEN. Export it before using Gitea workflows.")
return token
def request_json(
url: str,
token: str,
method: str = "GET",
payload: dict | None = None,
) -> dict | list:
data = None
headers = {"Accept": "application/json"}
if token:
headers["Authorization"] = f"token {token}"
if payload is not None:
data = json.dumps(payload).encode("utf-8")
headers["Content-Type"] = "application/json"
request = urllib.request.Request(url, data=data, headers=headers, method=method)
try:
with urllib.request.urlopen(request) as response:
return json.load(response)
except urllib.error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
raise SystemExit(
f"Gitea API request failed: {exc.code} {exc.reason} | {body}"
) from exc
except urllib.error.URLError as exc:
raise SystemExit(f"Failed to reach Gitea API: {exc.reason}") from exc
def get_current_user_login(origin: str, token: str) -> str:
data = request_json(f"{origin}/api/v1/user", token)
if not isinstance(data, dict):
raise SystemExit("Unexpected user API response.")
login = str(data.get("login") or "").strip()
if not login:
raise SystemExit("Failed to resolve current Gitea user login.")
return login
def current_branch() -> str:
ensure_git_repo()
try:
branch = run_git(["symbolic-ref", "--quiet", "--short", "HEAD"])
except GitCommandError as exc:
raise SystemExit(
"Detached HEAD. Checkout a branch before running push or PR actions."
) from exc
if not branch:
raise SystemExit("Failed to determine current git branch.")
return branch
def get_upstream(branch: str) -> str | None:
result = subprocess.run(
[
"git",
"rev-parse",
"--abbrev-ref",
"--symbolic-full-name",
f"{branch}@{{upstream}}",
],
capture_output=True,
text=True,
)
if result.returncode != 0:
return None
upstream = result.stdout.strip()
return upstream or None
def remote_branch_sha(remote: str, branch: str) -> str | None:
try:
output = run_git(["ls-remote", "--heads", remote, branch])
except GitCommandError as exc:
raise SystemExit(f"Failed to query remote branch '{remote}/{branch}': {exc}") from exc
if not output:
return None
return output.split()[0]
def ahead_behind(compare_ref: str) -> tuple[int, int]:
output = run_git(["rev-list", "--left-right", "--count", f"{compare_ref}...HEAD"])
parts = output.split()
if len(parts) != 2:
raise SystemExit(f"Unexpected rev-list output: {output}")
behind, ahead = (int(part) for part in parts)
return behind, ahead
def worktree_changes() -> list[str]:
output = run_git(["status", "--short"])
return [line for line in output.splitlines() if line.strip()]
def remote_default_branch(remote: str = "origin") -> str:
try:
output = run_git(["symbolic-ref", "--quiet", "--short", f"refs/remotes/{remote}/HEAD"])
if output.startswith(f"{remote}/"):
return output.split("/", 1)[1]
except GitCommandError:
pass
for candidate in ("main", "master"):
if remote_branch_sha(remote, candidate):
return candidate
raise SystemExit(
f"Failed to infer default branch for remote '{remote}'. Pass --base explicitly."
)
def build_authenticated_push_url(repo_url: str, username: str, token: str) -> str:
parsed = urllib.parse.urlsplit(repo_url)
path = parsed.path.rstrip("/")
if not path.endswith(".git"):
path = f"{path}.git"
netloc = (
f"{urllib.parse.quote(username, safe='')}:"
f"{urllib.parse.quote(token, safe='')}@{parsed.netloc}"
)
return urllib.parse.urlunsplit((parsed.scheme or "https", netloc, path, "", ""))
def mask_url(url: str) -> str:
parsed = urllib.parse.urlsplit(url)
if "@" not in parsed.netloc:
return url
_, host = parsed.netloc.rsplit("@", 1)
return urllib.parse.urlunsplit((parsed.scheme, f"***@{host}", parsed.path, "", ""))
def is_auth_error(stderr: str) -> bool:
message = stderr.lower()
indicators = (
"authentication failed",
"permission denied",
"could not read username",
"could not read password",
"http basic: access denied",
"access denied",
"unauthorized",
)
return any(indicator in message for indicator in indicators)