SupYumm commited on
Commit
fac2d2c
·
verified ·
1 Parent(s): 3df9842

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -3
README.md CHANGED
@@ -1,3 +1,80 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ```python
6
+
7
+ from gen import get_answer,get_state
8
+ import torch
9
+
10
+
11
+ def load_state(train_state_path, layer=32, n_embd=2560):
12
+ train_state = torch.load(pth_file_path, map_location=torch.device('cpu'))
13
+ state = [None] * (layer * 3)
14
+ for i in range(layer):
15
+ state[i*3+0]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
16
+ state[i*3+1]=train_state[f'blocks.{i}.att.time_state'].to(dtype=torch.float,device='cuda')
17
+ state[i*3+2]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
18
+ return state
19
+
20
+ def get_instruction():
21
+ """返回固定的指令内容"""
22
+ return "根据input中的input和entity_types,帮助用户找到文本中每种entity_types的实体,标明实体类型并且简单描述。然后给找到实体之间的关系,并且描述这段关系以及对关系强度打分。 避免使用诸如\"其他\"或\"未知\"的通用实体类型。 非常重要的是:不要生成冗余或重叠的实体类型和关系。用JSON格式输出。"
23
+
24
+ def get_content(input_text):
25
+ """输入内容文本,返回格式化的content部分"""
26
+ return f"'{{'input': '{input_text}'}}"
27
+
28
+ def get_entity_types(entity_list):
29
+ """
30
+ 输入实体类型列表,返回格式化的entity_types部分
31
+
32
+ Args:
33
+ entity_list: 可以是字符串列表 ['领域', '专家', '任务']
34
+ 或者是字符串 '领域, 专家, 任务'
35
+ """
36
+ if isinstance(entity_list, str):
37
+ # 如果是字符串,按逗号分割
38
+ entity_list = [item.strip() for item in entity_list.split(',')]
39
+
40
+ # 不带引号的格式(和原数据一致)
41
+ entity_str = ', '.join(entity_list)
42
+ return f"{{'entity_types': [{entity_str}]}}"
43
+
44
+ def generate_prompt(content, entity_types):
45
+ """
46
+ 生成完整的prompt
47
+
48
+ Args:
49
+ content: 输入的文本内容
50
+ entity_types: 实体类型列表或字符串
51
+
52
+ Returns:
53
+ 完整的prompt字符串
54
+ """
55
+ instruction = get_instruction()
56
+ content_part = get_content(content)
57
+ entity_types_part = get_entity_types(entity_types)
58
+ input_list_str = f'["content": {content_part}, "entity_types": {entity_types_part}]'
59
+ # 按照指定格式拼接
60
+ prompt = (
61
+ f"{input_list_str}\n\n"
62
+ f"User: Act as a specialized AI for Knowledge Graph construction. Your task is to extract entities and their relationships from the provided input, based on the given entity_types provided in above content.\nStructure your output as a single, valid JSON object with two top-level keys: entities and relationships.\nentities: A list of objects. Each object must have:\nentity: The exact name of the entity.\ndescription: A brief, context-based summary of the entity.\nrelationships: A list of objects. Each object must have:\nsource: The name of the source entity.\ntarget: The name of the target entity.\nrelationship: A concise description of their connection.\nCritical Rules:\nStrict Typing: Use only the provided entity types. Do not invent types or use generics like \"Other\".\nNo Redundancy: Do not create duplicate or reciprocal relationships (e.g., if A acquired B exists, do not add B was acquired by A).\nYour response must be only the JSON object.\n\n"
63
+ f"Assistant:"
64
+ )
65
+
66
+ return prompt
67
+
68
+ content1 = "根据我国的监狱法令,为了协助监狱囚犯改过自新和重新融入社会,监禁期至少四个星期的囚犯可在服刑至少14天后转入居家宵禁计划,在家服满剩余的刑期"
69
+ entity_types1 = ["法律法规", "人物类别", "时间条件", "政策措施"]
70
+
71
+ ctx = generate_prompt(content1, entity_types1)
72
+
73
+ pth_file_path = "/home/rwkv/models/triplets1/rwkv-0.pth"
74
+
75
+
76
+ tt_state = load_state(pth_file_path)
77
+
78
+ print(ctx)
79
+ res1 = get_answer(ctx,state=tt_state)
80
+ print('train_state :',res1)