synyyy commited on
Commit
942de71
·
verified ·
1 Parent(s): eb4d4fd

Add example usage script

Browse files
Files changed (1) hide show
  1. example_usage.py +63 -0
example_usage.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage script for T5 Spotify Features model
3
+ """
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
5
+ import json
6
+
7
+ def load_model():
8
+ """Load the model and tokenizer"""
9
+ model = T5ForConditionalGeneration.from_pretrained("synyyy/t5-spotify-features-v2")
10
+ tokenizer = T5Tokenizer.from_pretrained("synyyy/t5-spotify-features-v2")
11
+ return model, tokenizer
12
+
13
+ def generate_spotify_features(prompt, model, tokenizer):
14
+ """Generate Spotify features from text prompt"""
15
+ input_text = f"prompt: {prompt}"
16
+
17
+ input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).input_ids
18
+ outputs = model.generate(
19
+ input_ids,
20
+ max_length=256,
21
+ num_beams=4,
22
+ early_stopping=True,
23
+ do_sample=False
24
+ )
25
+
26
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+
28
+ # Post-process JSON if needed
29
+ if not result.strip().startswith(') and not result.strip().endswith('):
30
+ result = " + result + "
31
+
32
+ try:
33
+ return json.loads(result)
34
+ except json.JSONDecodeError as e:
35
+ print(f"JSON parsing failed: {e}")
36
+ print(f"Raw output: {result}")
37
+ return None
38
+
39
+ if __name__ == "__main__":
40
+ # Load model
41
+ print("Loading model...")
42
+ model, tokenizer = load_model()
43
+
44
+ # Test prompts
45
+ test_prompts = [
46
+ "energetic dance music for parties",
47
+ "calm acoustic music for studying",
48
+ "upbeat pop songs for working out",
49
+ "relaxing instrumental background music",
50
+ "happy music for road trips"
51
+ ]
52
+
53
+ print("\nGenerating features for test prompts:")
54
+ print("=" * 50)
55
+
56
+ for prompt in test_prompts:
57
+ print(f"\nPrompt: {prompt}")
58
+ features = generate_spotify_features(prompt, model, tokenizer)
59
+ if features:
60
+ print(f"Features: {json.dumps(features, indent=2)}")
61
+ else:
62
+ print("Failed to generate valid features")
63
+ print("-" * 30)