chanpinhsd/backend/tools/ude_extract.py
lidf c5e2a58258 refactor: v2.0 完全解耦 — 阿里云内闭环
- 删除 VOC_DATA_DIR / get_voc_conn(不再跨云直读 SQLite)
- 案例 DB 自带 comments 表,自包含所有数据
- 新增 POST /import-voc:通过 VOC 公网 API 导入评论
- VOC_API_BASE 环境变量控制 API 地址
- 新增 httpx 依赖
2026-04-07 19:47:34 +08:00

296 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
黑手党提案 — UDE 提取工具(阿里云内闭环)
流程:本地 comments → LLM 转写 UDE → DashScope 向量化 → DBSCAN 聚类
所有数据读写都在案例 DB 内,不跨云。
"""
from __future__ import annotations
import json
import os
import asyncio
import logging
from pathlib import Path
import numpy as np
from openai import OpenAI, AsyncOpenAI
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
MODEL = os.getenv("MODEL_ID", "qwen-plus")
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.1"))
BATCH_SIZE = 10
CONCURRENCY = 5
EMBED_DIM = 1024
EMBED_BATCH_SIZE = 25
PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "voc_to_ude.txt"
def _get_llm_client() -> AsyncOpenAI:
return AsyncOpenAI(
api_key=os.getenv("LITELLM_MASTER_KEY"),
base_url=os.getenv("LITELLM_PROXY_URL"),
)
def _get_embed_client(key: str) -> OpenAI:
if not key:
raise ValueError("DashScope API Key 未配置。请通过 Header 或 .env 传入。")
return OpenAI(
api_key=key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# ═══════════ Step 1: 本地评论 → UDE 转写 ═══════════
async def _call_ude_llm(prompt: str, comments: list[dict]) -> list[dict]:
"""单批 LLM 转写"""
client = _get_llm_client()
user_msg = "请将以下消费者评论转写为 UDE 格式句,返回 JSON\n\n"
for c in comments:
user_msg += f"[{c['id']}] 平台:{c['platform']} 原文: \"{c['text'][:300]}\"\n\n"
try:
resp = await client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": user_msg},
],
temperature=TEMPERATURE,
max_tokens=4000,
response_format={"type": "json_object"},
)
content = (resp.choices[0].message.content or "").strip()
parsed = json.loads(content)
if isinstance(parsed, dict):
for key in ("results", "data", "items", "udes"):
if key in parsed and isinstance(parsed[key], list):
return parsed[key]
if isinstance(parsed, list):
return parsed
return []
except Exception as e:
logger.warning(f"[UDE] LLM 转写失败: {str(e)[:80]}")
return []
async def _process_ude_batch(comments, prompt, semaphore):
async with semaphore:
return await _call_ude_llm(prompt, comments)
async def run_ude_extraction(case_id: str, limit: int = 0) -> dict:
"""从本地 comments 表读取评论,转写为 UDE存入 ude_sentences"""
from db import get_case_conn
prompt = PROMPT_PATH.read_text("utf-8") if PROMPT_PATH.exists() else ""
if not prompt:
return {"error": "UDE 转写 prompt 未找到 (prompts/voc_to_ude.txt)"}
with get_case_conn(case_id) as conn:
# 获取已转写的 comment_ids
done_ids = {r[0] for r in conn.execute(
"SELECT comment_id FROM ude_sentences"
).fetchall()}
# 从本地 comments 表读取
rows = conn.execute("""
SELECT id, platform, text
FROM comments WHERE length(text) > 10
ORDER BY id
""").fetchall()
total_comments = len(rows)
pending = [r for r in rows if r["id"] not in done_ids]
if not pending:
with get_case_conn(case_id) as conn:
total = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
return {"message": "全部已转写完成", "totalUdes": total, "new": 0}
if limit > 0:
pending = pending[:limit]
# 切批
batches = []
for i in range(0, len(pending), BATCH_SIZE):
chunk = pending[i:i + BATCH_SIZE]
batches.append([{"id": r["id"], "platform": r["platform"], "text": r["text"]} for r in chunk])
semaphore = asyncio.Semaphore(CONCURRENCY)
tasks = [asyncio.create_task(_process_ude_batch(b, prompt, semaphore)) for b in batches]
all_results = await asyncio.gather(*tasks)
# 写入案例 DB
ok = 0
with get_case_conn(case_id) as conn:
for results in all_results:
for r in (results or []):
if not isinstance(r, dict):
continue
ude_text = r.get("ude")
if not ude_text:
continue
cid = r.get("id")
if not cid:
continue
try:
conn.execute(
"INSERT OR IGNORE INTO ude_sentences (comment_id, ude_text, confidence) VALUES (?, ?, ?)",
(int(cid), ude_text, r.get("confidence", 0.5))
)
ok += 1
except Exception as e:
logger.warning(f"[UDE] 写入失败 id={cid}: {e}")
conn.commit()
total = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
return {
"newUdes": ok,
"totalUdes": total,
"totalComments": total_comments,
"remaining": total_comments - total,
"batches": len(batches),
}
# ═══════════ Step 2 & 3: 向量化 + 聚类 ═══════════
def _embed_texts(client: OpenAI, texts: list[str]) -> list[list[float]]:
all_vectors = []
for i in range(0, len(texts), EMBED_BATCH_SIZE):
batch = texts[i:i + EMBED_BATCH_SIZE]
resp = client.embeddings.create(model="text-embedding-v4", input=batch, dimensions=EMBED_DIM)
all_vectors.extend([item.embedding for item in resp.data])
return all_vectors
def run_clustering(case_id: str, eps: float = 0.25, min_samples: int = 3,
dashscope_key: str = None) -> dict:
"""向量化 + DBSCAN 聚类(全部在本地案例 DB 内)"""
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import cosine_distances
from db import get_case_conn
key = dashscope_key or os.getenv("DASHSCOPE_API_KEY", "")
if not key:
return {"error": "DashScope API Key 未配置。"}
embed_client = _get_embed_client(key)
with get_case_conn(case_id) as conn:
rows = conn.execute("SELECT id, comment_id, ude_text FROM ude_sentences ORDER BY id").fetchall()
if len(rows) < min_samples:
return {"error": f"UDE 不足 ({len(rows)} 条),至少需要 {min_samples} 条。"}
ude_texts = [r["ude_text"] for r in rows]
ude_ids = [r["id"] for r in rows]
comment_ids = [r["comment_id"] for r in rows]
# 向量化
vectors = _embed_texts(embed_client, ude_texts)
vec_array = np.array(vectors)
# 保存向量
for i, uid in enumerate(ude_ids):
conn.execute("UPDATE ude_sentences SET vector = ? WHERE id = ?",
(json.dumps(vectors[i]), uid))
# DBSCAN
dist_matrix = cosine_distances(vec_array)
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="precomputed").fit(dist_matrix)
labels = clustering.labels_
# 更新聚类标签
for i, uid in enumerate(ude_ids):
conn.execute("UPDATE ude_sentences SET cluster_id = ? WHERE id = ?",
(int(labels[i]), uid))
# 清空旧聚类,写入新聚类
conn.execute("DELETE FROM ude_clusters")
clusters = []
unique_labels = sorted(set(labels) - {-1})
for cluster_id in unique_labels:
member_indices = [i for i, l in enumerate(labels) if l == cluster_id]
member_texts = [ude_texts[i] for i in member_indices]
member_vectors = vec_array[member_indices]
member_cids = [comment_ids[i] for i in member_indices]
# 簇中心
centroid = member_vectors.mean(axis=0)
dists = cosine_distances([centroid], member_vectors)[0]
representative = member_texts[dists.argmin()]
# 原声采样(从本地 comments 表)
sample_voices = []
for cid in member_cids[:5]:
voice = conn.execute(
"SELECT text, platform FROM comments WHERE id = ?", (cid,)
).fetchone()
if voice:
sample_voices.append({"text": voice["text"][:200], "platform": voice["platform"]})
conn.execute(
"INSERT INTO ude_clusters (representative_ude, coverage, sample_voices) VALUES (?, ?, ?)",
(representative, len(member_indices), json.dumps(sample_voices, ensure_ascii=False))
)
clusters.append({
"clusterId": int(cluster_id),
"representativeUde": representative,
"coverage": len(member_indices),
"sampleVoices": sample_voices,
})
conn.commit()
clusters.sort(key=lambda x: x["coverage"], reverse=True)
noise_count = int((labels == -1).sum())
return {
"totalUdes": len(labels),
"numClusters": len(clusters),
"noiseCount": noise_count,
"noisePct": round(noise_count / len(labels) * 100, 1) if len(labels) else 0,
"clusters": clusters,
"params": {"eps": eps, "minSamples": min_samples},
}
# ═══════════ 覆盖扫描 ═══════════
def run_coverage_scan(case_id: str) -> dict:
from db import get_case_conn
with get_case_conn(case_id) as conn:
total_comments = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
total_udes = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
clustered = conn.execute("SELECT count(*) FROM ude_sentences WHERE cluster_id >= 0").fetchone()[0]
noise = conn.execute("SELECT count(*) FROM ude_sentences WHERE cluster_id = -1").fetchone()[0]
cluster_stats = [dict(r) for r in conn.execute(
"SELECT cluster_id, count(*) as cnt FROM ude_sentences WHERE cluster_id >= 0 GROUP BY cluster_id ORDER BY cnt DESC"
).fetchall()]
noise_samples = [dict(r) for r in conn.execute(
"SELECT ude_text, comment_id, confidence FROM ude_sentences WHERE cluster_id = -1 ORDER BY confidence DESC LIMIT 10"
).fetchall()]
return {
"totalComments": total_comments,
"totalUdes": total_udes,
"udesClustered": clustered,
"udesNoise": noise,
"coverageRate": round(clustered / total_comments * 100, 1) if total_comments else 0,
"clusterDistribution": cluster_stats,
"noiseSamples": noise_samples,
"verdict": "充分" if (total_udes > 0 and noise / total_udes < 0.1) else
("需关注" if (total_udes > 0 and noise / total_udes < 0.2) else "需调参"),
}