Spaces:
Running
Running
Commit
·
1bd39bb
1
Parent(s):
a528625
update for random ref voice
Browse files- app.py +8 -254
- templates/arena.html +0 -0
app.py
CHANGED
|
@@ -850,260 +850,6 @@ def cleanup_session(session_id):
|
|
| 850 |
# Remove session
|
| 851 |
del app.tts_sessions[session_id]
|
| 852 |
|
| 853 |
-
|
| 854 |
-
@app.route("/api/conversational/generate", methods=["POST"])
|
| 855 |
-
@limiter.limit("5 per minute")
|
| 856 |
-
def generate_podcast():
|
| 857 |
-
# If verification not setup, handle it first
|
| 858 |
-
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
|
| 859 |
-
return jsonify({"error": "Turnstile verification required"}), 403
|
| 860 |
-
|
| 861 |
-
data = request.json
|
| 862 |
-
script = data.get("script")
|
| 863 |
-
|
| 864 |
-
if not script or not isinstance(script, list) or len(script) < 2:
|
| 865 |
-
return jsonify({"error": "Invalid script format or too short"}), 400
|
| 866 |
-
|
| 867 |
-
# Validate script format
|
| 868 |
-
for line in script:
|
| 869 |
-
if not isinstance(line, dict) or "text" not in line or "speaker_id" not in line:
|
| 870 |
-
return (
|
| 871 |
-
jsonify(
|
| 872 |
-
{
|
| 873 |
-
"error": "Invalid script line format. Each line must have text and speaker_id"
|
| 874 |
-
}
|
| 875 |
-
),
|
| 876 |
-
400,
|
| 877 |
-
)
|
| 878 |
-
if (
|
| 879 |
-
not line["text"]
|
| 880 |
-
or not isinstance(line["speaker_id"], int)
|
| 881 |
-
or line["speaker_id"] not in [0, 1]
|
| 882 |
-
):
|
| 883 |
-
return (
|
| 884 |
-
jsonify({"error": "Invalid script content. Speaker ID must be 0 or 1"}),
|
| 885 |
-
400,
|
| 886 |
-
)
|
| 887 |
-
|
| 888 |
-
# Get two conversational models (currently only CSM and PlayDialog)
|
| 889 |
-
available_models = Model.query.filter_by(
|
| 890 |
-
model_type=ModelType.CONVERSATIONAL, is_active=True
|
| 891 |
-
).all()
|
| 892 |
-
|
| 893 |
-
if len(available_models) < 2:
|
| 894 |
-
return jsonify({"error": "Not enough conversational models available"}), 500
|
| 895 |
-
|
| 896 |
-
selected_models = get_weighted_random_models(available_models, 2, ModelType.CONVERSATIONAL)
|
| 897 |
-
|
| 898 |
-
try:
|
| 899 |
-
# Generate audio for both models concurrently
|
| 900 |
-
audio_files = []
|
| 901 |
-
model_ids = []
|
| 902 |
-
|
| 903 |
-
# Function to process a single model
|
| 904 |
-
def process_model(model):
|
| 905 |
-
# Call conversational TTS service
|
| 906 |
-
audio_content = predict_tts(script, model.id)
|
| 907 |
-
|
| 908 |
-
# Save to temp file with unique name
|
| 909 |
-
file_uuid = str(uuid.uuid4())
|
| 910 |
-
dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
|
| 911 |
-
|
| 912 |
-
with open(dest_path, "wb") as f:
|
| 913 |
-
f.write(audio_content)
|
| 914 |
-
|
| 915 |
-
return {"model_id": model.id, "audio_path": dest_path}
|
| 916 |
-
|
| 917 |
-
# Use ThreadPoolExecutor to process models concurrently
|
| 918 |
-
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 919 |
-
results = list(executor.map(process_model, selected_models))
|
| 920 |
-
|
| 921 |
-
# Extract results
|
| 922 |
-
for result in results:
|
| 923 |
-
model_ids.append(result["model_id"])
|
| 924 |
-
audio_files.append(result["audio_path"])
|
| 925 |
-
|
| 926 |
-
# Create session
|
| 927 |
-
session_id = str(uuid.uuid4())
|
| 928 |
-
script_text = " ".join([line["text"] for line in script])
|
| 929 |
-
app.conversational_sessions[session_id] = {
|
| 930 |
-
"model_a": model_ids[0],
|
| 931 |
-
"model_b": model_ids[1],
|
| 932 |
-
"audio_a": audio_files[0],
|
| 933 |
-
"audio_b": audio_files[1],
|
| 934 |
-
"text": script_text[:1000], # Limit text length
|
| 935 |
-
"created_at": datetime.utcnow(),
|
| 936 |
-
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
| 937 |
-
"voted": False,
|
| 938 |
-
"script": script,
|
| 939 |
-
}
|
| 940 |
-
|
| 941 |
-
# Return audio file paths and session
|
| 942 |
-
return jsonify(
|
| 943 |
-
{
|
| 944 |
-
"session_id": session_id,
|
| 945 |
-
"audio_a": f"/api/conversational/audio/{session_id}/a",
|
| 946 |
-
"audio_b": f"/api/conversational/audio/{session_id}/b",
|
| 947 |
-
"expires_in": 1800, # 30 minutes in seconds
|
| 948 |
-
}
|
| 949 |
-
)
|
| 950 |
-
|
| 951 |
-
except Exception as e:
|
| 952 |
-
app.logger.error(f"Conversational generation error: {str(e)}")
|
| 953 |
-
return jsonify({"error": f"Failed to generate podcast: {str(e)}"}), 500
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
@app.route("/api/conversational/audio/<session_id>/<model_key>")
|
| 957 |
-
def get_podcast_audio(session_id, model_key):
|
| 958 |
-
# If verification not setup, handle it first
|
| 959 |
-
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
|
| 960 |
-
return jsonify({"error": "Turnstile verification required"}), 403
|
| 961 |
-
|
| 962 |
-
if session_id not in app.conversational_sessions:
|
| 963 |
-
return jsonify({"error": "Invalid or expired session"}), 404
|
| 964 |
-
|
| 965 |
-
session_data = app.conversational_sessions[session_id]
|
| 966 |
-
|
| 967 |
-
# Check if session expired
|
| 968 |
-
if datetime.utcnow() > session_data["expires_at"]:
|
| 969 |
-
cleanup_conversational_session(session_id)
|
| 970 |
-
return jsonify({"error": "Session expired"}), 410
|
| 971 |
-
|
| 972 |
-
if model_key == "a":
|
| 973 |
-
audio_path = session_data["audio_a"]
|
| 974 |
-
elif model_key == "b":
|
| 975 |
-
audio_path = session_data["audio_b"]
|
| 976 |
-
else:
|
| 977 |
-
return jsonify({"error": "Invalid model key"}), 400
|
| 978 |
-
|
| 979 |
-
# Check if file exists
|
| 980 |
-
if not os.path.exists(audio_path):
|
| 981 |
-
return jsonify({"error": "Audio file not found"}), 404
|
| 982 |
-
|
| 983 |
-
return send_file(audio_path, mimetype="audio/wav")
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
@app.route("/api/conversational/vote", methods=["POST"])
|
| 987 |
-
@limiter.limit("30 per minute")
|
| 988 |
-
def submit_podcast_vote():
|
| 989 |
-
# If verification not setup, handle it first
|
| 990 |
-
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
|
| 991 |
-
return jsonify({"error": "Turnstile verification required"}), 403
|
| 992 |
-
|
| 993 |
-
data = request.json
|
| 994 |
-
session_id = data.get("session_id")
|
| 995 |
-
chosen_model_key = data.get("chosen_model") # "a" or "b"
|
| 996 |
-
|
| 997 |
-
if not session_id or session_id not in app.conversational_sessions:
|
| 998 |
-
return jsonify({"error": "Invalid or expired session"}), 404
|
| 999 |
-
|
| 1000 |
-
if not chosen_model_key or chosen_model_key not in ["a", "b"]:
|
| 1001 |
-
return jsonify({"error": "Invalid chosen model"}), 400
|
| 1002 |
-
|
| 1003 |
-
session_data = app.conversational_sessions[session_id]
|
| 1004 |
-
|
| 1005 |
-
# Check if session expired
|
| 1006 |
-
if datetime.utcnow() > session_data["expires_at"]:
|
| 1007 |
-
cleanup_conversational_session(session_id)
|
| 1008 |
-
return jsonify({"error": "Session expired"}), 410
|
| 1009 |
-
|
| 1010 |
-
# Check if already voted
|
| 1011 |
-
if session_data["voted"]:
|
| 1012 |
-
return jsonify({"error": "Vote already submitted for this session"}), 400
|
| 1013 |
-
|
| 1014 |
-
# Get model IDs and audio paths
|
| 1015 |
-
chosen_id = (
|
| 1016 |
-
session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
|
| 1017 |
-
)
|
| 1018 |
-
rejected_id = (
|
| 1019 |
-
session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
|
| 1020 |
-
)
|
| 1021 |
-
chosen_audio_path = (
|
| 1022 |
-
session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
|
| 1023 |
-
)
|
| 1024 |
-
rejected_audio_path = (
|
| 1025 |
-
session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
|
| 1026 |
-
)
|
| 1027 |
-
|
| 1028 |
-
# Record vote in database
|
| 1029 |
-
user_id = current_user.id if current_user.is_authenticated else None
|
| 1030 |
-
vote, error = record_vote(
|
| 1031 |
-
user_id, session_data["text"], chosen_id, rejected_id, ModelType.CONVERSATIONAL
|
| 1032 |
-
)
|
| 1033 |
-
|
| 1034 |
-
if error:
|
| 1035 |
-
return jsonify({"error": error}), 500
|
| 1036 |
-
|
| 1037 |
-
# --- Save preference data ---\
|
| 1038 |
-
try:
|
| 1039 |
-
vote_uuid = str(uuid.uuid4())
|
| 1040 |
-
vote_dir = os.path.join("./votes", vote_uuid)
|
| 1041 |
-
os.makedirs(vote_dir, exist_ok=True)
|
| 1042 |
-
|
| 1043 |
-
# Copy audio files
|
| 1044 |
-
shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
|
| 1045 |
-
shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
|
| 1046 |
-
|
| 1047 |
-
# Create metadata
|
| 1048 |
-
chosen_model_obj = Model.query.get(chosen_id)
|
| 1049 |
-
rejected_model_obj = Model.query.get(rejected_id)
|
| 1050 |
-
metadata = {
|
| 1051 |
-
"script": session_data["script"], # Save the full script
|
| 1052 |
-
"chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
|
| 1053 |
-
"chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
|
| 1054 |
-
"rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
|
| 1055 |
-
"rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
|
| 1056 |
-
"session_id": session_id,
|
| 1057 |
-
"timestamp": datetime.utcnow().isoformat(),
|
| 1058 |
-
"username": current_user.username if current_user.is_authenticated else None,
|
| 1059 |
-
"model_type": "CONVERSATIONAL"
|
| 1060 |
-
}
|
| 1061 |
-
with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
|
| 1062 |
-
json.dump(metadata, f, indent=2)
|
| 1063 |
-
|
| 1064 |
-
except Exception as e:
|
| 1065 |
-
app.logger.error(f"Error saving preference data for conversational vote {session_id}: {str(e)}")
|
| 1066 |
-
# Continue even if saving preference data fails, vote is already recorded
|
| 1067 |
-
|
| 1068 |
-
# Mark session as voted
|
| 1069 |
-
session_data["voted"] = True
|
| 1070 |
-
|
| 1071 |
-
# Return updated models (use previously fetched objects)
|
| 1072 |
-
return jsonify(
|
| 1073 |
-
{
|
| 1074 |
-
"success": True,
|
| 1075 |
-
"chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
|
| 1076 |
-
"rejected_model": {
|
| 1077 |
-
"id": rejected_id,
|
| 1078 |
-
"name": rejected_model_obj.name if rejected_model_obj else "Unknown",
|
| 1079 |
-
},
|
| 1080 |
-
"names": {
|
| 1081 |
-
"a": Model.query.get(session_data["model_a"]).name,
|
| 1082 |
-
"b": Model.query.get(session_data["model_b"]).name,
|
| 1083 |
-
},
|
| 1084 |
-
}
|
| 1085 |
-
)
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
def cleanup_conversational_session(session_id):
|
| 1089 |
-
"""Remove conversational session and its audio files"""
|
| 1090 |
-
if session_id in app.conversational_sessions:
|
| 1091 |
-
session = app.conversational_sessions[session_id]
|
| 1092 |
-
|
| 1093 |
-
# Remove audio files
|
| 1094 |
-
for audio_file in [session["audio_a"], session["audio_b"]]:
|
| 1095 |
-
if os.path.exists(audio_file):
|
| 1096 |
-
try:
|
| 1097 |
-
os.remove(audio_file)
|
| 1098 |
-
except Exception as e:
|
| 1099 |
-
app.logger.error(
|
| 1100 |
-
f"Error removing conversational audio file: {str(e)}"
|
| 1101 |
-
)
|
| 1102 |
-
|
| 1103 |
-
# Remove session
|
| 1104 |
-
del app.conversational_sessions[session_id]
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
# Schedule periodic cleanup
|
| 1108 |
def setup_cleanup():
|
| 1109 |
def cleanup_expired_sessions():
|
|
@@ -1375,6 +1121,14 @@ def get_reference_audio(filename):
|
|
| 1375 |
return jsonify({"error": "Reference audio not found"}), 404
|
| 1376 |
return send_file(file_path, mimetype="audio/wav")
|
| 1377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1378 |
|
| 1379 |
def get_weighted_random_models(
|
| 1380 |
applicable_models: list[Model], num_to_select: int, model_type: ModelType
|
|
|
|
| 850 |
# Remove session
|
| 851 |
del app.tts_sessions[session_id]
|
| 852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
# Schedule periodic cleanup
|
| 854 |
def setup_cleanup():
|
| 855 |
def cleanup_expired_sessions():
|
|
|
|
| 1121 |
return jsonify({"error": "Reference audio not found"}), 404
|
| 1122 |
return send_file(file_path, mimetype="audio/wav")
|
| 1123 |
|
| 1124 |
+
@app.route('/api/voice/random', methods=['GET'])
|
| 1125 |
+
def get_random_voice():
|
| 1126 |
+
# 随机选择一个音频文件
|
| 1127 |
+
random_voice = random.choice(reference_audio_files)
|
| 1128 |
+
voice_path = os.path.join(REFERENCE_AUDIO_DIR, random_voice)
|
| 1129 |
+
|
| 1130 |
+
# 返回音频文件
|
| 1131 |
+
return send_file(voice_path, mimetype='audio/' + voice_path.split('.')[-1])
|
| 1132 |
|
| 1133 |
def get_weighted_random_models(
|
| 1134 |
applicable_models: list[Model], num_to_select: int, model_type: ModelType
|
templates/arena.html
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|