import asyncio
import orjson # Added for faster JSON processing
import time
import os
from pathlib import Path
from typing import Dict, Any, List, Optional, Union, Set, Tuple
from difflib import get_close_matches
import aiohttp
from .exceptions import InvalidRequestError, VNDBAPIError
from .methods.fetch import _fetch_api # Ensure this import is present
# Forward declaration for type hinting
if "VNDB" not in globals():
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .client import VNDB
[docs]
class SchemaCache:
"""
Manages the download, caching, and retrieval of the VNDB API schema.
"""
[docs]
def __init__(self, cache_dir: str = ".veedb_cache", cache_filename: str = "schema.json", ttl_hours: float = 24.0, local_schema_path: Optional[str] = None):
# Use string paths initially to avoid any Path recursion issues
self._cache_dir_str = str(cache_dir) if cache_dir else ".veedb_cache"
self._cache_filename_str = str(cache_filename) if cache_filename else "schema.json"
self._local_schema_path_str = str(local_schema_path) if local_schema_path else None
# Create Path objects only when needed and with error handling
self._cache_dir = None
self._cache_file = None
self._local_schema_path = None
self.ttl_seconds = ttl_hours * 3600
self._schema_data: Optional[Dict[str, Any]] = None
@property
def cache_dir(self) -> Path:
"""Safely get the cache directory Path object."""
if self._cache_dir is None:
try:
self._cache_dir = Path(self._cache_dir_str)
except Exception:
self._cache_dir = Path(".veedb_cache")
return self._cache_dir
@property
def cache_file(self) -> Path:
"""Safely get the cache file Path object."""
if self._cache_file is None:
try:
self._cache_file = self.cache_dir / self._cache_filename_str
except Exception:
self._cache_file = Path(".veedb_cache") / "schema.json"
return self._cache_file
@property
def local_schema_path(self) -> Optional[Path]:
"""Safely get the local schema path Path object."""
if self._local_schema_path is None and self._local_schema_path_str:
try:
self._local_schema_path = Path(self._local_schema_path_str)
except Exception:
self._local_schema_path = None
return self._local_schema_path
[docs]
def is_cached(self) -> bool:
"""Check if the schema file exists in the cache or if a local path is provided."""
if self._local_schema_path_str:
try:
# Use os.path.isfile instead of Path.is_file() to avoid recursion
return os.path.isfile(self._local_schema_path_str)
except (RecursionError, OSError, Exception):
pass
try:
# Use os.path.exists instead of Path.exists() to avoid recursion
cache_path = os.path.join(self._cache_dir_str, self._cache_filename_str)
return os.path.exists(cache_path)
except (RecursionError, OSError, Exception):
return False
[docs]
def get_cache_age(self) -> float:
"""Get the age of the cache file in seconds. Returns 0 if using local_schema_path."""
if self._local_schema_path_str:
try:
# Use os.path.isfile instead of Path methods to avoid recursion
if os.path.isfile(self._local_schema_path_str):
# Treat local schema as always up-to-date unless explicitly updated
return 0.0
except (RecursionError, OSError, Exception):
pass
try:
cache_path = os.path.join(self._cache_dir_str, self._cache_filename_str)
if not os.path.exists(cache_path):
return float('inf')
return time.time() - os.path.getmtime(cache_path)
except (RecursionError, OSError, Exception):
return float('inf')
[docs]
def is_cache_expired(self) -> bool:
"""Check if the cached schema has expired. Local schema path is never considered expired by this check."""
if self._local_schema_path_str and os.path.isfile(self._local_schema_path_str):
return False # Local schema is not subject to TTL expiration, only manual updates
return self.get_cache_age() > self.ttl_seconds
[docs]
def save_schema(self, schema_data: Dict[str, Any], to_local_path: bool = False):
"""Save the schema data to the cache file or the specified local_schema_path."""
target_path = self.local_schema_path if to_local_path and self.local_schema_path else self.cache_file
if not target_path:
# This case should ideally not be hit if logic is correct, but as a fallback:
target_path = self.cache_file
target_dir = target_path.parent
target_dir.mkdir(parents=True, exist_ok=True)
# Use orjson for writing
with open(target_path, 'wb') as f: # Open in binary mode for orjson
f.write(orjson.dumps(schema_data, option=orjson.OPT_INDENT_2))
self._schema_data = schema_data # Update in-memory cache as well
[docs]
def load_schema(self) -> Optional[Dict[str, Any]]:
"""Load the schema data from the local_schema_path (if provided) or the cache file."""
if self._local_schema_path_str and os.path.isfile(self._local_schema_path_str):
try:
# Use orjson for reading
with open(self._local_schema_path_str, 'rb') as f: # Open in binary mode for orjson
return orjson.loads(f.read())
except Exception:
# If local schema fails to load, fall back to cache or download
pass
# Use os.path instead of Path methods to avoid recursion issues
try:
cache_path_str = os.path.join(self._cache_dir_str, self._cache_filename_str)
cache_file_exists = os.path.isfile(cache_path_str) # Changed from os.path.exists to os.path.isfile
except (RecursionError, OSError, Exception):
# If there's any issue with the cache_file path, return None
return None
if cache_file_exists:
try:
# Use orjson for reading
with open(cache_path_str, 'rb') as f: # Open in binary mode for orjson
return orjson.loads(f.read())
except Exception:
pass
return None
[docs]
def invalidate_cache(self):
"""Remove the cache file. Does not remove user-provided local_schema_path."""
self._schema_data = None
try:
cache_path = os.path.join(self._cache_dir_str, self._cache_filename_str)
if os.path.exists(cache_path):
os.remove(cache_path)
except FileNotFoundError:
pass
[docs]
async def get_schema(self, client: 'VNDB', force_download: bool = False) -> Dict[str, Any]:
"""
Get the schema. Prioritizes local_schema_path, then cache, then download.
If force_download is True, it will download and update the primary schema location.
"""
if force_download:
schema = await self._download_schema(client)
# Save to local_schema_path if it's configured, otherwise to default cache file
self.save_schema(schema, to_local_path=bool(self.local_schema_path))
return schema
if self._schema_data and not self.is_cache_expired() and not (self.local_schema_path and self.local_schema_path.is_file()):
# Use in-memory if not expired AND not primarily using a local file (which would be loaded directly)
return self._schema_data
loaded_schema = self.load_schema() # Tries local_schema_path first, then cache_file
if loaded_schema and not self.is_cache_expired(): # is_cache_expired is aware of local_schema_path
self._schema_data = loaded_schema
return loaded_schema
# If local schema was specified but not found or failed to load, or cache expired/not found
schema = await self._download_schema(client)
# Save to local_schema_path if it's configured, otherwise to default cache file
self.save_schema(schema, to_local_path=bool(self.local_schema_path))
return schema
async def update_local_schema_from_api(self, client: 'VNDB') -> Dict[str, Any]:
"""Forces a download of the schema and saves it to local_schema_path if configured, else to cache."""
if not self.local_schema_path:
# If no specific local path, update the default cache file.
# Or, one might choose to raise an error if this method is called without a local_schema_path configured.
# For now, let's assume it updates the primary schema location (local if set, else cache).
pass # Fall through to get_schema with force_download
return await self.get_schema(client, force_download=True)
async def _download_schema(self, client: 'VNDB') -> Dict[str, Any]:
"""Fetch the schema from the VNDB API directly."""
try:
# Call the API directly to avoid recursion - do NOT call client.get_schema()
url = f"{client.base_url}/schema"
session = client._get_session()
# Use the imported _fetch_api
# The schema endpoint typically does not require a token.
response_data = await _fetch_api(
session=session,
method="GET",
url=url,
token=None # Explicitly None for public schema endpoint
)
if not isinstance(response_data, dict):
raise VNDBAPIError(f"Schema download did not return a valid JSON object. Received type: {type(response_data)}")
return response_data
except aiohttp.ClientError as e:
raise VNDBAPIError(f"Failed to download schema due to network/HTTP error: {e}") from e
except Exception as e:
# Catch other potential errors during the fetch or processing
raise VNDBAPIError(f"An unexpected error occurred while downloading schema: {e}") from e
[docs]
async def update_local_schema_from_api(self, client: 'VNDB') -> Dict[str, Any]:
"""Forces a download of the schema and saves it to local_schema_path if configured, else to cache."""
if not self.local_schema_path:
# If no specific local path, update the default cache file.
# Or, one might choose to raise an error if this method is called without a local_schema_path configured.
# For now, let's assume it updates the primary schema location (local if set, else cache).
pass # Fall through to get_schema with force_download
return await self.get_schema(client, force_download=True)
async def _download_schema(self, client: 'VNDB') -> Dict[str, Any]:
"""Fetch the schema from the VNDB API directly."""
try:
# Call the API directly to avoid recursion - do NOT call client.get_schema()
url = f"{client.base_url}/schema"
session = client._get_session()
# Use the imported _fetch_api
# The schema endpoint typically does not require a token.
response_data = await _fetch_api(
session=session,
method="GET",
url=url,
token=None # Explicitly None for public schema endpoint
)
if not isinstance(response_data, dict):
raise VNDBAPIError(f"Schema download did not return a valid JSON object. Received type: {type(response_data)}")
return response_data
except aiohttp.ClientError as e:
raise VNDBAPIError(f"Failed to download schema due to network/HTTP error: {e}") from e
except Exception as e:
# Catch other potential errors during the fetch or processing
raise VNDBAPIError(f"An unexpected error occurred while downloading schema: {e}") from e
[docs]
class FilterValidator:
"""
Validates filter expressions against the VNDB API schema.
"""
[docs]
def __init__(self, schema_cache: Optional[SchemaCache] = None, local_schema_path: Optional[str] = None):
self.schema_cache = schema_cache or SchemaCache(local_schema_path=local_schema_path)
self._field_cache: Dict[str, List[str]] = {}
def _extract_fields(self, schema: Dict[str, Any], endpoint: str) -> List[str]:
"""Recursively extract all valid field names for an endpoint, including nested ones."""
if endpoint in self._field_cache:
return self._field_cache[endpoint]
all_fields: Set[str] = set()
def recurse(obj: Dict[str, Any], prefix: str, full_schema: Dict[str, Any], visited_endpoints: Set[str]):
if "_inherit" in obj:
inherited_endpoint = obj["_inherit"]
if inherited_endpoint in visited_endpoints:
return # Break recursion
if inherited_endpoint in full_schema["api_fields"]:
new_visited = visited_endpoints | {inherited_endpoint}
recurse(full_schema["api_fields"][inherited_endpoint], prefix, full_schema, new_visited)
for key, value in obj.items():
if key == "_inherit":
continue
new_prefix = f"{prefix}.{key}" if prefix else key
all_fields.add(new_prefix)
if isinstance(value, dict):
# Pass the original visited_endpoints set for parallel branches
recurse(value, new_prefix, full_schema, visited_endpoints)
api_fields = schema.get("api_fields", {})
if endpoint in api_fields:
initial_visited = {endpoint}
recurse(api_fields[endpoint], "", schema, initial_visited)
field_list = sorted(list(all_fields))
self._field_cache[endpoint] = field_list
return field_list
[docs]
def suggest_fields(self, field: str, available_fields: List[str]) -> List[str]:
"""Suggest corrections for a misspelled field name."""
return get_close_matches(field, available_fields, n=3, cutoff=0.7)
[docs]
async def get_available_fields(self, endpoint: str, client: 'VNDB') -> List[str]:
"""Get all available filterable fields for a given endpoint."""
schema = await self.schema_cache.get_schema(client) # Removed force_download=False, get_schema handles logic
return self._extract_fields(schema, endpoint)
[docs]
async def list_endpoints(self, client: 'VNDB') -> List[str]:
"""List all available API endpoints from the schema."""
schema = await self.schema_cache.get_schema(client) # Removed force_download=False
return sorted(list(schema.get("api_fields", {}).keys()))
[docs]
async def validate_filters(self, endpoint: str, filters: Union[List, str, None], client: 'VNDB') -> Dict[str, Any]:
"""
Validate a filter expression for a given endpoint.
Returns:
A dictionary containing the validation result.
"""
if not filters:
return {'valid': True, 'errors': [], 'suggestions': [], 'available_fields': []}
available_fields = await self.get_available_fields(endpoint, client)
errors: List[str] = []
suggestions: Set[str] = set()
def _validate_recursive(current_filter):
if not isinstance(current_filter, list) or len(current_filter) < 1:
errors.append(f"Invalid filter format: {current_filter}")
return
operator = current_filter[0].lower()
if operator in ["and", "or"]:
if len(current_filter) < 3:
errors.append(f"'{operator}' filter requires at least two sub-filters.")
for sub_filter in current_filter[1:]:
_validate_recursive(sub_filter)
else: # Assumes a simple predicate like ["field", "op", "value"]
if len(current_filter) != 3:
errors.append(f"Simple filter predicate must have 3 elements: [field, operator, value]. Found: {current_filter}")
return
field_name = current_filter[0]
if field_name not in available_fields:
errors.append(f"Invalid field '{field_name}' for endpoint '{endpoint}'.")
field_suggestions = self.suggest_fields(field_name, available_fields)
if field_suggestions:
suggestions.update(field_suggestions)
_validate_recursive(filters)
return {
'valid': not errors,
'errors': errors,
'suggestions': sorted(list(suggestions)),
'available_fields': available_fields
}