Source code for engine.oracle

"""BYO decryption oracle loader.

Loads a user-supplied Python file that exposes either:

    Shape 1 — stateless function:
        def verify(candidate: bytes) -> bool: ...

    Shape 2 — stateful factory (amortizes KDF/socket setup):
        def build_oracle(config: dict) -> Oracle:
            return MyOracle(config)

        class MyOracle:
            def verify(self, candidate: bytes) -> bool: ...
            def close(self): ...   # optional

Memdiver auto-detects the shape and always hands the caller a flat
``verify(bytes) -> bool`` callable.

Security: ``--oracle`` runs arbitrary Python with the caller's privileges,
equivalent to ``find -exec``. This loader prints the sha256 of the loaded
module to stderr and refuses to load oracles from world-writable paths.
"""

from __future__ import annotations

import hashlib
import importlib.util
import logging
import stat
import sys
import tomllib
from pathlib import Path
from typing import Any, Callable, Protocol, runtime_checkable

logger = logging.getLogger("memdiver.engine.oracle")

OracleFn = Callable[[bytes], bool]


[docs] @runtime_checkable class Oracle(Protocol): """Stateful oracle shape produced by ``build_oracle(config)``."""
[docs] def verify(self, candidate: bytes) -> bool: ...
[docs] class OracleLoadError(RuntimeError): """Raised when a user oracle file cannot be loaded or validated."""
[docs] def load_oracle_config(path: Path | None) -> dict[str, Any]: """Load an optional TOML config file into a plain dict.""" if path is None: return {} config_path = Path(path) if not config_path.is_file(): raise OracleLoadError(f"oracle config not found: {config_path}") with config_path.open("rb") as fh: return tomllib.load(fh)
def _assert_safe_path(path: Path) -> None: """Refuse to load oracles from world-writable files or directories.""" if not path.is_file(): raise OracleLoadError(f"oracle file not found: {path}") file_mode = path.stat().st_mode if file_mode & stat.S_IWOTH: raise OracleLoadError( f"refusing to load world-writable oracle: {path} " f"(mode={stat.filemode(file_mode)}); tighten permissions first" ) parent_mode = path.parent.stat().st_mode if parent_mode & stat.S_IWOTH: raise OracleLoadError( f"refusing to load oracle from world-writable directory: " f"{path.parent} (mode={stat.filemode(parent_mode)})" ) def _log_module_fingerprint(path: Path) -> str: """Print the sha256 of the loaded file so the user can audit what ran.""" digest = hashlib.sha256(path.read_bytes()).hexdigest() print( f"memdiver: loaded oracle {path} sha256={digest}", file=sys.stderr, ) return digest def _import_user_module(path: Path): """Import the oracle file as an isolated module named ``memdiver_user_oracle``.""" spec = importlib.util.spec_from_file_location("memdiver_user_oracle", path) if spec is None or spec.loader is None: raise OracleLoadError(f"cannot create import spec for {path}") module = importlib.util.module_from_spec(spec) try: spec.loader.exec_module(module) except Exception as exc: raise OracleLoadError(f"failed to import oracle {path}: {exc}") from exc return module def _wrap_stateful(oracle_obj: Any, path: Path) -> OracleFn: """Validate a Shape-2 oracle object and return its bound verify method.""" verify = getattr(oracle_obj, "verify", None) if not callable(verify): raise OracleLoadError( f"{path}: build_oracle() returned {type(oracle_obj).__name__} " f"with no verify() method" ) return verify
[docs] def load_oracle( path: Path | str, config: dict[str, Any] | None = None, ) -> OracleFn: """Load a user oracle script and return a flat ``verify(bytes) -> bool``. Auto-detects Shape 1 (``def verify(candidate)``) vs Shape 2 (``def build_oracle(config) -> Oracle``). Shape 2 is preferred when the oracle needs to cache KDF state or open a network connection once; Shape 1 is fine for hot-path-only verifiers. Raises OracleLoadError on missing files, unsafe permissions, import failures, or missing/invalid exports. """ oracle_path = Path(path).resolve() _assert_safe_path(oracle_path) _log_module_fingerprint(oracle_path) module = _import_user_module(oracle_path) builder = getattr(module, "build_oracle", None) if callable(builder): oracle_obj = builder(dict(config or {})) return _wrap_stateful(oracle_obj, oracle_path) verify = getattr(module, "verify", None) if callable(verify): return verify raise OracleLoadError( f"{oracle_path}: must export either verify(candidate) -> bool " f"or build_oracle(config) -> Oracle" )