Spaces:
Sleeping
Sleeping
| import re | |
| import asyncio | |
| import requests | |
| import hivemind | |
| import functools | |
| from async_timeout import timeout | |
| from petals.server.handler import TransformerConnectionHandler | |
| info_cache = hivemind.TimedStorage() | |
| async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True): | |
| if use_cache: | |
| entry = info_cache.get(peer_id) | |
| if entry is not None: | |
| return entry.value | |
| try: | |
| with timeout(connect_timeout): | |
| if fetch_info: # For Petals servers | |
| stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id) | |
| response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID()) | |
| rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info) | |
| rpc_info["ok"] = True | |
| else: # For DHT-only bootstrap peers | |
| await node.p2p._client.connect(peer_id, []) | |
| await node.p2p._client.disconnect(peer_id) | |
| rpc_info = {"ok": True} | |
| except Exception as e: | |
| # Actual connection error | |
| if not isinstance(e, asyncio.TimeoutError): | |
| message = str(e) if str(e) else repr(e) | |
| if message == "protocol not supported": | |
| # This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1 | |
| return {"ok": True} | |
| else: | |
| message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections" | |
| rpc_info = {"ok": False, "error": message} | |
| info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration) | |
| return rpc_info | |
| async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False): | |
| rpc_infos = await asyncio.gather( | |
| *[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids] | |
| ) | |
| return dict(zip(peer_ids, rpc_infos)) | |
| async def get_peers_ips(dht, dht_node): | |
| return await dht_node.p2p.list_peers() | |
| def get_location(ip_address): | |
| try: | |
| response = requests.get(f"http://ip-api.com/json/{ip_address}") | |
| if response.status_code == 200: | |
| return response.json() | |
| except Exception: | |
| pass | |
| return {} | |
| def extract_peer_ip_info(multiaddr_str): | |
| if ip_match := re.search(r"/ip4/(\d+\.\d+\.\d+\.\d+)", multiaddr_str): | |
| return get_location(ip_match[1]) | |
| return {} | |