|
from typing import Dict, List, Union |
|
import numpy as np |
|
import torch |
|
from transformers import Pipeline |
|
from astropy.io import fits |
|
|
|
class FlareDetectionPipeline(Pipeline): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.call_count = 0 |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
postprocess_kwargs = {} |
|
|
|
|
|
return preprocess_kwargs, {}, postprocess_kwargs |
|
|
|
def preprocess(self, light_curve: Union[np.ndarray, str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: |
|
"""Preprocess the input light curve from FITS files. |
|
|
|
Args: |
|
light_curve: Single FITS file path, list of FITS file paths, or numpy array |
|
""" |
|
|
|
if isinstance(light_curve, str): |
|
light_curve = [light_curve] |
|
|
|
|
|
if isinstance(light_curve, list) and isinstance(light_curve[0], str): |
|
|
|
flux_data = [] |
|
times_data = [] |
|
lengths = [] |
|
|
|
|
|
max_length = 0 |
|
for fits_path in light_curve: |
|
with fits.open(fits_path) as hdul: |
|
time = hdul[1].data['TIME'].astype(np.float32) |
|
flux = hdul[1].data['PDCSAP_FLUX'].astype(np.float32) |
|
|
|
flux = flux / np.nanmedian(flux) |
|
|
|
max_length = max(max_length, len(flux)) |
|
lengths.append(len(flux)) |
|
flux_data.append(flux) |
|
times_data.append(time) |
|
|
|
|
|
padded_flux = [] |
|
padded_times = [] |
|
sequence_mask = [] |
|
|
|
for flux, time, length in zip(flux_data, times_data, lengths): |
|
|
|
pad_length = max_length - length |
|
|
|
|
|
padded_f = np.pad(flux, (0, pad_length), mode='constant', constant_values=np.nan) |
|
padded_t = np.pad(time, (0, pad_length), mode='constant', constant_values=np.nan) |
|
|
|
|
|
mask = np.ones(length) |
|
mask = np.pad(mask, (0, pad_length), mode='constant', constant_values=0) |
|
|
|
padded_flux.append(padded_f) |
|
padded_times.append(padded_t) |
|
sequence_mask.append(mask) |
|
|
|
|
|
self.time_series = np.array(padded_times) |
|
|
|
flux_array = np.array(padded_flux) |
|
sequence_mask = np.array(sequence_mask) |
|
|
|
|
|
flux_array = flux_array.reshape(flux_array.shape[0], flux_array.shape[1], 1) |
|
|
|
|
|
inputs = torch.tensor(flux_array, dtype=torch.float32) |
|
mask = torch.tensor(sequence_mask, dtype=torch.float32) |
|
|
|
return { |
|
"input_features": inputs, |
|
"sequence_mask": mask |
|
} |
|
|
|
def _forward(self, model_inputs, **forward_params): |
|
"""Forward pass through the model. |
|
|
|
Args: |
|
model_inputs: Dictionary containing input tensors |
|
forward_params: Additional parameters for the forward pass |
|
""" |
|
if model_inputs is None: |
|
raise ValueError("model_inputs cannot be None. Check if preprocess method is returning correct dictionary.") |
|
|
|
if "input_features" not in model_inputs: |
|
raise KeyError("model_inputs must contain 'input_features' key.") |
|
|
|
|
|
self.input_features = model_inputs["input_features"] |
|
|
|
|
|
return self.model( |
|
input_features=model_inputs["input_features"], |
|
sequence_mask=model_inputs.get("sequence_mask", None), |
|
return_dict=True |
|
) |
|
|
|
def postprocess(self, model_outputs, **kwargs): |
|
""" |
|
Postprocess the model outputs to detect flare events. |
|
Returns a list of dictionaries containing flare events information. |
|
""" |
|
logits = model_outputs.logits |
|
predictions = torch.sigmoid(logits).squeeze(-1) |
|
binary_predictions = (predictions > 0.5).long() |
|
|
|
|
|
predictions_np = binary_predictions.cpu().numpy() |
|
flux_data = self.input_features.cpu().numpy() |
|
|
|
flare_events = [] |
|
|
|
def is_valid_flare(flux, start_idx, end_idx, peak_idx): |
|
"""Helper function to validate flare events |
|
|
|
Args: |
|
flux: Array of flux values |
|
start_idx: Start index of potential flare |
|
end_idx: End index of potential flare |
|
peak_idx: Peak index of potential flare |
|
|
|
Returns: |
|
bool: True if the event is a valid flare, False otherwise |
|
""" |
|
|
|
if end_idx - start_idx < 2: |
|
return False |
|
|
|
try: |
|
|
|
if peak_idx == start_idx and flux[peak_idx] <= flux[peak_idx - 1]: |
|
return False |
|
|
|
|
|
if end_idx - peak_idx <= peak_idx - start_idx: |
|
return False |
|
|
|
|
|
alter = (flux[peak_idx] - flux[start_idx - 2]) / (flux[peak_idx] - flux[end_idx + 2] + 1e-8) |
|
|
|
if alter < 0.5 or alter > 2 or np.isnan(alter): |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
except (IndexError, ValueError): |
|
return False |
|
|
|
return True |
|
|
|
for i in range(predictions_np.shape[0]): |
|
pred = predictions_np[i] |
|
flux = flux_data[i, :, 0] |
|
flare_idx = np.where(pred == 1)[0] |
|
|
|
if len(flare_idx) == 0: |
|
continue |
|
|
|
|
|
splits = np.where(np.diff(flare_idx) > 1)[0] + 1 |
|
segments = np.split(flare_idx, splits) |
|
|
|
for segment in segments: |
|
|
|
if len(segment) < 3: |
|
continue |
|
|
|
start_idx = segment[0] |
|
end_idx = segment[-1] |
|
|
|
|
|
segment_flux = flux[start_idx:end_idx+1] |
|
peak_idx = np.argmax(segment_flux) + start_idx |
|
|
|
|
|
if not is_valid_flare(flux, start_idx, end_idx, peak_idx): |
|
continue |
|
|
|
|
|
start_time = float(self.time_series[i][start_idx]) |
|
end_time = float(self.time_series[i][end_idx]) |
|
duration = end_time - start_time |
|
event = { |
|
"start_idx": int(start_idx), |
|
"peak_idx": int(peak_idx), |
|
"end_idx": int(end_idx), |
|
"start_time": start_time, |
|
"peak_time": float(self.time_series[i][peak_idx]), |
|
"end_time": end_time, |
|
"duration": duration, |
|
"confidence": float(predictions[i, segment].mean()), |
|
} |
|
flare_events.append(event) |
|
|
|
return flare_events |
|
|
|
def load_flare_detection_pipeline( |
|
model_name: str = "Maxwell-Jia/fcn4flare", |
|
device: int = -1, |
|
**kwargs |
|
) -> FlareDetectionPipeline: |
|
""" |
|
Load a flare detection pipeline. |
|
|
|
Args: |
|
model_name (str): The model name or path to load |
|
device (int): Device to use (-1 for CPU, GPU number otherwise) |
|
**kwargs: Additional arguments to pass to the pipeline |
|
|
|
Returns: |
|
FlareDetectionPipeline: A pipeline for flare detection |
|
""" |
|
return FlareDetectionPipeline( |
|
model=model_name, |
|
device=device, |
|
**kwargs |
|
) |