K00B404 commited on
Commit
8ae56e8
·
verified ·
1 Parent(s): 41ddd56

Create huggingfaceinferenceclient.py

Browse files
Files changed (1) hide show
  1. huggingfaceinferenceclient.py +241 -0
huggingfaceinferenceclient.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import requests
4
+ from typing import Dict, Any, Optional
5
+ from PIL import Image
6
+ import io
7
+
8
+ class HuggingFaceInferenceClient:
9
+ """
10
+ Comprehensive client for interacting with Hugging Face Inference API endpoints.
11
+
12
+ ## Core Features
13
+ - Secure API authentication
14
+ - Flexible image encoding
15
+ - Advanced error handling
16
+ - Configurable generation parameters
17
+
18
+ ## Technical Design Considerations
19
+ - Environment-based configuration
20
+ - Type-hinted method signatures
21
+ - Comprehensive logging and error management
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ api_url: Optional[str] = None,
27
+ api_token: Optional[str] = None
28
+ ):
29
+ """
30
+ Initialize Hugging Face Inference API client.
31
+
32
+ Args:
33
+ api_url (str, optional): Inference endpoint URL
34
+ api_token (str, optional): Authentication token
35
+ """
36
+ self.api_url = api_url or os.getenv('HF_INFERENCE_ENDPOINT')
37
+ self.api_token = api_token or os.getenv('HF_API_TOKEN')
38
+
39
+ if not self.api_url or not self.api_token:
40
+ raise ValueError(
41
+ "Missing Hugging Face Inference endpoint or API token. "
42
+ "Please provide via parameters or environment variables."
43
+ )
44
+
45
+ def encode_image(
46
+ self,
47
+ image_path: str,
48
+ format: str = 'JPEG'
49
+ ) -> str:
50
+ """
51
+ Encode image to base64 data URI.
52
+
53
+ Args:
54
+ image_path (str): Path to input image
55
+ format (str): Output image format
56
+
57
+ Returns:
58
+ str: Base64 encoded data URI
59
+ """
60
+ try:
61
+ with Image.open(image_path) as img:
62
+ # Ensure RGB compatibility
63
+ if img.mode != "RGB":
64
+ img = img.convert("RGB")
65
+
66
+ # Convert to byte array
67
+ img_byte_arr = io.BytesIO()
68
+ img.save(img_byte_arr, format=format)
69
+
70
+ # Encode to base64
71
+ base64_encoded = base64.b64encode(
72
+ img_byte_arr.getvalue()
73
+ ).decode('utf-8')
74
+
75
+ return f"data:image/{format.lower()};base64,{base64_encoded}"
76
+
77
+ except Exception as e:
78
+ raise ValueError(f"Image encoding failed: {e}")
79
+
80
+ def generate_image(
81
+ self,
82
+ payload: Dict[str, Any]
83
+ ) -> Dict[str, Any]:
84
+ """
85
+ Execute image generation request.
86
+
87
+ Args:
88
+ payload (Dict): Generation configuration payload
89
+
90
+ Returns:
91
+ Dict: API response containing generation results
92
+ """
93
+ headers = {
94
+ "Accept": "application/json",
95
+ "Authorization": f"Bearer {self.api_token}",
96
+ "Content-Type": "application/json"
97
+ }
98
+
99
+ try:
100
+ response = requests.post(
101
+ self.api_url,
102
+ headers=headers,
103
+ json=payload
104
+ )
105
+ response.raise_for_status()
106
+ return response.json()
107
+
108
+ except requests.RequestException as e:
109
+ return {
110
+ "error": f"API request failed: {e}",
111
+ "status_code": response.status_code if 'response' in locals() else None
112
+ }
113
+
114
+ def save_generated_media(
115
+ self,
116
+ response: Dict[str, Any],
117
+ output_filename: str
118
+ ) -> Optional[str]:
119
+ """
120
+ Save generated media from API response.
121
+
122
+ Args:
123
+ response (Dict): API generation response
124
+ output_filename (str): Output file path
125
+
126
+ Returns:
127
+ Optional[str]: Path to saved file or None
128
+ """
129
+ media_types = {
130
+ 'image': self._save_image,
131
+ 'video': self._save_video
132
+ }
133
+
134
+ try:
135
+ # Check for errors
136
+ if 'error' in response:
137
+ print(f"Generation Error: {response['error']}")
138
+ return None
139
+
140
+ # Detect media type and save
141
+ for media_type, save_func in media_types.items():
142
+ if media_type in response:
143
+ return save_func(response[media_type], output_filename)
144
+
145
+ raise ValueError("No supported media found in response")
146
+
147
+ except Exception as e:
148
+ print(f"Media saving failed: {e}")
149
+ return None
150
+
151
+ def _save_image(
152
+ self,
153
+ image_data_uri: str,
154
+ output_path: str
155
+ ) -> str:
156
+ """
157
+ Save base64 encoded image data.
158
+
159
+ Args:
160
+ image_data_uri (str): Base64 image data URI
161
+ output_path (str): Output image file path
162
+
163
+ Returns:
164
+ str: Path to saved image
165
+ """
166
+ # Remove data URI prefix
167
+ base64_data = image_data_uri.split(",")[1]
168
+ image_data = base64.b64decode(base64_data)
169
+
170
+ with open(output_path, "wb") as f:
171
+ f.write(image_data)
172
+
173
+ return output_path
174
+
175
+ def _save_video(
176
+ self,
177
+ video_data_uri: str,
178
+ output_path: str
179
+ ) -> str:
180
+ """
181
+ Save base64 encoded video data.
182
+
183
+ Args:
184
+ video_data_uri (str): Base64 video data URI
185
+ output_path (str): Output video file path
186
+
187
+ Returns:
188
+ str: Path to saved video
189
+ """
190
+ # Remove data URI prefix
191
+ base64_data = video_data_uri.split(",")[1]
192
+ video_data = base64.b64decode(base64_data)
193
+
194
+ with open(output_path, "wb") as f:
195
+ f.write(video_data)
196
+
197
+ return output_path
198
+
199
+ def main():
200
+ """
201
+ Example usage demonstrating client capabilities.
202
+ """
203
+ # Initialize client with endpoint and token
204
+ client = HuggingFaceInferenceClient(
205
+ api_url="https://your-endpoint.endpoints.huggingface.cloud",
206
+ api_token="hf_your_token_here"
207
+ )
208
+
209
+ # Prepare generation payload
210
+ image_generation_config = {
211
+ "inputs": {
212
+ "image": client.encode_image("input_image.jpg"),
213
+ "prompt": "Enhance and expand the scene creatively"
214
+ },
215
+ "parameters": {
216
+ # Configurable generation parameters
217
+ "width": 768,
218
+ "height": 480,
219
+ "num_frames": 129, # 8*16 + 1
220
+ "num_inference_steps": 50,
221
+ "guidance_scale": 4.0,
222
+ "double_num_frames": True,
223
+ "fps": 60,
224
+ "super_resolution": True,
225
+ "grain_amount": 12
226
+ }
227
+ }
228
+
229
+ # Generate media
230
+ generation_output = client.generate_image(image_generation_config)
231
+
232
+ # Save generated media
233
+ output_filename = client.save_generated_media(
234
+ generation_output,
235
+ "output_media.mp4"
236
+ )
237
+
238
+ print(f"Media saved to: {output_filename}")
239
+
240
+ if __name__ == "__main__":
241
+ main()