jbilcke-hf HF staff commited on
Commit
f7f7c46
·
1 Parent(s): 13356ed

add a system to recover from broken generation cycles

Browse files
Files changed (7) hide show
  1. README.md +2 -1
  2. database.json +0 -0
  3. src/getDatabase.mts +9 -0
  4. src/getStats.mts +5 -3
  5. src/index.mts +99 -30
  6. src/prompts.mts +0 -0
  7. src/types.mts +23 -0
README.md CHANGED
@@ -8,4 +8,5 @@ pinned: false
8
  app_port: 8000
9
  ---
10
 
11
- Media server 📡
 
 
8
  app_port: 8000
9
  ---
10
 
11
+ Media server 📡
12
+
database.json ADDED
The diff for this file is too large to render. See raw diff
 
src/getDatabase.mts ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import { promises as fs } from 'node:fs'
2
+
3
+ import { Database } from './types.mts'
4
+
5
+ export const getDatabase = async (filePath: string): Promise<Database> => {
6
+ const rawFile = await fs.readFile(filePath, 'utf8')
7
+
8
+ return JSON.parse(rawFile) as Database
9
+ }
src/getStats.mts CHANGED
@@ -1,11 +1,13 @@
1
  import { promises as fs } from 'node:fs'
2
 
3
  export const getStats = async () => {
4
- const videoFilePaths = await fs.readdir(process.env.WEBTV_VIDEO_STORAGE_PATH)
5
- const audioFilePaths = await fs.readdir(process.env.WEBTV_AUDIO_STORAGE_PATH)
 
6
 
7
  return {
8
  nbVideoFiles: videoFilePaths.length,
9
- nbAudioFiles: audioFilePaths.length,
 
10
  }
11
  }
 
1
  import { promises as fs } from 'node:fs'
2
 
3
  export const getStats = async () => {
4
+ const videoFilePaths = await fs.readdir(process.env.WEBTV_VIDEO_STORAGE_PATH_NEXT)
5
+ const legacyVideoFilePaths = await fs.readdir(process.env.WEBTV_VIDEO_STORAGE_PATH)
6
+ const legacyAudioFilePaths = await fs.readdir(process.env.WEBTV_AUDIO_STORAGE_PATH)
7
 
8
  return {
9
  nbVideoFiles: videoFilePaths.length,
10
+ nbLegacyVideoFiles: legacyVideoFilePaths.length,
11
+ nbLegacyAudioFiles: legacyAudioFilePaths.length,
12
  }
13
  }
src/index.mts CHANGED
@@ -2,85 +2,154 @@ import { v4 as uuid } from 'uuid'
2
  import { upscaleVideo } from './upscaleVideo.mts'
3
  import { keepVideo } from './keepVideo.mts'
4
 
5
- import { demoPrompts } from './prompts.mts'
6
  import { getStats } from './getStats.mts'
7
  import { enhanceVideo } from './enhanceVideo.mts'
8
  import { callZeroscope } from './callZeroscope.mts'
9
  import { downloadVideo } from './downloadVideo.mts'
 
 
 
 
 
 
 
 
 
 
10
 
