Safetensors
astronomy
multimodal
classification
MeriDK commited on
Commit
9a7bc12
·
verified ·
1 Parent(s): 78c2a17

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +117 -5
README.md CHANGED
@@ -1,9 +1,121 @@
1
  ---
2
  tags:
3
- - model_hub_mixin
4
- - pytorch_model_hub_mixin
 
5
  ---
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  tags:
3
+ - astronomy
4
+ - multimodal
5
+ - classification
6
  ---
7
 
8
+ # AstroM3-CLIP-photo
9
+
10
+ AstroM³ is a self-supervised multimodal model for astronomy that integrates time-series photometry, spectra, and metadata into a unified embedding space
11
+ for classification and other downstream tasks. AstroM³ is trained on [AstroM3Processed](https://huggingface.co/datasets/MeriDK/AstroM3Processed).
12
+ For more details on the AstroM³ architecture, training, and results, please refer to the [paper](https://arxiv.org/abs/2411.08842).
13
+
14
+ <p align="center">
15
+ <img src="astroclip-architecture.png" width="70%">
16
+ <br />
17
+ <span>
18
+ Figure 1: Overview of the multimodal CLIP framework adapted for astronomy, incorporating three data modalities: photometric time-series, spectra, and metadata.
19
+ Each modality is processed by a dedicated encoder to create embeddings, which are then mapped into a shared embedding space through projection heads.
20
+ Pairwise similarity matrices align the embeddings across modalities, and a symmetric cross-entropy loss, computed over these matrices, optimizes the model.
21
+ The total loss, derived from all pairwise losses, guides the model’s trimodal learning.
22
+ </span>
23
+ </p>
24
+
25
+ To perform inference with AstroM³, install the AstroM3 library from our [GitHub repo](https://github.com/MeriDK/AstroM3).
26
+ ```sh
27
+ git clone https://github.com/MeriDK/AstroM3.git
28
+ cd AstroM3
29
+ ```
30
+ Create a virtual environment (tested with Python 3.10.14), then install the required dependencies:
31
+ ```sh
32
+ uv venv venv --python 3.10.14
33
+ source venv/bin/activate
34
+ uv pip install -r requirements.txt
35
+ ```
36
+
37
+ A simple example to get started:
38
+ 1. Data Loading & Preprocessing
39
+ ```python
40
+ from datasets import load_dataset
41
+ from src.data import process_photometry
42
+
43
+ # Load the test dataset
44
+ test_dataset = load_dataset('MeriDK/AstroM3Processed', name='full_42', split='test')
45
+
46
+ # Process photometry to have a fixed sequence length of 200 (center-cropped)
47
+ test_dataset = test_dataset.map(process_photometry, batched=True, fn_kwargs={'seq_len': 200, 'how': 'center'})
48
+ test_dataset = test_dataset.with_format('torch')
49
+ ```
50
+ 2. Model Loading & Embedding Extraction
51
+ ```python
52
+ import torch
53
+ from src.model import AstroM3
54
+
55
+ # Load the base AstroM3-CLIP model
56
+ model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP')
57
+
58
+ # Retrieve the first sample (batch size = 1)
59
+ sample = test_dataset[0:1]
60
+ photometry = sample['photometry']
61
+ photometry_mask = sample['photometry_mask']
62
+ spectra = sample['spectra']
63
+ metadata = sample['metadata']
64
+
65
+ # Example 1: Generate embeddings when all modalities are present
66
+ p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
67
+ multimodal_emb = (p_emb + s_emb + m_emb) / 3
68
+ print('Multimodal Embedding (All Modalities):', multimodal_emb)
69
+
70
+ # Example 2: Generate embeddings when the spectra modality is missing
71
+ dummy_spectra = torch.zeros_like(spectra) # Dummy tensor for missing spectra
72
+ p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, dummy_spectra, metadata)
73
+ multimodal_emb_missing = (p_emb + m_emb) / 2
74
+ print('Multimodal Embedding (Spectra Missing):', multimodal_emb_missing)
75
+ ```
76
+ 3. Classification Examples
77
+ ```python
78
+ from src.model import AstroM3, Informer, GalSpecNet, MetaModel
79
+
80
+ # Photometry classification
81
+ photo_model = Informer.from_pretrained('MeriDK/AstroM3-CLIP-photo')
82
+ prediction = photo_model(photometry, photometry_mask).argmax(dim=1).item()
83
+ print('Photometry Classification:', test_dataset.features['label'].int2str(prediction))
84
+
85
+ # Spectra classification
86
+ spectra_model = GalSpecNet.from_pretrained('MeriDK/AstroM3-CLIP-spectra')
87
+ prediction = spectra_model(spectra).argmax(dim=1).item()
88
+ print('Spectra Classification:', test_dataset.features['label'].int2str(prediction))
89
+
90
+ # Metadata classification
91
+ meta_model = MetaModel.from_pretrained('MeriDK/AstroM3-CLIP-meta')
92
+ prediction = meta_model(metadata).argmax(dim=1).item()
93
+ print('Metadata Classification:', test_dataset.features['label'].int2str(prediction))
94
+
95
+ # Multimodal classification
96
+ all_model = AstroM3.from_pretrained('MeriDK/AstroM3-CLIP-all')
97
+ prediction = all_model(photometry, photometry_mask, spectra, metadata).argmax(dim=1).item()
98
+ print('Multimodal Classification:', test_dataset.features['label'].int2str(prediction))
99
+ ```
100
+
101
+ ## The AstroM³ Family
102
+
103
+ | # Model | # Description |
104
+ | :--- | :--- |
105
+ | [AstroM3-CLIP](https://huggingface.co/MeriDK/AstroM3-CLIP) | The base model pre-trained using the trimodal CLIP approach. |
106
+ | [AstroM3-CLIP-meta](https://huggingface.co/MeriDK/AstroM3-CLIP-meta) | Fine-tuned for metadata-only classification. |
107
+ | [AstroM3-CLIP-spectra](https://huggingface.co/MeriDK/AstroM3-CLIP-spectra) | Fine-tuned for spectra-only classification. |
108
+ | [AstroM3-CLIP-photo](https://huggingface.co/MeriDK/AstroM3-CLIP-photo) | Fine-tuned for photometry-only classification. |
109
+ | [AstroM3-CLIP-all](https://huggingface.co/MeriDK/AstroM3-CLIP-all) | Fine-tuned for multimodal classification. |
110
+
111
+ ## AstroM3-CLIP Variants
112
+ These variants of the base AstroM3-CLIP model are trained using different random seeds (42, 0, 66, 12, 123);
113
+ ensure that the dataset is loaded with the corresponding seed for consistency.
114
+
115
+ | # Model | # Description |
116
+ | :--- | :--- |
117
+ | [AstroM3-CLIP-42](https://huggingface.co/MeriDK/AstroM3-CLIP-42) | The base model pre-trained with random seed 42 (identical to AstroM3-CLIP). |
118
+ | [AstroM3-CLIP-0](https://huggingface.co/MeriDK/AstroM3-CLIP-0) | AstroM3-CLIP pre-trained with random seed 0 (use dataset with seed 0). |
119
+ | [AstroM3-CLIP-66](https://huggingface.co/MeriDK/AstroM3-CLIP-66) | AstroM3-CLIP pre-trained with random seed 66 (use dataset with seed 66). |
120
+ | [AstroM3-CLIP-12](https://huggingface.co/MeriDK/AstroM3-CLIP-12) | AstroM3-CLIP pre-trained with random seed 12 (use dataset with seed 12). |
121
+ | [AstroM3-CLIP-123](https://huggingface.co/MeriDK/AstroM3-CLIP-123) | AstroM3-CLIP pre-trained with random seed 123 (use dataset with seed 123). |