feat: v2.1 BYOK + 自动流水线
- 统一为 DASHSCOPE_API_KEY(百炼 Key 通吃 LLM + Embedding) - import-voc 后自动触发 UDE 转写 + 向量化(后台 asyncio task) - 新增 GET /pipeline-status 查询流水线进度 - run_clustering 变纯 CPU(向量已预计算) - 新增独立 run_vectorization 函数 - 修复 Python 3.9 类型注解兼容性
This commit is contained in:
parent
3cd7d4776d
commit
f61c255b9d
13
CHANGELOG.md
13
CHANGELOG.md
@ -1,5 +1,18 @@
|
|||||||
# 更新日志 (Changelog)
|
# 更新日志 (Changelog)
|
||||||
|
|
||||||
|
## v2.1.0 (2026-04-07)
|
||||||
|
|
||||||
|
### 🚀 功能 (Features)
|
||||||
|
- **BYOK(自带弹药)模式**:合伙人通过 `X-DashScope-Key` Header 传入百炼 Key,服务器不承担任何 Token 成本。百炼 Key 一个通吃 LLM 推理 + Embedding 向量化。
|
||||||
|
- **自动流水线**:`import-voc` 完成后自动异步触发 UDE 转写 + 向量化,无需 CoPaw Agent 逐步编排。
|
||||||
|
- **`GET /pipeline-status`**:新增流水线进度查询接口,支持 CoPaw 轮询。
|
||||||
|
- **聚类零调用**:`/ude/cluster` 变为纯 CPU 操作(毫秒级),向量已在流水线中预计算。
|
||||||
|
|
||||||
|
### 🛠️ 重构与优化 (Refactor & Optimization)
|
||||||
|
- ~~BAILIAN_API_KEY~~ 废弃,统一为 `DASHSCOPE_API_KEY`(两把钥匙取代三把)。
|
||||||
|
- 新增 `run_vectorization` 独立函数,从 `run_clustering` 中拆出向量化逻辑。
|
||||||
|
- `_get_llm_client` 默认直连百炼 API,移除 LiteLLM 代理依赖。
|
||||||
|
|
||||||
## v2.0.0 (2026-04-07)
|
## v2.0.0 (2026-04-07)
|
||||||
|
|
||||||
### 🚀 功能 (Features)
|
### 🚀 功能 (Features)
|
||||||
|
|||||||
@ -1,19 +1,17 @@
|
|||||||
# 黑手党提案后端 v2.0 — 环境变量
|
# 黑手党提案后端 v2.1 — BYOK 模式
|
||||||
# 完全独立,阿里云内闭环
|
# 合伙人通过 Header 传入百炼 Key,服务器仅提供算法编排
|
||||||
|
|
||||||
# VOC 公网 API(跨云只读访问,用于 import-voc)
|
# VOC 公网 API(跨云只读访问)
|
||||||
VOC_API_BASE=https://brand.brainwork.club/voc/api/research
|
VOC_API_BASE=https://brand.brainwork.club/voc/api/research
|
||||||
|
|
||||||
# LLM 路由(走同机 LiteLLM)
|
# Fallback DashScope Key(平台方自测用)
|
||||||
LITELLM_PROXY_URL=http://127.0.0.1:4000/v1
|
# 合伙人通过 X-DashScope-Key Header 传入自己的 Key
|
||||||
LITELLM_MASTER_KEY=sk-xxx
|
# 百炼 Key 一个通吃:LLM 推理 + Embedding 向量化
|
||||||
|
DASHSCOPE_API_KEY=sk-xxx
|
||||||
|
|
||||||
# 模型
|
# 模型
|
||||||
MODEL_ID=qwen-plus
|
MODEL_ID=qwen-plus
|
||||||
TEMPERATURE=0.1
|
TEMPERATURE=0.1
|
||||||
|
|
||||||
# DashScope(向量化用)
|
|
||||||
DASHSCOPE_API_KEY=sk-xxx
|
|
||||||
|
|
||||||
# 端口
|
# 端口
|
||||||
PORT=8093
|
PORT=8093
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
"""
|
"""
|
||||||
黑手党提案 — 独立后端(阿里云内闭环)
|
黑手党提案 — 独立后端(BYOK v2.1)
|
||||||
|
|
||||||
FastAPI 服务,端口 8093。
|
FastAPI 服务,端口 8093。
|
||||||
VOC 数据通过公网 API 导入,不直读 VOC DB。
|
合伙人通过 X-DashScope-Key Header 传入百炼 Key,服务器统一用于 LLM + Embedding。
|
||||||
|
import-voc 完成后自动触发 UDE 转写 + 向量化流水线。
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from fastapi import FastAPI, Header, HTTPException, Query
|
from fastapi import FastAPI, Header, HTTPException, Query
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@ -20,7 +23,7 @@ from db import get_case_conn, init_case_db, list_cases as _list_cases, DATA_DIR,
|
|||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s %(message)s")
|
||||||
logger = logging.getLogger("mafia")
|
logger = logging.getLogger("mafia")
|
||||||
|
|
||||||
app = FastAPI(title="黑手党提案后端", version="2.0.0", description="独立后端:阿里云内闭环,VOC 通过 API 导入")
|
app = FastAPI(title="黑手党提案后端", version="2.1.0", description="独立后端:阿里云内闭环 + BYOK,合伙人自带 API Key")
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
@ -29,6 +32,20 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════ BYOK Key 解析 ═══════════
|
||||||
|
|
||||||
|
def resolve_key(header_key: Optional[str]) -> str:
|
||||||
|
"""优先使用合伙人 Header 传入的 DashScope Key(百炼 Key 一个通吃 LLM + Embedding)"""
|
||||||
|
if header_key:
|
||||||
|
logger.info(f"[BYOK] 使用合伙人 Key: ...{header_key[-6:]}")
|
||||||
|
return header_key
|
||||||
|
env_val = os.getenv("DASHSCOPE_API_KEY", "")
|
||||||
|
if env_val:
|
||||||
|
logger.info(f"[BYOK] 使用 fallback Key: ...{env_val[-6:]}")
|
||||||
|
return env_val
|
||||||
|
raise HTTPException(401, "缺少 DashScope API Key。请在 CoPaw 中配置 DASHSCOPE_API_KEY 并通过 X-DashScope-Key Header 传入")
|
||||||
|
|
||||||
|
|
||||||
# ═══════════ Models ═══════════
|
# ═══════════ Models ═══════════
|
||||||
|
|
||||||
class CreateCaseRequest(BaseModel):
|
class CreateCaseRequest(BaseModel):
|
||||||
@ -91,7 +108,35 @@ async def delete_case(case_id: str):
|
|||||||
raise HTTPException(404, "案例不存在")
|
raise HTTPException(404, "案例不存在")
|
||||||
|
|
||||||
|
|
||||||
# ═══════════ VOC 导入(跨云 API) ═══════════
|
# ═══════════ VOC 导入 + 自动流水线 ═══════════
|
||||||
|
|
||||||
|
# 后台流水线状态(内存缓存,进程级)
|
||||||
|
_pipeline_status: dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _auto_pipeline(case_id: str, dashscope_key: str):
|
||||||
|
"""后台自动流水线:UDE 转写 → 向量化"""
|
||||||
|
from tools.ude_extract import run_ude_extraction, run_vectorization
|
||||||
|
|
||||||
|
_pipeline_status[case_id] = {"stage": "extracting", "progress": {}}
|
||||||
|
try:
|
||||||
|
# Stage 1: UDE 转写
|
||||||
|
result = await run_ude_extraction(case_id, limit=0, dashscope_key=dashscope_key)
|
||||||
|
_pipeline_status[case_id]["progress"]["extraction"] = result
|
||||||
|
logger.info(f"[Pipeline] {case_id} UDE 转写完成: {result.get('totalUdes', 0)} 条")
|
||||||
|
|
||||||
|
# Stage 2: 向量化
|
||||||
|
_pipeline_status[case_id]["stage"] = "vectorizing"
|
||||||
|
vec_result = run_vectorization(case_id, dashscope_key=dashscope_key)
|
||||||
|
_pipeline_status[case_id]["progress"]["vectorization"] = vec_result
|
||||||
|
logger.info(f"[Pipeline] {case_id} 向量化完成: {vec_result.get('vectorized', 0)} 条")
|
||||||
|
|
||||||
|
_pipeline_status[case_id]["stage"] = "done"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Pipeline] {case_id} 失败: {e}")
|
||||||
|
_pipeline_status[case_id]["stage"] = "error"
|
||||||
|
_pipeline_status[case_id]["error"] = str(e)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/cases/{case_id}/link-voc")
|
@app.post("/api/cases/{case_id}/link-voc")
|
||||||
async def link_voc(case_id: str, req: LinkVocRequest):
|
async def link_voc(case_id: str, req: LinkVocRequest):
|
||||||
@ -106,8 +151,15 @@ async def link_voc(case_id: str, req: LinkVocRequest):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/cases/{case_id}/import-voc")
|
@app.post("/api/cases/{case_id}/import-voc")
|
||||||
async def import_voc(case_id: str, page: int = Query(1), pageSize: int = Query(100)):
|
async def import_voc(
|
||||||
"""从 VOC 公网 API 拉取评论数据,存入本地案例 DB"""
|
case_id: str,
|
||||||
|
page: int = Query(1),
|
||||||
|
pageSize: int = Query(100),
|
||||||
|
x_dashscope_key: str = Header(None),
|
||||||
|
):
|
||||||
|
"""从 VOC 公网 API 拉取评论,完成后自动触发 UDE 转写 + 向量化"""
|
||||||
|
dashscope_key = resolve_key(x_dashscope_key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with get_case_conn(case_id) as conn:
|
with get_case_conn(case_id) as conn:
|
||||||
card = conn.execute("SELECT voc_research_id, voc_api_base FROM case_card LIMIT 1").fetchone()
|
card = conn.execute("SELECT voc_research_id, voc_api_base FROM case_card LIMIT 1").fetchone()
|
||||||
@ -120,7 +172,6 @@ async def import_voc(case_id: str, page: int = Query(1), pageSize: int = Query(1
|
|||||||
voc_rid = card["voc_research_id"]
|
voc_rid = card["voc_research_id"]
|
||||||
api_base = card["voc_api_base"] or VOC_API_BASE
|
api_base = card["voc_api_base"] or VOC_API_BASE
|
||||||
|
|
||||||
# 从 VOC API 拉取(只读,不需要 TikHub Key)
|
|
||||||
total_imported = 0
|
total_imported = 0
|
||||||
current_page = page
|
current_page = page
|
||||||
|
|
||||||
@ -167,18 +218,45 @@ async def import_voc(case_id: str, page: int = Query(1), pageSize: int = Query(1
|
|||||||
break
|
break
|
||||||
current_page += 1
|
current_page += 1
|
||||||
|
|
||||||
# 更新统计
|
|
||||||
with get_case_conn(case_id) as conn:
|
with get_case_conn(case_id) as conn:
|
||||||
local_count = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
local_count = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
||||||
|
|
||||||
|
# 自动触发后台流水线(UDE 转写 + 向量化)
|
||||||
|
pipeline_triggered = False
|
||||||
|
if local_count > 0:
|
||||||
|
asyncio.create_task(_auto_pipeline(case_id, dashscope_key))
|
||||||
|
pipeline_triggered = True
|
||||||
|
logger.info(f"[Import] {case_id} 自动流水线已触发({local_count} 条评论)")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"imported": total_imported,
|
"imported": total_imported,
|
||||||
"totalLocal": local_count,
|
"totalLocal": local_count,
|
||||||
"vocResearchId": voc_rid,
|
"vocResearchId": voc_rid,
|
||||||
"pagesProcessed": current_page - page + 1,
|
"pagesProcessed": current_page - page + 1,
|
||||||
|
"pipelineTriggered": pipeline_triggered,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/cases/{case_id}/pipeline-status")
|
||||||
|
async def pipeline_status(case_id: str):
|
||||||
|
"""查询后台流水线进度"""
|
||||||
|
status = _pipeline_status.get(case_id)
|
||||||
|
if not status:
|
||||||
|
# 没有正在运行的流水线,从 DB 推断状态
|
||||||
|
try:
|
||||||
|
with get_case_conn(case_id) as conn:
|
||||||
|
comments = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
||||||
|
udes = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
|
||||||
|
vectorized = conn.execute("SELECT count(*) FROM ude_sentences WHERE vector IS NOT NULL").fetchone()[0]
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise HTTPException(404, "案例不存在")
|
||||||
|
return {
|
||||||
|
"stage": "idle",
|
||||||
|
"progress": {"comments": comments, "udesExtracted": udes, "udesVectorized": vectorized},
|
||||||
|
}
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/cases/{case_id}/comments")
|
@app.get("/api/cases/{case_id}/comments")
|
||||||
async def get_comments(case_id: str, page: int = 1, pageSize: int = 50):
|
async def get_comments(case_id: str, page: int = 1, pageSize: int = 50):
|
||||||
"""查看本地导入的评论"""
|
"""查看本地导入的评论"""
|
||||||
@ -195,13 +273,19 @@ async def get_comments(case_id: str, page: int = 1, pageSize: int = 50):
|
|||||||
return {"total": total, "page": page, "items": [dict(r) for r in rows]}
|
return {"total": total, "page": page, "items": [dict(r) for r in rows]}
|
||||||
|
|
||||||
|
|
||||||
# ═══════════ UDE 分析 ═══════════
|
# ═══════════ UDE 分析(BYOK) ═══════════
|
||||||
|
|
||||||
@app.post("/api/cases/{case_id}/ude/extract")
|
@app.post("/api/cases/{case_id}/ude/extract")
|
||||||
async def extract_ude(case_id: str, limit: int = Query(0)):
|
async def extract_ude(
|
||||||
|
case_id: str,
|
||||||
|
limit: int = Query(0),
|
||||||
|
x_dashscope_key: str = Header(None),
|
||||||
|
):
|
||||||
|
"""手动触发 UDE 转写(通常由 import-voc 自动触发,此接口用于补跑)"""
|
||||||
from tools.ude_extract import run_ude_extraction
|
from tools.ude_extract import run_ude_extraction
|
||||||
|
key = resolve_key(x_dashscope_key)
|
||||||
try:
|
try:
|
||||||
result = await run_ude_extraction(case_id, limit)
|
result = await run_ude_extraction(case_id, limit, dashscope_key=key)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise HTTPException(404, str(e))
|
raise HTTPException(404, str(e))
|
||||||
return result
|
return result
|
||||||
@ -212,12 +296,11 @@ async def cluster_ude(
|
|||||||
case_id: str,
|
case_id: str,
|
||||||
eps: float = Query(0.25),
|
eps: float = Query(0.25),
|
||||||
minSamples: int = Query(3),
|
minSamples: int = Query(3),
|
||||||
x_dashscope_key: str = Header(None),
|
|
||||||
):
|
):
|
||||||
|
"""向量聚类(纯 CPU,不调外部 API。向量已在流水线中预计算)"""
|
||||||
from tools.ude_extract import run_clustering
|
from tools.ude_extract import run_clustering
|
||||||
key = x_dashscope_key or os.getenv("DASHSCOPE_API_KEY", "")
|
|
||||||
try:
|
try:
|
||||||
result = run_clustering(case_id, eps, minSamples, dashscope_key=key)
|
result = run_clustering(case_id, eps, minSamples)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise HTTPException(404, str(e))
|
raise HTTPException(404, str(e))
|
||||||
return result
|
return result
|
||||||
@ -251,10 +334,14 @@ async def get_coverage(case_id: str):
|
|||||||
async def health():
|
async def health():
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"version": "2.0.0",
|
"version": "2.1.0",
|
||||||
"architecture": "independent (Aliyun self-contained)",
|
"architecture": "BYOK (Bring Your Own Key)",
|
||||||
"vocApiBase": VOC_API_BASE,
|
"vocApiBase": VOC_API_BASE,
|
||||||
"caseDataDir": str(DATA_DIR),
|
"caseDataDir": str(DATA_DIR),
|
||||||
|
"byok": {
|
||||||
|
"bailianFallback": bool(os.getenv("BAILIAN_API_KEY")),
|
||||||
|
"dashscopeFallback": bool(os.getenv("DASHSCOPE_API_KEY")),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
黑手党提案 — UDE 提取工具(阿里云内闭环)
|
黑手党提案 — UDE 提取工具(BYOK v2.1)
|
||||||
|
|
||||||
流程:本地 comments → LLM 转写 UDE → DashScope 向量化 → DBSCAN 聚类
|
流程:本地 comments → LLM 转写 UDE → DashScope 向量化 → DBSCAN 聚类
|
||||||
|
|
||||||
所有数据读写都在案例 DB 内,不跨云。
|
所有外部 API 调用统一使用 DASHSCOPE_API_KEY(百炼 Key 一个通吃 LLM + Embedding)。
|
||||||
|
向量化在 import-voc 流水线中预完成,聚类为纯 CPU 操作。
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -29,28 +30,28 @@ EMBED_BATCH_SIZE = 25
|
|||||||
|
|
||||||
PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "voc_to_ude.txt"
|
PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "voc_to_ude.txt"
|
||||||
|
|
||||||
|
# DashScope OpenAI 兼容端点(LLM + Embedding 共用)
|
||||||
def _get_llm_client() -> AsyncOpenAI:
|
DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
return AsyncOpenAI(
|
|
||||||
api_key=os.getenv("LITELLM_MASTER_KEY"),
|
|
||||||
base_url=os.getenv("LITELLM_PROXY_URL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_embed_client(key: str) -> OpenAI:
|
def _get_llm_client(dashscope_key: str = None) -> AsyncOpenAI:
|
||||||
|
"""百炼 Key 一个通吃 LLM + Embedding"""
|
||||||
|
key = dashscope_key or os.getenv("DASHSCOPE_API_KEY", "")
|
||||||
|
return AsyncOpenAI(api_key=key, base_url=DASHSCOPE_BASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embed_client(dashscope_key: str = None) -> OpenAI:
|
||||||
|
key = dashscope_key or os.getenv("DASHSCOPE_API_KEY", "")
|
||||||
if not key:
|
if not key:
|
||||||
raise ValueError("DashScope API Key 未配置。请通过 Header 或 .env 传入。")
|
raise ValueError("DashScope API Key 未配置。")
|
||||||
return OpenAI(
|
return OpenAI(api_key=key, base_url=DASHSCOPE_BASE)
|
||||||
api_key=key,
|
|
||||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ═══════════ Step 1: 本地评论 → UDE 转写 ═══════════
|
# ═══════════ Step 1: 本地评论 → UDE 转写 ═══════════
|
||||||
|
|
||||||
async def _call_ude_llm(prompt: str, comments: list[dict]) -> list[dict]:
|
async def _call_ude_llm(prompt: str, comments: list[dict], dashscope_key: str = None) -> list[dict]:
|
||||||
"""单批 LLM 转写"""
|
"""单批 LLM 转写"""
|
||||||
client = _get_llm_client()
|
client = _get_llm_client(dashscope_key)
|
||||||
user_msg = "请将以下消费者评论转写为 UDE 格式句,返回 JSON:\n\n"
|
user_msg = "请将以下消费者评论转写为 UDE 格式句,返回 JSON:\n\n"
|
||||||
for c in comments:
|
for c in comments:
|
||||||
user_msg += f"[{c['id']}] 平台:{c['platform']} 原文: \"{c['text'][:300]}\"\n\n"
|
user_msg += f"[{c['id']}] 平台:{c['platform']} 原文: \"{c['text'][:300]}\"\n\n"
|
||||||
@ -80,12 +81,12 @@ async def _call_ude_llm(prompt: str, comments: list[dict]) -> list[dict]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def _process_ude_batch(comments, prompt, semaphore):
|
async def _process_ude_batch(comments, prompt, semaphore, dashscope_key=None):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await _call_ude_llm(prompt, comments)
|
return await _call_ude_llm(prompt, comments, dashscope_key)
|
||||||
|
|
||||||
|
|
||||||
async def run_ude_extraction(case_id: str, limit: int = 0) -> dict:
|
async def run_ude_extraction(case_id: str, limit: int = 0, dashscope_key: str = None) -> dict:
|
||||||
"""从本地 comments 表读取评论,转写为 UDE,存入 ude_sentences"""
|
"""从本地 comments 表读取评论,转写为 UDE,存入 ude_sentences"""
|
||||||
from db import get_case_conn
|
from db import get_case_conn
|
||||||
|
|
||||||
@ -94,12 +95,10 @@ async def run_ude_extraction(case_id: str, limit: int = 0) -> dict:
|
|||||||
return {"error": "UDE 转写 prompt 未找到 (prompts/voc_to_ude.txt)"}
|
return {"error": "UDE 转写 prompt 未找到 (prompts/voc_to_ude.txt)"}
|
||||||
|
|
||||||
with get_case_conn(case_id) as conn:
|
with get_case_conn(case_id) as conn:
|
||||||
# 获取已转写的 comment_ids
|
|
||||||
done_ids = {r[0] for r in conn.execute(
|
done_ids = {r[0] for r in conn.execute(
|
||||||
"SELECT comment_id FROM ude_sentences"
|
"SELECT comment_id FROM ude_sentences"
|
||||||
).fetchall()}
|
).fetchall()}
|
||||||
|
|
||||||
# 从本地 comments 表读取
|
|
||||||
rows = conn.execute("""
|
rows = conn.execute("""
|
||||||
SELECT id, platform, text
|
SELECT id, platform, text
|
||||||
FROM comments WHERE length(text) > 10
|
FROM comments WHERE length(text) > 10
|
||||||
@ -117,17 +116,15 @@ async def run_ude_extraction(case_id: str, limit: int = 0) -> dict:
|
|||||||
if limit > 0:
|
if limit > 0:
|
||||||
pending = pending[:limit]
|
pending = pending[:limit]
|
||||||
|
|
||||||
# 切批
|
|
||||||
batches = []
|
batches = []
|
||||||
for i in range(0, len(pending), BATCH_SIZE):
|
for i in range(0, len(pending), BATCH_SIZE):
|
||||||
chunk = pending[i:i + BATCH_SIZE]
|
chunk = pending[i:i + BATCH_SIZE]
|
||||||
batches.append([{"id": r["id"], "platform": r["platform"], "text": r["text"]} for r in chunk])
|
batches.append([{"id": r["id"], "platform": r["platform"], "text": r["text"]} for r in chunk])
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(CONCURRENCY)
|
semaphore = asyncio.Semaphore(CONCURRENCY)
|
||||||
tasks = [asyncio.create_task(_process_ude_batch(b, prompt, semaphore)) for b in batches]
|
tasks = [asyncio.create_task(_process_ude_batch(b, prompt, semaphore, dashscope_key)) for b in batches]
|
||||||
all_results = await asyncio.gather(*tasks)
|
all_results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# 写入案例 DB
|
|
||||||
ok = 0
|
ok = 0
|
||||||
with get_case_conn(case_id) as conn:
|
with get_case_conn(case_id) as conn:
|
||||||
for results in all_results:
|
for results in all_results:
|
||||||
@ -160,7 +157,7 @@ async def run_ude_extraction(case_id: str, limit: int = 0) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ═══════════ Step 2 & 3: 向量化 + 聚类 ═══════════
|
# ═══════════ Step 2: 向量化(独立函数,流水线自动调用) ═══════════
|
||||||
|
|
||||||
def _embed_texts(client: OpenAI, texts: list[str]) -> list[list[float]]:
|
def _embed_texts(client: OpenAI, texts: list[str]) -> list[list[float]]:
|
||||||
all_vectors = []
|
all_vectors = []
|
||||||
@ -171,38 +168,75 @@ def _embed_texts(client: OpenAI, texts: list[str]) -> list[list[float]]:
|
|||||||
return all_vectors
|
return all_vectors
|
||||||
|
|
||||||
|
|
||||||
def run_clustering(case_id: str, eps: float = 0.25, min_samples: int = 3,
|
def run_vectorization(case_id: str, dashscope_key: str = None) -> dict:
|
||||||
dashscope_key: str = None) -> dict:
|
"""为所有未向量化的 UDE 生成 embedding(独立于聚类,可被流水线自动调用)"""
|
||||||
"""向量化 + DBSCAN 聚类(全部在本地案例 DB 内)"""
|
from db import get_case_conn
|
||||||
|
|
||||||
|
embed_client = _get_embed_client(dashscope_key)
|
||||||
|
|
||||||
|
with get_case_conn(case_id) as conn:
|
||||||
|
# 只处理未向量化的 UDE
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, ude_text FROM ude_sentences WHERE vector IS NULL ORDER BY id"
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
total = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
|
||||||
|
return {"vectorized": 0, "totalUdes": total, "message": "全部已向量化"}
|
||||||
|
|
||||||
|
texts = [r["ude_text"] for r in rows]
|
||||||
|
ids = [r["id"] for r in rows]
|
||||||
|
|
||||||
|
vectors = _embed_texts(embed_client, texts)
|
||||||
|
|
||||||
|
for i, uid in enumerate(ids):
|
||||||
|
conn.execute("UPDATE ude_sentences SET vector = ? WHERE id = ?",
|
||||||
|
(json.dumps(vectors[i]), uid))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
total = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
|
||||||
|
vectorized_total = conn.execute("SELECT count(*) FROM ude_sentences WHERE vector IS NOT NULL").fetchone()[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vectorized": len(rows),
|
||||||
|
"totalUdes": total,
|
||||||
|
"totalVectorized": vectorized_total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════ Step 3: 聚类(纯 CPU,不调外部 API) ═══════════
|
||||||
|
|
||||||
|
def run_clustering(case_id: str, eps: float = 0.25, min_samples: int = 3) -> dict:
|
||||||
|
"""纯 CPU 聚类:从 DB 读取已有向量,DBSCAN 计算。不调任何外部 API。"""
|
||||||
from sklearn.cluster import DBSCAN
|
from sklearn.cluster import DBSCAN
|
||||||
from sklearn.metrics.pairwise import cosine_distances
|
from sklearn.metrics.pairwise import cosine_distances
|
||||||
from db import get_case_conn
|
from db import get_case_conn
|
||||||
|
|
||||||
key = dashscope_key or os.getenv("DASHSCOPE_API_KEY", "")
|
|
||||||
if not key:
|
|
||||||
return {"error": "DashScope API Key 未配置。"}
|
|
||||||
|
|
||||||
embed_client = _get_embed_client(key)
|
|
||||||
|
|
||||||
with get_case_conn(case_id) as conn:
|
with get_case_conn(case_id) as conn:
|
||||||
rows = conn.execute("SELECT id, comment_id, ude_text FROM ude_sentences ORDER BY id").fetchall()
|
rows = conn.execute("SELECT id, comment_id, ude_text, vector FROM ude_sentences ORDER BY id").fetchall()
|
||||||
|
|
||||||
if len(rows) < min_samples:
|
if len(rows) < min_samples:
|
||||||
return {"error": f"UDE 不足 ({len(rows)} 条),至少需要 {min_samples} 条。"}
|
return {"error": f"UDE 不足 ({len(rows)} 条),至少需要 {min_samples} 条。"}
|
||||||
|
|
||||||
|
# 检查向量完备性
|
||||||
|
has_vector = [r for r in rows if r["vector"]]
|
||||||
|
missing = len(rows) - len(has_vector)
|
||||||
|
if missing > 0:
|
||||||
|
return {
|
||||||
|
"error": f"{missing} 条 UDE 尚未向量化,请等待流水线完成或手动调用 /ude/extract",
|
||||||
|
"totalUdes": len(rows),
|
||||||
|
"vectorized": len(has_vector),
|
||||||
|
"missing": missing,
|
||||||
|
}
|
||||||
|
|
||||||
ude_texts = [r["ude_text"] for r in rows]
|
ude_texts = [r["ude_text"] for r in rows]
|
||||||
ude_ids = [r["id"] for r in rows]
|
ude_ids = [r["id"] for r in rows]
|
||||||
comment_ids = [r["comment_id"] for r in rows]
|
comment_ids = [r["comment_id"] for r in rows]
|
||||||
|
|
||||||
# 向量化
|
# 从 DB 读取已有向量(纯 CPU)
|
||||||
vectors = _embed_texts(embed_client, ude_texts)
|
vec_array = np.array([json.loads(r["vector"]) for r in rows])
|
||||||
vec_array = np.array(vectors)
|
|
||||||
|
|
||||||
# 保存向量
|
# DBSCAN(纯 CPU)
|
||||||
for i, uid in enumerate(ude_ids):
|
|
||||||
conn.execute("UPDATE ude_sentences SET vector = ? WHERE id = ?",
|
|
||||||
(json.dumps(vectors[i]), uid))
|
|
||||||
|
|
||||||
# DBSCAN
|
|
||||||
dist_matrix = cosine_distances(vec_array)
|
dist_matrix = cosine_distances(vec_array)
|
||||||
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="precomputed").fit(dist_matrix)
|
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="precomputed").fit(dist_matrix)
|
||||||
labels = clustering.labels_
|
labels = clustering.labels_
|
||||||
@ -224,12 +258,10 @@ def run_clustering(case_id: str, eps: float = 0.25, min_samples: int = 3,
|
|||||||
member_vectors = vec_array[member_indices]
|
member_vectors = vec_array[member_indices]
|
||||||
member_cids = [comment_ids[i] for i in member_indices]
|
member_cids = [comment_ids[i] for i in member_indices]
|
||||||
|
|
||||||
# 簇中心
|
|
||||||
centroid = member_vectors.mean(axis=0)
|
centroid = member_vectors.mean(axis=0)
|
||||||
dists = cosine_distances([centroid], member_vectors)[0]
|
dists = cosine_distances([centroid], member_vectors)[0]
|
||||||
representative = member_texts[dists.argmin()]
|
representative = member_texts[dists.argmin()]
|
||||||
|
|
||||||
# 原声采样(从本地 comments 表)
|
|
||||||
sample_voices = []
|
sample_voices = []
|
||||||
for cid in member_cids[:5]:
|
for cid in member_cids[:5]:
|
||||||
voice = conn.execute(
|
voice = conn.execute(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user