""" Mind CLI — WebSocket Tunnel 会话管线(无状态)。 连接到 Cloud 端 mindcli_bridge,接收工具调用指令并在本地执行。 采用 Browser-Donated JWT 认证:浏览器授权 CLI,CLI 不需独立认证。 无状态:不持有进程级单例、不反向 import 调用方模块(health.py)。 状态变更通过 on_status 回调通知调用方,工具调用通过 on_dispatch 回调派发。 生命周期由调用方(health.py)管理。 """ import asyncio import json import logging from typing import Callable logger = logging.getLogger("mindcli.pipelines.tunnel_session") # 连接状态 DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" async def connect( url: str, jwt: str, on_dispatch: Callable[[dict], None] | None = None, on_status: Callable[[str, int], None] | None = None, ) -> "TunnelHandle": """ 建立 Tunnel 连接,返回句柄。 Args: url: Cloud Tunnel WebSocket URL jwt: MindPass JWT(浏览器提供) on_dispatch: 工具调用派发回调(可选,默认走内部 ToolProxy) on_status: 状态变更回调 (status_str, tool_count) Returns: TunnelHandle(已在后台启动连接循环) """ handle = TunnelHandle(url, jwt, on_dispatch, on_status) handle._start_connect_loop() return handle class TunnelHandle: """ CLI → Cloud WebSocket 隧道句柄。 非单例——生命周期由调用方管理。 生命周期: 1. 调用方调 connect() → 后台启动 _connect_loop 2. 握手 + 能力协商 → Cloud 下发 approved_tools → 创建 ToolProxy 3. 消息循环:接收 tool_call → ToolProxy 执行 → 返回结果 4. 心跳维持 30s / 断线指数退避重连 """ def __init__( self, url: str, jwt: str, on_dispatch: Callable[[dict], None] | None = None, on_status: Callable[[str, int], None] | None = None, ): self._tunnel_url = url self._jwt = jwt self._on_dispatch = on_dispatch self._on_status = on_status self._status = DISCONNECTED self._ws = None self._user_id: str | None = None self._tool_proxy = None # ToolProxy 实例,握手成功后创建 # 后台任务 self._reconnect_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None self._message_task: asyncio.Task | None = None # 重连参数 self._reconnect_delay = 1.0 self._max_reconnect_delay = 30.0 self._reconnect_attempts = 0 @property def status(self) -> str: return self._status @property def user_id(self) -> str | None: return self._user_id def _set_status(self, status: str, tools: int = 0) -> None: self._status = status if self._on_status: self._on_status(status, tools) def _start_connect_loop(self) -> None: """在当前 event loop 中后台启动连接循环。""" self._reconnect_task = asyncio.create_task(self._connect_loop()) async def activate(self, jwt: str, tunnel_url: str) -> dict: """ 更新 JWT + URL 并重新连接(由调用方在浏览器重新授权时调用)。 Args: jwt: 新的 MindPass JWT tunnel_url: Cloud Tunnel WebSocket URL Returns: {"ok": True, "status": "connecting"} """ self._jwt = jwt self._tunnel_url = tunnel_url # 取消旧连接 await self.disconnect() # 后台启动新连接 self._reconnect_task = asyncio.create_task(self._connect_loop()) return {"ok": True, "status": "connecting"} async def disconnect(self) -> None: """断开 Tunnel 连接。""" # 取消所有后台任务 for task in [self._reconnect_task, self._heartbeat_task, self._message_task]: if task and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass if self._ws: try: await self._ws.close() except Exception: pass self._ws = None self._set_status(DISCONNECTED) self._reconnect_attempts = 0 self._reconnect_delay = 1.0 logger.info("[Tunnel] 已断开") def update_approved(self, tools: list[str]) -> None: """热更新工具白名单(Cloud 下发时调用)。""" if self._tool_proxy: self._tool_proxy.update_approved(tools) async def _connect_loop(self) -> None: """连接循环:握手 → 消息循环 → 断线重连。""" while True: try: self._set_status(CONNECTING) await self._connect() # 连接成功,重置重连参数 self._reconnect_delay = 1.0 self._reconnect_attempts = 0 # 进入消息循环(阻塞直到断线) await self._run() except asyncio.CancelledError: return except Exception as e: logger.warning("[Tunnel] 连接异常: %s", e) # 断线,指数退避重连 self._set_status(DISCONNECTED) self._reconnect_attempts += 1 delay = min(self._reconnect_delay, self._max_reconnect_delay) logger.info("[Tunnel] %ds 后重连(第 %d 次)...", delay, self._reconnect_attempts) await asyncio.sleep(delay) self._reconnect_delay = min(self._reconnect_delay * 2, self._max_reconnect_delay) async def _connect(self) -> None: """WebSocket 握手 + 能力协商。""" try: import websockets except ImportError: logger.error("[Tunnel] websockets 未安装。运行: pip install websockets") raise headers = {"Authorization": f"Bearer {self._jwt}"} self._ws = await websockets.connect( self._tunnel_url, additional_headers=headers, ping_interval=30, ping_timeout=10, close_timeout=5, ) # 等待 Cloud 确认 raw = await asyncio.wait_for(self._ws.recv(), timeout=10) msg = json.loads(raw) if msg.get("type") != "connected": raise ConnectionError(f"握手失败: {msg}") self._user_id = msg.get("userId") logger.info("[Tunnel] 已连接,userId=%s", self._user_id) # 发送能力报告 from mindcli.capability import scan_capabilities cap = scan_capabilities() await self._ws.send(json.dumps({ "type": "capability_report", **cap, })) # 接收审批结果 raw = await asyncio.wait_for(self._ws.recv(), timeout=10) approval = json.loads(raw) approved: list[str] = [] if approval.get("type") == "approved_tools": approved = approval.get("tools", []) # 初始化 ToolProxy(从 pipelines.tool_proxy 导入,无状态) from mindcli.pipelines.tool_proxy import ToolProxy self._tool_proxy = ToolProxy(approved_tools=approved) logger.info("[Tunnel] 审批通过工具: %s", approved) # ★ 通过回调通知调用方状态(不反向 import health.py) self._set_status(CONNECTED, len(approved)) async def _run(self) -> None: """主消息循环:接收 Cloud 指令 → 本地执行 → 返回结果。""" try: async for raw in self._ws: msg = json.loads(raw) if msg.get("jsonrpc") == "2.0" and msg.get("method") == "tool_call": await self._handle_tool_call(msg) elif msg.get("type") == "approved_tools": # 热更新白名单 if self._tool_proxy: self._tool_proxy.update_approved(msg.get("tools", [])) elif msg.get("type") == "ping": await self._ws.send(json.dumps({"type": "pong"})) else: logger.debug("[Tunnel] 未知消息: %s", msg.get("type")) except Exception as e: logger.warning("[Tunnel] 消息循环异常: %s", e) raise async def _handle_tool_call(self, msg: dict) -> None: """处理工具调用请求。""" call_id = msg.get("id", "unknown") params = msg.get("params", {}) tool_name = params.get("tool", "") tool_args = params.get("args", {}) logger.info("[Tunnel] 工具调用: %s (id=%s)", tool_name, call_id) if self._on_dispatch: # 调用方自定义派发 result = self._on_dispatch(msg) elif not self._tool_proxy: result = {"error": "ToolProxy not initialized"} else: result = await self._tool_proxy.execute(tool_name, tool_args) # 截断过大的输出(防止 WS 阻塞) output = result.get("output", "") if isinstance(output, str) and len(output) > 50000: result["output"] = output[:50000] + f"\n\n... [截断:原始 {len(output)} 字符]" result["truncated"] = True response = { "jsonrpc": "2.0", "id": call_id, "result": result, } await self._ws.send(json.dumps(response))