11
  const main = async () => {
12
- console.log('Generating ideas..')
13
- const ideas = demoPrompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  console.log('Generating videos sequences..')
16
  const instanceId = process.env.WEBTV_WORKER_INSTANCE_ID || '0'
17
 
18
- for (const { input, captions } of ideas) {
19
- console.log(`\nVideo sequence to generate: ${input}`)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  const sequenceId = uuid()
22
 
23
- const silentVideos: string[] = []
24
 
25
  // this is hardcoded everywhere for now, since videos longer than 3 sec require the Nvidia A100
26
  const videoDurationInSecs = 3
27
 
28
- for (const caption of captions) {
29
- console.log(`- generating shot: ${caption}`)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try {
31
- const generatedVideoUrl = await callZeroscope(caption)
32
 
33
- const videoName = `inst_${instanceId}_seq_${sequenceId}_shot_${Date.now()}.mp4`
34
 
35
- console.log(`- downloading ${videoName} from ${generatedVideoUrl}`)
36
- await downloadVideo(generatedVideoUrl, videoName)
37
 
38
- console.log(`- downloaded ${videoName}`)
39
 
40
- console.log('- upscaling video..')
41
 
42
  try {
43
- await upscaleVideo(videoName, caption)
44
  } catch (err) {
45
  // upscaling is finicky, if it fails we try again
46
- console.log('- trying again to upscale video..')
47
- await upscaleVideo(videoName, caption)
48
  }
49
 
50
- console.log('- enhancing video..')
51
- await enhanceVideo(videoName)
52
 
53
- console.log('- saving final video..')
54
- await keepVideo(videoName, process.env.WEBTV_VIDEO_STORAGE_PATH_NEXT)
55
 
56
- silentVideos.push(videoName)
57
 
58
  console.log('- done!')
59
  } catch (err) {
60
  console.log(`- error: ${err}`)
61
  }
62
 
63
- const totalRunTime = videoDurationInSecs * silentVideos.length
64
 
65
  if (totalRunTime <= 0) {
66
  continue
67
  }
68
 
69
- console.log(`TODO: generate ${totalRunTime} seconds of music`)
70
  // TODO: generate music from MusicGen, with the correct length
71
  // (or we could generate a slightly larger track and let ffmpeg cut it)
 
 
72
 
73
- }
74
-
75
- console.log('Finished gerating all video sequences')
76
 
77
- console.log('Current stats:', await getStats())
78
  }
79
 
80
- console.log('Finished the full cycle')
 
 
 
81
  setTimeout(() => {
82
  main()
83
- }, 3000)
84
  }
85
 
