Spaces:
Runtime error
Runtime error
File size: 4,344 Bytes
3c30e6f 1ac89b2 3c30e6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import apache_beam as beam
import gradio as gr
import huggingface_hub
import pandas as pd
import plotly.graph_objects as go
import spaces
import textwrap
import torch
import us
from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import logging
import os
import requests
MODEL_NAME = "google/gemma-2-2b-it"
PROMPT_TEMPLATE = """Write a succinct summary of the following weather alerts. Do not comment on missing information - just summarize the information provided/available.
```json
{}
```
Summary (In the state...):
"""
# Initialize an empty list to store weather alerts
alerts = []
# Define a transform for fetching weather alerts
class FetchWeatherAlerts(beam.DoFn):
def process(self, state):
logging.info(f"Fetching weather alerts for {state} from weather.gov")
url = f"https://api.weather.gov/alerts/active?area={state}"
response = requests.get(
url,
headers={
"User-Agent": "(Neal DeBuhr, https://huggingface.co/spaces/ndebuhr/streaming-llm-weather-alerts)",
"Accept": "application/geo+json",
},
)
if response.status_code == 200:
logging.info(f"Fetched weather alerts for {state} from weather.gov")
features = response.json()["features"]
alerts.append(
{
"features": [
{
"event": feature["properties"]["event"],
"headline": feature["properties"]["headline"],
"instruction": feature["properties"]["instruction"],
}
for feature in features
if feature["properties"]["messageType"] == "Alert"
],
"state": state,
}
)
pipeline_options = PipelineOptions()
# Save the main session state so that pickled functions and classes
# defined in __main__ can be unpickled
pipeline_options.view_as(SetupOptions).save_main_session = True
# Create and run the Apache Beam pipeline to fetch weather alerts
with beam.Pipeline(options=pipeline_options) as p:
(p
| "Create States" >> beam.Create([state.abbr for state in us.states.STATES])
| "Fetch Weather Alerts" >> beam.ParDo(FetchWeatherAlerts())
)
# Define a function to generate alert summaries using transformers and ZeroGPU
@spaces.GPU()
def generate_summaries(alerts):
huggingface_hub.login(token=os.environ["HUGGINGFACE_TOKEN"])
device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
for alert in alerts:
prompt = PROMPT_TEMPLATE.format(json.dumps(alert, indent=2))
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id
)
alert["summary"] = (
tokenizer.decode(outputs[0], skip_special_tokens=True)
.replace(prompt, "")
.strip()
)
return alerts
alerts = generate_summaries(alerts)
df = pd.DataFrame.from_dict(
[{"state": alert["state"], "summary": alert["summary"]} for alert in alerts]
)
def get_map():
def wrap_text(text, width=50):
return "<br>".join(textwrap.wrap(text, width=width))
df["wrapped_summary"] = df["summary"].apply(wrap_text)
fig = go.Figure(
go.Choropleth(
locations=df["state"],
z=[1 for _ in df["summary"]],
locationmode="USA-states",
colorscale=[
[0, "lightgrey"],
[1, "lightgrey"],
], # Single color for all states
showscale=False,
text=df["wrapped_summary"],
hoverinfo="text",
hovertemplate="%{text}<extra></extra>",
)
)
fig.update_layout(title_text="Streaming LLM Weather Alerts", geo_scope="usa")
return fig
# Create Gradio interface
iface = gr.Interface(fn=get_map, inputs=None, outputs=gr.Plot())
# Launch the Gradio interface
iface.launch()
|