yyuri commited on
Commit
4b3bda3
·
verified ·
1 Parent(s): ca46696

Upload check_openai.py

Browse files
utils_groupclassification/check_openai.py CHANGED
@@ -4,6 +4,7 @@ from logging import getLogger
4
 
5
  from openai import OpenAI
6
  from openai import AzureOpenAI
 
7
  import fitz
8
 
9
  import requests
@@ -55,26 +56,41 @@ def generate_check_(reference):
55
  client = OpenAI(
56
  api_key=api_key,
57
  )
58
- response = client.chat.completions.create(
59
- model="gpt-3.5-turbo",
60
- messages=[
61
- {
62
- "role": "system",
63
- "content": system_prompt,
64
- },
65
- {
66
- "role": "user",
67
- "content": reference,
68
- },
69
- ],
70
- functions=[{"name": "generate_queries", "parameters": json_schema}],
71
- function_call={"name": "generate_queries"},
72
- temperature=0.0,
73
- top_p=0.0,
74
- )
75
- output = response.choices[0].message.function_call.arguments
76
- time.sleep(1)
77
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  def ch(company_name, reference):
 
4
 
5
  from openai import OpenAI
6
  from openai import AzureOpenAI
7
+ from openai import OpenAIError
8
  import fitz
9
 
10
  import requests
 
56
  client = OpenAI(
57
  api_key=api_key,
58
  )
59
+ retries = 0
60
+ max_retries = 100
61
+ delay = 5
62
+
63
+ while retries < max_retries:
64
+ try:
65
+ response = client.chat.completions.create(
66
+ model="gpt-3.5-turbo",
67
+ messages=[
68
+ {
69
+ "role": "system",
70
+ "content": system_prompt,
71
+ },
72
+ {
73
+ "role": "user",
74
+ "content": reference,
75
+ },
76
+ ],
77
+ functions=[{"name": "generate_queries", "parameters": json_schema}],
78
+ function_call={"name": "generate_queries"},
79
+ temperature=0.0,
80
+ top_p=0.0,
81
+ )
82
+ output = response.choices[0].message.function_call.arguments
83
+ time.sleep(1)
84
+ return output
85
+ except OpenAIError as e:
86
+ print(f"Error occurred: {e}. Retrying in {delay} seconds...")
87
+ retries += 1
88
+ time.sleep(delay)
89
+ except Exception as e:
90
+ print(f"Unexpected error: {e}. Retrying in {delay} seconds...")
91
+ retries += 1
92
+ time.sleep(delay)
93
+ raise RuntimeError("Maximum retries exceeded. Could not get a valid response.")
94
 
95
 
96
  def ch(company_name, reference):