Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,80 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
```python
|
6 |
+
|
7 |
+
from gen import get_answer,get_state
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def load_state(train_state_path, layer=32, n_embd=2560):
|
12 |
+
train_state = torch.load(pth_file_path, map_location=torch.device('cpu'))
|
13 |
+
state = [None] * (layer * 3)
|
14 |
+
for i in range(layer):
|
15 |
+
state[i*3+0]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
|
16 |
+
state[i*3+1]=train_state[f'blocks.{i}.att.time_state'].to(dtype=torch.float,device='cuda')
|
17 |
+
state[i*3+2]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
|
18 |
+
return state
|
19 |
+
|
20 |
+
def get_instruction():
|
21 |
+
"""返回固定的指令内容"""
|
22 |
+
return "根据input中的input和entity_types,帮助用户找到文本中每种entity_types的实体,标明实体类型并且简单描述。然后给找到实体之间的关系,并且描述这段关系以及对关系强度打分。 避免使用诸如\"其他\"或\"未知\"的通用实体类型。 非常重要的是:不要生成冗余或重叠的实体类型和关系。用JSON格式输出。"
|
23 |
+
|
24 |
+
def get_content(input_text):
|
25 |
+
"""输入内容文本,返回格式化的content部分"""
|
26 |
+
return f"'{{'input': '{input_text}'}}"
|
27 |
+
|
28 |
+
def get_entity_types(entity_list):
|
29 |
+
"""
|
30 |
+
输入实体类型列表,返回格式化的entity_types部分
|
31 |
+
|
32 |
+
Args:
|
33 |
+
entity_list: 可以是字符串列表 ['领域', '专家', '任务']
|
34 |
+
或者是字符串 '领域, 专家, 任务'
|
35 |
+
"""
|
36 |
+
if isinstance(entity_list, str):
|
37 |
+
# 如果是字符串,按逗号分割
|
38 |
+
entity_list = [item.strip() for item in entity_list.split(',')]
|
39 |
+
|
40 |
+
# 不带引号的格式(和原数据一致)
|
41 |
+
entity_str = ', '.join(entity_list)
|
42 |
+
return f"{{'entity_types': [{entity_str}]}}"
|
43 |
+
|
44 |
+
def generate_prompt(content, entity_types):
|
45 |
+
"""
|
46 |
+
生成完整的prompt
|
47 |
+
|
48 |
+
Args:
|
49 |
+
content: 输入的文本内容
|
50 |
+
entity_types: 实体类型列表或字符串
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
完整的prompt字符串
|
54 |
+
"""
|
55 |
+
instruction = get_instruction()
|
56 |
+
content_part = get_content(content)
|
57 |
+
entity_types_part = get_entity_types(entity_types)
|
58 |
+
input_list_str = f'["content": {content_part}, "entity_types": {entity_types_part}]'
|
59 |
+
# 按照指定格式拼接
|
60 |
+
prompt = (
|
61 |
+
f"{input_list_str}\n\n"
|
62 |
+
f"User: Act as a specialized AI for Knowledge Graph construction. Your task is to extract entities and their relationships from the provided input, based on the given entity_types provided in above content.\nStructure your output as a single, valid JSON object with two top-level keys: entities and relationships.\nentities: A list of objects. Each object must have:\nentity: The exact name of the entity.\ndescription: A brief, context-based summary of the entity.\nrelationships: A list of objects. Each object must have:\nsource: The name of the source entity.\ntarget: The name of the target entity.\nrelationship: A concise description of their connection.\nCritical Rules:\nStrict Typing: Use only the provided entity types. Do not invent types or use generics like \"Other\".\nNo Redundancy: Do not create duplicate or reciprocal relationships (e.g., if A acquired B exists, do not add B was acquired by A).\nYour response must be only the JSON object.\n\n"
|
63 |
+
f"Assistant:"
|
64 |
+
)
|
65 |
+
|
66 |
+
return prompt
|
67 |
+
|
68 |
+
content1 = "根据我国的监狱法令,为了协助监狱囚犯改过自新和重新融入社会,监禁期至少四个星期的囚犯可在服刑至少14天后转入居家宵禁计划,在家服满剩余的刑期"
|
69 |
+
entity_types1 = ["法律法规", "人物类别", "时间条件", "政策措施"]
|
70 |
+
|
71 |
+
ctx = generate_prompt(content1, entity_types1)
|
72 |
+
|
73 |
+
pth_file_path = "/home/rwkv/models/triplets1/rwkv-0.pth"
|
74 |
+
|
75 |
+
|
76 |
+
tt_state = load_state(pth_file_path)
|
77 |
+
|
78 |
+
print(ctx)
|
79 |
+
res1 = get_answer(ctx,state=tt_state)
|
80 |
+
print('train_state :',res1)
|