diff --git a/src/acemcp/index/manager.py b/src/acemcp/index/manager.py index 70824a4..dcb459e 100644 --- a/src/acemcp/index/manager.py +++ b/src/acemcp/index/manager.py @@ -88,8 +88,26 @@ class IndexManager: self.max_lines_per_blob = max_lines_per_blob self.exclude_patterns = exclude_patterns or [] self.projects_file = storage_path / "projects.json" + self._client: httpx.AsyncClient | None = None logger.info(f"IndexManager initialized with storage path: {storage_path}, batch_size: {batch_size}, max_lines_per_blob: {max_lines_per_blob}, exclude_patterns: {len(self.exclude_patterns)} patterns") + def _get_client(self) -> httpx.AsyncClient: + """Get or create httpx AsyncClient instance. + + Returns: + httpx.AsyncClient instance + """ + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=60.0) + logger.debug("Created new httpx.AsyncClient") + return self._client + + async def close(self) -> None: + """Close httpx client and release resources.""" + if self._client is not None and not self._client.is_closed: + await self._client.aclose() + logger.debug("Closed httpx.AsyncClient") + def _normalize_path(self, path: str) -> str: """Normalize path to use forward slashes. @@ -379,42 +397,42 @@ class IndexManager: total_batches = (len(blobs_to_upload) + self.batch_size - 1) // self.batch_size logger.info(f"Uploading {len(blobs_to_upload)} new blobs in {total_batches} batches (batch_size={self.batch_size})") - async with httpx.AsyncClient(timeout=30.0) as client: - for batch_idx in range(total_batches): - start_idx = batch_idx * self.batch_size - end_idx = min(start_idx + self.batch_size, len(blobs_to_upload)) - batch_blobs = blobs_to_upload[start_idx:end_idx] + client = self._get_client() + for batch_idx in range(total_batches): + start_idx = batch_idx * self.batch_size + end_idx = min(start_idx + self.batch_size, len(blobs_to_upload)) + batch_blobs = blobs_to_upload[start_idx:end_idx] - logger.info(f"Uploading batch {batch_idx + 1}/{total_batches} ({len(batch_blobs)} blobs)") + logger.info(f"Uploading batch {batch_idx + 1}/{total_batches} ({len(batch_blobs)} blobs)") - try: - async def upload_batch(): - payload = {"blobs": batch_blobs} - response = await client.post( - f"{self.base_url}/batch-upload", - headers={"Authorization": f"Bearer {self.token}"}, - json=payload, - ) - response.raise_for_status() - return response.json() + try: + async def upload_batch(): + payload = {"blobs": batch_blobs} + response = await client.post( + f"{self.base_url}/batch-upload", + headers={"Authorization": f"Bearer {self.token}"}, + json=payload, + ) + response.raise_for_status() + return response.json() - # Retry up to 3 times with exponential backoff - result = await self._retry_request(upload_batch, max_retries=3, retry_delay=1.0) + # Retry up to 3 times with exponential backoff + result = await self._retry_request(upload_batch, max_retries=3, retry_delay=1.0) - batch_blob_names = result.get("blob_names", []) - if not batch_blob_names: - logger.warning(f"Batch {batch_idx + 1} returned no blob names") - failed_batches.append(batch_idx + 1) - continue - - uploaded_blob_names.extend(batch_blob_names) - logger.info(f"Batch {batch_idx + 1} uploaded successfully, got {len(batch_blob_names)} blob names") - - except Exception as e: - logger.error(f"Batch {batch_idx + 1} failed after retries: {e}. Continuing with next batch...") + batch_blob_names = result.get("blob_names", []) + if not batch_blob_names: + logger.warning(f"Batch {batch_idx + 1} returned no blob names") failed_batches.append(batch_idx + 1) continue + uploaded_blob_names.extend(batch_blob_names) + logger.info(f"Batch {batch_idx + 1} uploaded successfully, got {len(batch_blob_names)} blob names") + + except Exception as e: + logger.error(f"Batch {batch_idx + 1} failed after retries: {e}. Continuing with next batch...") + failed_batches.append(batch_idx + 1) + continue + if not uploaded_blob_names and blobs_to_upload: if failed_batches: return {"status": "error", "message": f"All batches failed. Failed batches: {failed_batches}"} @@ -519,22 +537,22 @@ class IndexManager: "enable_commit_retrieval": False, } - async with httpx.AsyncClient(timeout=60.0) as client: - async def search_request(): - response = await client.post( - f"{self.base_url}/agents/codebase-retrieval", - headers={"Authorization": f"Bearer {self.token}"}, - json=payload, - ) - response.raise_for_status() - return response.json() + client = self._get_client() + async def search_request(): + response = await client.post( + f"{self.base_url}/agents/codebase-retrieval", + headers={"Authorization": f"Bearer {self.token}"}, + json=payload, + ) + response.raise_for_status() + return response.json() - # Retry up to 3 times with exponential backoff - try: - result = await self._retry_request(search_request, max_retries=3, retry_delay=2.0) - except Exception as e: - logger.error(f"Search request failed after retries: {e}") - return f"Error: Search request failed after 3 retries. {e!s}" + # Retry up to 3 times with exponential backoff + try: + result = await self._retry_request(search_request, max_retries=3, retry_delay=2.0) + except Exception as e: + logger.error(f"Search request failed after retries: {e}") + return f"Error: Search request failed after 3 retries. {e!s}" formatted_retrieval = result.get("formatted_retrieval", "") diff --git a/src/acemcp/server.py b/src/acemcp/server.py index 98f728e..5afdd4b 100644 --- a/src/acemcp/server.py +++ b/src/acemcp/server.py @@ -11,7 +11,7 @@ from mcp.types import Tool from acemcp.config import get_config, init_config from acemcp.logging_config import setup_logging -from acemcp.tools import search_context_tool +from acemcp.tools import search_context_tool, shutdown_index_manager from acemcp.web import create_app app = Server("acemcp") @@ -94,6 +94,7 @@ async def main(base_url: str | None = None, token: str | None = None, web_port: token: Override TOKEN from command line web_port: Port for web management interface (None to disable) """ + web_task: asyncio.Task | None = None try: config = init_config(base_url=base_url, token=token) config.validate() @@ -108,12 +109,13 @@ async def main(base_url: str | None = None, token: str | None = None, web_port: async with stdio_server() as (read_stream, write_stream): await app.run(read_stream, write_stream, app.create_initialization_options()) - if web_port: - web_task.cancel() - except Exception: logger.exception("Server error") raise + finally: + if web_task: + web_task.cancel() + await shutdown_index_manager() def run() -> None: @@ -140,4 +142,3 @@ def run() -> None: if __name__ == "__main__": run() - diff --git a/src/acemcp/tools/__init__.py b/src/acemcp/tools/__init__.py index d810cad..ed6acc8 100644 --- a/src/acemcp/tools/__init__.py +++ b/src/acemcp/tools/__init__.py @@ -1,6 +1,5 @@ """MCP tools for codebase indexing.""" -from acemcp.tools.search_context import search_context_tool - -__all__ = ["search_context_tool"] +from acemcp.tools.search_context import search_context_tool, shutdown_index_manager +__all__ = ["search_context_tool", "shutdown_index_manager"] diff --git a/src/acemcp/tools/search_context.py b/src/acemcp/tools/search_context.py index f203232..b5de11c 100644 --- a/src/acemcp/tools/search_context.py +++ b/src/acemcp/tools/search_context.py @@ -1,5 +1,8 @@ """Search context tool for MCP server.""" +from __future__ import annotations + +import asyncio from typing import Any from loguru import logger @@ -7,6 +10,48 @@ from loguru import logger from acemcp.config import get_config from acemcp.index import IndexManager +_index_manager: IndexManager | None = None +_index_manager_lock: asyncio.Lock | None = None + + +async def _get_index_manager() -> IndexManager: + """Create or return the shared IndexManager instance.""" + global _index_manager, _index_manager_lock + + if _index_manager is not None: + return _index_manager + + if _index_manager_lock is None: + _index_manager_lock = asyncio.Lock() + + async with _index_manager_lock: + if _index_manager is None: + config = get_config() + _index_manager = IndexManager( + config.index_storage_path, + config.base_url, + config.token, + config.text_extensions, + config.batch_size, + config.max_lines_per_blob, + config.exclude_patterns, + ) + + return _index_manager + + +async def shutdown_index_manager() -> None: + """Close the shared IndexManager instance.""" + global _index_manager, _index_manager_lock + + if _index_manager_lock is None: + _index_manager_lock = asyncio.Lock() + + async with _index_manager_lock: + if _index_manager is not None: + await _index_manager.close() + _index_manager = None + async def search_context_tool(arguments: dict[str, Any]) -> dict[str, Any]: """Search for code context based on query. @@ -31,16 +76,7 @@ async def search_context_tool(arguments: dict[str, Any]) -> dict[str, Any]: logger.info(f"Tool invoked: search_context for project {project_root_path} with query: {query}") - config = get_config() - index_manager = IndexManager( - config.index_storage_path, - config.base_url, - config.token, - config.text_extensions, - config.batch_size, - config.max_lines_per_blob, - config.exclude_patterns, - ) + index_manager = await _get_index_manager() result = await index_manager.search_context(project_root_path, query) return {"type": "text", "text": result} @@ -48,4 +84,3 @@ async def search_context_tool(arguments: dict[str, Any]) -> dict[str, Any]: except Exception as e: logger.exception("Error in search_context_tool") return {"type": "text", "text": f"Error: {e!s}"} - diff --git a/src/acemcp/web/app.py b/src/acemcp/web/app.py index 8104d13..178e067 100644 --- a/src/acemcp/web/app.py +++ b/src/acemcp/web/app.py @@ -206,5 +206,11 @@ def create_app() -> FastAPI: finally: log_broadcaster.remove_client(queue) - return app + @app.on_event("shutdown") + async def shutdown_tools() -> None: + """Release shared tool resources when the web app stops.""" + from acemcp.tools import shutdown_index_manager + await shutdown_index_manager() + + return app