from fastapi.responses import Response
from io import BytesIO
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
import numpy as np
import soundfile as sf
from flask import Flask, jsonify
app = Flask(__name__)
tts_config_cache = {} # 缓存 TTS_Config 对象
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format='wav')
return io_buffer
def pack_audio(data: np.ndarray, rate: int, media_type: str):
io_buffer = BytesIO()
io_buffer = pack_wav(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def get_tts_config(tts_infer_yaml_path):
if tts_infer_yaml_path not in tts_config_cache:
tts_config_cache[tts_infer_yaml_path] = TTS_Config(tts_infer_yaml_path)
return tts_config_cache[tts_infer_yaml_path]
def tts_handle(req: dict):
# 打印传入的配置信息
print(f"传入的配置是: {req}")
# 从传入的配置中获取是否启用流式模式的参数,若未提供则默认为False
streaming_mode = req.get("streaming_mode", False)
# 从传入的配置中获取所需的媒体类型,若未提供则默认为wav格式
media_type = req.get("media_type", "wav")
# 从传入的配置中获取TTS推理配置文件的路径,若未提供则默认为"GPT_SoVITS/configs/tts_infer.yaml"
tts_infer_yaml_path = req.get("tts_infer_yaml_path", "GPT_SoVITS/configs/tts_infer.yaml")
# 根据提供的配置文件路径创建TTS配置对象
tts_config = get_tts_config(tts_infer_yaml_path)
try:
# 使用创建的TTS配置对象初始化TTS类的实例
tts_instance = TTS(tts_config)
# 使用初始化的TTS实例处理输入请求,生成音频数据
tts_generator = tts_instance.run(req)
# 获取生成的音频数据和采样率
sr, audio_data = next(tts_generator)
# 将生成的音频数据打包成指定媒体类型的二进制数据
audio_data = pack_audio(audio_data, sr, media_type).getvalue()
# 构建并返回包含音频数据和媒体类型的响应对象
return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e:
# 如果在处理请求过程中发生异常,打印错误信息并返回一个空响应对象
print(f"生成失败: {str(e)}")
return Response("", media_type=f"audio/{media_type}")
@app.route('/hello', methods=['GET'])
def hello():
audio = tts_handle({
"text": "你好",
"text_lang": "zh",
"ref_audio_path": "666.mp3",
"prompt_text": "",
"prompt_lang": "zh",
"top_k": 5,
"top_p": 1,
"temperature": 1,
"text_split_method": "cut0",
"batch_size": 1,
"batch_threshold": 1,
"speed_factor": 1.0,
"split_bucket": True,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"parallel_infer": True,
"repetition_penalty": 1.35,
"tts_infer_yaml_path": "GPT_SoVITS/configs/tts_infer.yaml"
})
print(f"生成结果: {audio}")
return jsonify(message="Hello, World!")
if __name__ == '__main__':
app.run(debug=True, port=5001)