Update README.md
Browse files
README.md
CHANGED
@@ -1,9 +1,121 @@
|
|
1 |
---
|
2 |
tags:
|
3 |
-
-
|
4 |
-
-
|
|
|
5 |
---
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
tags:
|
3 |
+
- astronomy
|
4 |
+
- multimodal
|
5 |
+
- classification
|
6 |
---
|
7 |
|
8 |
+
# AstroM3-CLIP-photo
|
9 |
+
|
10 |
+
AstroM³ is a self-supervised multimodal model for astronomy that integrates time-series photometry, spectra, and metadata into a unified embedding space
|
11 |
+
for classification and other downstream tasks. AstroM³ is trained on [AstroM3Processed](https://huggingface.co/datasets/MeriDK/AstroM3Processed).
|
12 |
+
For more details on the AstroM³ architecture, training, and results, please refer to the [paper](https://arxiv.org/abs/2411.08842).
|
13 |
+
|
14 |
+
<p align="center">
|
15 |
+
<img src="astroclip-architecture.png" width="70%">
|
16 |
+
<br />
|
17 |
+
<span>
|
18 |
+
Figure 1: Overview of the multimodal CLIP framework adapted for astronomy, incorporating three data modalities: photometric time-series, spectra, and metadata.
|
19 |
+
Each modality is processed by a dedicated encoder to create embeddings, which are then mapped into a shared embedding space through projection heads.
|
20 |
+
Pairwise similarity matrices align the embeddings across modalities, and a symmetric cross-entropy loss, computed over these matrices, optimizes the model.
|
21 |
+
The total loss, derived from all pairwise losses, guides the model’s trimodal learning.
|
22 |
+
</span>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
To perform inference with AstroM³, install the AstroM3 library from our [GitHub repo](https://github.com/MeriDK/AstroM3).
|
26 |
+
```sh
|
27 |
+
git clone https://github.com/MeriDK/AstroM3.git
|
28 |
+
cd AstroM3
|
29 |
+
```
|
30 |
+
Create a virtual environment (tested with Python 3.10.14), then install the required dependencies:
|
31 |
+
```sh
|
32 |
+
uv venv venv --python 3.10.14
|
33 |
+
source venv/bin/activate
|
34 |
+
uv pip install -r requirements.txt
|
35 |
+
```
|
36 |
+
|
37 |
+
A simple example to get started:
|
38 |
+
1. Data Loading & Preprocessing
|
39 |
+
```python
|
40 |
+
from datasets import load_dataset
|
41 |
+
from src.data import process_photometry
|
42 |
+
|
43 |
+
# Load the test dataset
|
44 |
+
test_dataset = load_dataset('MeriDK/AstroM3Processed', name='full_42', split='test')
|
45 |
+
|
46 |
+
# Process photometry to have a fixed sequence length of 200 (center-cropped)
|
47 |
+
test_dataset = test_dataset.map(process_photometry, batched=True, fn_kwargs={'seq_len': 200, 'how': 'center'})
|
48 |
+
test_dataset = test_dataset.with_format('torch')
|
49 |
+
```
|
50 |
+
2. Model Loading & Embedding Extraction
|
51 |
+
```python
|
52 |
+
import torch
|
53 |
+
from src.model import AstroM3
|
54 |
+
|
55 |
+
# Load the base AstroM3-CLIP model
|
56 |
+
model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP')
|
57 |
+
|
58 |
+
# Retrieve the first sample (batch size = 1)
|
59 |
+
sample = test_dataset[0:1]
|
60 |
+
photometry = sample['photometry']
|
61 |
+
photometry_mask = sample['photometry_mask']
|
62 |
+
spectra = sample['spectra']
|
63 |
+
metadata = sample['metadata']
|
64 |
+
|
65 |
+
# Example 1: Generate embeddings when all modalities are present
|
66 |
+
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
|
67 |
+
multimodal_emb = (p_emb + s_emb + m_emb) / 3
|
68 |
+
print('Multimodal Embedding (All Modalities):', multimodal_emb)
|
69 |
+
|
70 |
+
# Example 2: Generate embeddings when the spectra modality is missing
|
71 |
+
dummy_spectra = torch.zeros_like(spectra) # Dummy tensor for missing spectra
|
72 |
+
p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, dummy_spectra, metadata)
|
73 |
+
multimodal_emb_missing = (p_emb + m_emb) / 2
|
74 |
+
print('Multimodal Embedding (Spectra Missing):', multimodal_emb_missing)
|
75 |
+
```
|
76 |
+
3. Classification Examples
|
77 |
+
```python
|
78 |
+
from src.model import AstroM3, Informer, GalSpecNet, MetaModel
|
79 |
+
|
80 |
+
# Photometry classification
|
81 |
+
photo_model = Informer.from_pretrained('MeriDK/AstroM3-CLIP-photo')
|
82 |
+
prediction = photo_model(photometry, photometry_mask).argmax(dim=1).item()
|
83 |
+
print('Photometry Classification:', test_dataset.features['label'].int2str(prediction))
|
84 |
+
|
85 |
+
# Spectra classification
|
86 |
+
spectra_model = GalSpecNet.from_pretrained('MeriDK/AstroM3-CLIP-spectra')
|
87 |
+
prediction = spectra_model(spectra).argmax(dim=1).item()
|
88 |
+
print('Spectra Classification:', test_dataset.features['label'].int2str(prediction))
|
89 |
+
|
90 |
+
# Metadata classification
|
91 |
+
meta_model = MetaModel.from_pretrained('MeriDK/AstroM3-CLIP-meta')
|
92 |
+
prediction = meta_model(metadata).argmax(dim=1).item()
|
93 |
+
print('Metadata Classification:', test_dataset.features['label'].int2str(prediction))
|
94 |
+
|
95 |
+
# Multimodal classification
|
96 |
+
all_model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP-all')
|
97 |
+
prediction = all_model(photometry, photometry_mask, spectra, metadata).argmax(dim=1).item()
|
98 |
+
print('Multimodal Classification:', test_dataset.features['label'].int2str(prediction))
|
99 |
+
```
|
100 |
+
|
101 |
+
## The AstroM³ Family
|
102 |
+
|
103 |
+
| # Model | # Description |
|
104 |
+
| :--- | :--- |
|
105 |
+
| [AstroM3-CLIP](https://huggingface.co/MeriDK/AstroM3-CLIP) | The base model pre-trained using the trimodal CLIP approach. |
|
106 |
+
| [AstroM3-CLIP-meta](https://huggingface.co/MeriDK/AstroM3-CLIP-meta) | Fine-tuned for metadata-only classification. |
|
107 |
+
| [AstroM3-CLIP-spectra](https://huggingface.co/MeriDK/AstroM3-CLIP-spectra) | Fine-tuned for spectra-only classification. |
|
108 |
+
| [AstroM3-CLIP-photo](https://huggingface.co/MeriDK/AstroM3-CLIP-photo) | Fine-tuned for photometry-only classification. |
|
109 |
+
| [AstroM3-CLIP-all](https://huggingface.co/MeriDK/AstroM3-CLIP-all) | Fine-tuned for multimodal classification. |
|
110 |
+
|
111 |
+
## AstroM3-CLIP Variants
|
112 |
+
These variants of the base AstroM3-CLIP model are trained using different random seeds (42, 0, 66, 12, 123);
|
113 |
+
ensure that the dataset is loaded with the corresponding seed for consistency.
|
114 |
+
|
115 |
+
| # Model | # Description |
|
116 |
+
| :--- | :--- |
|
117 |
+
| [AstroM3-CLIP-42](https://huggingface.co/MeriDK/AstroM3-CLIP-42) | The base model pre-trained with random seed 42 (identical to AstroM3-CLIP). |
|
118 |
+
| [AstroM3-CLIP-0](https://huggingface.co/MeriDK/AstroM3-CLIP-0) | AstroM3-CLIP pre-trained with random seed 0 (use dataset with seed 0). |
|
119 |
+
| [AstroM3-CLIP-66](https://huggingface.co/MeriDK/AstroM3-CLIP-66) | AstroM3-CLIP pre-trained with random seed 66 (use dataset with seed 66). |
|
120 |
+
| [AstroM3-CLIP-12](https://huggingface.co/MeriDK/AstroM3-CLIP-12) | AstroM3-CLIP pre-trained with random seed 12 (use dataset with seed 12). |
|
121 |
+
| [AstroM3-CLIP-123](https://huggingface.co/MeriDK/AstroM3-CLIP-123) | AstroM3-CLIP pre-trained with random seed 123 (use dataset with seed 123). |
|