Spaces:
Runtime error
Runtime error
| # Copyright 2022 Tristan Behrens. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| from fastapi import BackgroundTasks, FastAPI | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import os | |
| import io | |
| import random | |
| import base64 | |
| from time import time | |
| from statistics import mean | |
| from collections import OrderedDict | |
| import torch | |
| import wave | |
| from source.logging import create_logger | |
| from source.tokensequence import token_sequence_to_audio, token_sequence_to_image | |
| from source import constants | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| logger = create_logger(__name__) | |
| # Load the auth-token from authtoken.txt. | |
| auth_token = os.getenv("authtoken") | |
| # Loading the model and its tokenizer. | |
| logger.info("Loading tokenizer and model...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token | |
| ) | |
| logger.info("Done.") | |
| # Create the app | |
| logger.info("Creating app...") | |
| app = FastAPI(docs_url=None, redoc_url=None) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| logger.info("Done.") | |
| class Options(BaseModel): | |
| music_style: str | |
| density: str | |
| temperature: str | |
| class NewTask(BaseModel): | |
| music_style = "synth" | |
| density = "medium" | |
| temperature = "medium" | |
| def get_place_in_queue(task_id): | |
| queued_tasks = list( | |
| task | |
| for task in tasks.values() | |
| if task["status"] == "queued" or task["status"] == "processing" | |
| ) | |
| queued_tasks.sort(key=lambda task: task["created_at"]) | |
| queued_task_ids = list(task["task_id"] for task in queued_tasks) | |
| try: | |
| return queued_task_ids.index(task_id) + 1 | |
| except: | |
| return 0 | |
| def calculate_eta(task_id): | |
| total_durations = list( | |
| task["completed_at"] - task["started_at"] | |
| for task in tasks.values() | |
| if "completed_at" in task and task["status"] == "completed" | |
| ) | |
| initial_place_in_queue = tasks[task_id]["initial_place_in_queue"] | |
| if len(total_durations): | |
| eta = initial_place_in_queue * mean(total_durations) | |
| else: | |
| eta = initial_place_in_queue * 35 | |
| return round(eta, 1) | |
| def next_task(task_id): | |
| tasks[task_id]["completed_at"] = time() | |
| queued_tasks = list(task for task in tasks.values() if task["status"] == "queued") | |
| if queued_tasks: | |
| print( | |
| f"{task_id} {tasks[task_id]['status']}. Task/s remaining: {len(queued_tasks)}" | |
| ) | |
| process_task(queued_tasks[0]["task_id"]) | |
| def process_task(task_id): | |
| if "processing" in list(task["status"] for task in tasks.values()): | |
| return | |
| if tasks[task_id]["last_poll"] and time() - tasks[task_id]["last_poll"] > 30: | |
| tasks[task_id]["status"] = "abandoned" | |
| next_task(task_id) | |
| tasks[task_id]["status"] = "processing" | |
| tasks[task_id]["started_at"] = time() | |
| print(f"Processing {task_id}") | |
| try: | |
| tasks[task_id]["output"] = compose( | |
| tasks[task_id]["music_style"], | |
| tasks[task_id]["density"], | |
| tasks[task_id]["temperature"], | |
| ) | |
| except Exception as ex: | |
| tasks[task_id]["status"] = "failed" | |
| tasks[task_id]["error"] = repr(ex) | |
| else: | |
| tasks[task_id]["status"] = "completed" | |
| finally: | |
| next_task(task_id) | |
| def compose(music_style, density, temperature): | |
| instruments = constants.get_instruments(music_style) | |
| density = constants.get_density(density) | |
| temperature = constants.get_temperature(temperature) | |
| print(f"instruments: {instruments} density: {density} temperature: {temperature}") | |
| # Generate with the given parameters. | |
| logger.info(f"Generating token sequence...") | |
| generated_sequence = generate_sequence(instruments, density, temperature) | |
| logger.info(f"Generated token sequence: {generated_sequence}") | |
| # Get the audio data as a array of int16. | |
| logger.info("Generating audio...") | |
| sample_rate, audio_data = token_sequence_to_audio(generated_sequence) | |
| logger.info(f"Done. Audio data: {len(audio_data)}") | |
| # Encode the audio-data as wave file in memory. Use the wave module. | |
| audio_data_bytes = io.BytesIO() | |
| wave_file = wave.open(audio_data_bytes, "wb") | |
| wave_file.setframerate(sample_rate) | |
| wave_file.setnchannels(1) | |
| wave_file.setsampwidth(2) | |
| wave_file.writeframes(audio_data) | |
| wave_file.close() | |
| # Return the audio-data as a base64-encoded string. | |
| audio_data_bytes.seek(0) | |
| audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8") | |
| audio_data_bytes.close() | |
| # Convert the audio data to an PIL image. | |
| image = token_sequence_to_image(generated_sequence) | |
| # Save PIL image to harddrive as PNG. | |
| logger.debug(f"Saving image to harddrive... {type(image)}") | |
| image_file_name = "compose.png" | |
| image.save(image_file_name, "PNG") | |
| # Save image to virtual file. | |
| img_io = io.BytesIO() | |
| image.save(img_io, "PNG", quality=70) | |
| img_io.seek(0) | |
| # Return the image as a base64-encoded string. | |
| image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8") | |
| img_io.close() | |
| # Return. | |
| return { | |
| "tokens": generated_sequence, | |
| "audio": "data:audio/wav;base64," + audio_data_base64, | |
| "image": "data:image/png;base64," + image_data_base64, | |
| "status": "OK", | |
| } | |
| def generate_sequence(instruments, density, temperature): | |
| instruments = instruments[::] | |
| random.shuffle(instruments) | |
| generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0] | |
| for instrument in instruments: | |
| more_ids = tokenizer.encode( | |
| f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt" | |
| )[0] | |
| generated_ids = torch.cat((generated_ids, more_ids)) | |
| generated_ids = generated_ids.unsqueeze(0) | |
| generated_ids = model.generate( | |
| generated_ids, | |
| max_length=2048, | |
| do_sample=True, | |
| temperature=temperature, | |
| eos_token_id=tokenizer.encode("TRACK_END")[0], | |
| )[0] | |
| generated_sequence = tokenizer.decode(generated_ids) | |
| print("GENERATING COMPLETE") | |
| print(generate_sequence) | |
| return generated_sequence | |
| tasks = OrderedDict() | |
| # Route for the loading page. | |
| def index(request): | |
| return FileResponse(path="static/index.html", media_type="text/html") | |
| def create_task(background_tasks: BackgroundTasks, new_task: NewTask): | |
| created_at = time() | |
| task_id = f"{str(created_at)}_{new_task.music_style}" | |
| tasks[task_id] = OrderedDict( | |
| { | |
| "task_id": task_id, | |
| "status": "queued", | |
| "eta": None, | |
| "created_at": created_at, | |
| "started_at": None, | |
| "completed_at": None, | |
| "last_poll": None, | |
| "poll_count": 0, | |
| "initial_place_in_queue": None, | |
| "place_in_queue": None, | |
| "music_style": new_task.music_style, | |
| "density": new_task.density, | |
| "temperature": new_task.temperature, | |
| "output": None, | |
| } | |
| ) | |
| tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id) | |
| tasks[task_id]["eta"] = calculate_eta(task_id) | |
| background_tasks.add_task(process_task, task_id) | |
| return tasks[task_id] | |
| def poll_task(task_id: str): | |
| tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id) | |
| tasks[task_id]["eta"] = calculate_eta(task_id) | |
| tasks[task_id]["last_poll"] = time() | |
| tasks[task_id]["poll_count"] += 1 | |
| return tasks[task_id] | |