Spaces:
Sleeping
Sleeping
package whisper | |
import ( | |
"fmt" | |
) | |
/////////////////////////////////////////////////////////////////////////////// | |
// CGO | |
/* | |
#include <whisper.h> | |
*/ | |
import "C" | |
/////////////////////////////////////////////////////////////////////////////// | |
// PUBLIC METHODS | |
func (p *Params) SetTranslate(v bool) { | |
p.translate = toBool(v) | |
} | |
func (p *Params) SetSplitOnWord(v bool) { | |
p.split_on_word = toBool(v) | |
} | |
func (p *Params) SetNoContext(v bool) { | |
p.no_context = toBool(v) | |
} | |
func (p *Params) SetSingleSegment(v bool) { | |
p.single_segment = toBool(v) | |
} | |
func (p *Params) SetPrintSpecial(v bool) { | |
p.print_special = toBool(v) | |
} | |
func (p *Params) SetPrintProgress(v bool) { | |
p.print_progress = toBool(v) | |
} | |
func (p *Params) SetPrintRealtime(v bool) { | |
p.print_realtime = toBool(v) | |
} | |
func (p *Params) SetPrintTimestamps(v bool) { | |
p.print_timestamps = toBool(v) | |
} | |
// Set language id | |
func (p *Params) SetLanguage(lang int) error { | |
if lang == -1 { | |
p.language = nil | |
return nil | |
} | |
str := C.whisper_lang_str(C.int(lang)) | |
if str == nil { | |
return ErrInvalidLanguage | |
} else { | |
p.language = str | |
} | |
return nil | |
} | |
// Get language id | |
func (p *Params) Language() int { | |
if p.language == nil { | |
return -1 | |
} | |
return int(C.whisper_lang_id(p.language)) | |
} | |
// Threads available | |
func (p *Params) Threads() int { | |
return int(p.n_threads) | |
} | |
// Set number of threads to use | |
func (p *Params) SetThreads(threads int) { | |
p.n_threads = C.int(threads) | |
} | |
// Set start offset in ms | |
func (p *Params) SetOffset(offset_ms int) { | |
p.offset_ms = C.int(offset_ms) | |
} | |
// Set audio duration to process in ms | |
func (p *Params) SetDuration(duration_ms int) { | |
p.duration_ms = C.int(duration_ms) | |
} | |
// Set timestamp token probability threshold (~0.01) | |
func (p *Params) SetTokenThreshold(t float32) { | |
p.thold_pt = C.float(t) | |
} | |
// Set timestamp token sum probability threshold (~0.01) | |
func (p *Params) SetTokenSumThreshold(t float32) { | |
p.thold_ptsum = C.float(t) | |
} | |
// Set max segment length in characters | |
func (p *Params) SetMaxSegmentLength(n int) { | |
p.max_len = C.int(n) | |
} | |
func (p *Params) SetTokenTimestamps(b bool) { | |
p.token_timestamps = toBool(b) | |
} | |
// Set max tokens per segment (0 = no limit) | |
func (p *Params) SetMaxTokensPerSegment(n int) { | |
p.max_tokens = C.int(n) | |
} | |
// Set audio encoder context | |
func (p *Params) SetAudioCtx(n int) { | |
p.audio_ctx = C.int(n) | |
} | |
func (p *Params) SetMaxContext(n int) { | |
p.n_max_text_ctx = C.int(n) | |
} | |
func (p *Params) SetBeamSize(n int) { | |
p.beam_search.beam_size = C.int(n) | |
} | |
func (p *Params) SetEntropyThold(t float32) { | |
p.entropy_thold = C.float(t) | |
} | |
// Set initial prompt | |
func (p *Params) SetInitialPrompt(prompt string) { | |
p.initial_prompt = C.CString(prompt) | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// PRIVATE METHODS | |
func toBool(v bool) C.bool { | |
if v { | |
return C.bool(true) | |
} | |
return C.bool(false) | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// STRINGIFY | |
func (p *Params) String() string { | |
str := "<whisper.params" | |
str += fmt.Sprintf(" strategy=%v", p.strategy) | |
str += fmt.Sprintf(" n_threads=%d", p.n_threads) | |
if p.language != nil { | |
str += fmt.Sprintf(" language=%s", C.GoString(p.language)) | |
} | |
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx) | |
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms) | |
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms) | |
str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx) | |
str += fmt.Sprintf(" initial_prompt=%s", C.GoString(p.initial_prompt)) | |
str += fmt.Sprintf(" entropy_thold=%f", p.entropy_thold) | |
str += fmt.Sprintf(" beam_size=%d", p.beam_search.beam_size) | |
if p.translate { | |
str += " translate" | |
} | |
if p.no_context { | |
str += " no_context" | |
} | |
if p.single_segment { | |
str += " single_segment" | |
} | |
if p.print_special { | |
str += " print_special" | |
} | |
if p.print_progress { | |
str += " print_progress" | |
} | |
if p.print_realtime { | |
str += " print_realtime" | |
} | |
if p.print_timestamps { | |
str += " print_timestamps" | |
} | |
if p.token_timestamps { | |
str += " token_timestamps" | |
} | |
return str + ">" | |
} | |