86
  setTimeout(() => {
 
2
  import { upscaleVideo } from './upscaleVideo.mts'
3
  import { keepVideo } from './keepVideo.mts'
4
 
 
5
  import { getStats } from './getStats.mts'
6
  import { enhanceVideo } from './enhanceVideo.mts'
7
  import { callZeroscope } from './callZeroscope.mts'
8
  import { downloadVideo } from './downloadVideo.mts'
9
+ import { getDatabase } from './getDatabase.mts'
10
+ import { callMusicgen } from './callMusicgen.mts'
11
+
12
+ let hasReachedStartingPoint = false
13
+
14
+ type RunMode = 'running' | 'paused' | 'dry_run'
15
+
16
+ const status = `${process.env.WEBTV_STATUS || 'dry_run'}` as RunMode
17
+
18
+ console.log(`Web TV server status: ${status}`)
19
 
20
  const main = async () => {
21
+ if (status === 'paused') {
22
+ setTimeout(() => {
23
+ main()
24
+ }, 30000)
25
+ return
26
+ }
27
+
28
+ console.log('Reading persistent file structure..')
29
+ const stats = await getStats()
30
+ console.log(`New format: We have ${stats.nbVideoFiles} video files`)
31
+
32
+ console.log(`Legacy: We have ${stats.nbLegacyVideoFiles} video files and ${stats.nbLegacyAudioFiles} audio files`)
33
+
34
+ console.log('Reading prompt database..')
35
+ const db = await getDatabase('./database.json')
36
+
37
+ const nbTotalShots = db.sequences.reduce((a, s) => a + s.shots.length, 0)
38
+ console.log(`Prompt database version: ${db.version}`)
39
+ console.log(`We got ${db.sequences.length} sequences for ${nbTotalShots} shots in total`)
40
 
41
  console.log('Generating videos sequences..')
42
  const instanceId = process.env.WEBTV_WORKER_INSTANCE_ID || '0'
43
 
44
+ const startingPointExists = db.sequences.some(seq => seq.shots.some(shot => shot.shotId === db.startAtShotId))
45
+
46
+ if (!startingPointExists) {
47
+ console.log(`Starting point ${db.startAtShotId} not found, we will start at the beginning`)
48
+ hasReachedStartingPoint = true
49
+ } else if (db.startAtShotId) {
50
+ console.log(`We are going to start at shot ${db.startAtShotId}`)
51
+ } else {
52
+ console.log('We are going to start at the beginning')
53
+ }
54
+
55
+ for (const sequence of db.sequences) {
56
+ const containsStartingPoint = sequence.shots.some(shot => shot.shotId === db.startAtShotId)
57
+
58
+ // we skip sequences that were already processed
59
+ if (!hasReachedStartingPoint && !containsStartingPoint) {
60
+ continue
61
+ }
62
+
63
+ console.log(`
64
+ -----------------------------------------------------------
65
+ Going to generate ${sequence.shots.length} for prompt:
66
+ ${sequence.videoPrompt}
67
+ `)
68
 
69
  const sequenceId = uuid()
70
 
71
+ const generatedShots: string[] = []
72
 
73
  // this is hardcoded everywhere for now, since videos longer than 3 sec require the Nvidia A100
74
  const videoDurationInSecs = 3
75
 
76
+ let shotIndex = 0
77
+
78
+ for (const shot of sequence.shots) {
79
+
80
+ if (shot.shotId === db.startAtShotId) {
81
+ hasReachedStartingPoint = true
82
+ }
83
+
84
+ if (!hasReachedStartingPoint) {
85
+ shotIndex++
86
+ continue
87
+ }
88
+
89
+ console.log(`- generating shot: ${shot.shotId}`)
90
+
91
+ if (status === 'dry_run') {
92
+ // console.log('DRY RUN')
93
+ shotIndex++
94
+ continue
95
+ }
96
+
97
  try {
98
+ const generatedVideoUrl = await callZeroscope(shot.videoPrompt)
99
 
100
+ const shotFileName = `inst_${instanceId}_seq_${sequenceId}_shot_${shotIndex++}_${Date.now()}.mp4`
101
 
102
+ console.log(`- downloading shot ${shotFileName} from ${generatedVideoUrl}`)
103
+ await downloadVideo(generatedVideoUrl, shotFileName)
104
 
105
+ console.log(`- downloaded shot ${shotFileName}`)
106
 
107
+ console.log('- upscaling shot..')
108
 
109
  try {
110
+ await upscaleVideo(shotFileName, shot.videoPrompt)
111
  } catch (err) {
112
  // upscaling is finicky, if it fails we try again
113
+ console.log('- trying again to upscale shot..')
114
+ await upscaleVideo(shotFileName, shot.videoPrompt)
115
  }
116
 
117
+ console.log('- enhancing shot..')
118
+ await enhanceVideo(shotFileName)
119
 
120
+ console.log('- saving final shot..')
121
+ await keepVideo(shotFileName, process.env.WEBTV_VIDEO_STORAGE_PATH_NEXT)
122
 
123
+ generatedShots.push(shotFileName)
124
 
125
  console.log('- done!')
126
  } catch (err) {
127
  console.log(`- error: ${err}`)
128
  }
129
 
130
+ const totalRunTime = videoDurationInSecs * generatedShots.length
131
 
132
  if (totalRunTime <= 0) {
133
  continue
134
  }
135
 
 
136
  // TODO: generate music from MusicGen, with the correct length
137
  // (or we could generate a slightly larger track and let ffmpeg cut it)
138
+ console.log(`TODO: generate ${totalRunTime} seconds of music`)
139
+ await callMusicgen(sequence.audioPrompt) // this does nothing for now
140
 
141
+ }
 
 
142
 
143
+ console.log('Finished generating sequence')
144
  }
145
 
146
+ console.log('Finished the cycle')
147
+
148
+ hasReachedStartingPoint = true // set this to true in all cases
149
+
150
  setTimeout(() => {
151
  main()
152
+ }, 2000)
153
  }
154
 
155
  setTimeout(() => {
src/prompts.mts CHANGED
The diff for this file is too large to render. See raw diff
 
src/types.mts ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export interface Shot {
2
+ shotId: string
3
+ index: number
4
+ lastGenerationAt: string
5
+ videoPrompt: string
6
+ audioPrompt: string
7
+ }
8
+
9
+ export interface Sequence {
10
+ sequenceId: string
11
+ lastGenerationAt: string
12
+ videoPrompt: string
13
+ audioPrompt: string
14
+ channel: string
15
+ tags: string[]
16
+ shots: Shot[]
17
+ }
18
+
19
+ export interface Database {
20
+ version: number
21
+ startAtShotId: string
22
+ sequences: Sequence[]
23
+ }