Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 25 additions & 35 deletions modelscope_agent/tools/pipeline_tool.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,30 @@
from modelscope.pipelines import pipeline
from .tool import Tool
from modelscope_agent.tools.base import BaseTool, register_tool
import json
import requests
import os


class ModelscopePipelineTool(Tool):

default_model: str = ''
task: str = ''
model_revision = None
@register_tool('pipeline')
class ModelscopePipelineTool(BaseTool):
API_URL = ""
API_KEY = ""

def __init__(self, cfg):

"""
初始化一个ModelscopePipelineTool类
Initialize a ModelscopePipelineTool class.
参数:
cfg (Dict[str, object]): 配置字典,包含了初始化对象所需要的参数
"""
super().__init__(cfg)
self.model = self.cfg.get('model', None) or self.default_model
self.model_revision = self.cfg.get('model_revision',
None) or self.model_revision

self.pipeline_params = self.cfg.get('pipeline_params', {})
self.pipeline = None
self.is_initialized = False

def setup(self):

# only initialize when this tool is really called to save memory
if not self.is_initialized:
self.pipeline = pipeline(
task=self.task,
model=self.model,
model_revision=self.model_revision,
**self.pipeline_params)
self.is_initialized = True

def _local_call(self, *args, **kwargs):

self.setup()
self.API_URL = self.cfg.get(self.name, {}).get('url',None) or self.API_URL
self.API_KEY = os.getenv('MODELSCOPE_API_KEY', None) or self.API_KEY


def call(self, params: str, **kwargs) -> str:
params = self._verify_args(params)
data = json.dumps(params)
headers = {"Authorization": f"Bearer {self.API_KEY}"}
response = requests.request("POST", self.API_URL, headers=headers,data=data)
result = json.loads(response.content.decode("utf-8"))
return result

parsed_args, parsed_kwargs = self._local_parse_input(*args, **kwargs)
origin_result = self.pipeline(*parsed_args, **parsed_kwargs)
final_result = self._parse_output(origin_result, remote=False)
return final_result
22 changes: 16 additions & 6 deletions modelscope_agent/tools/plugin_tool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from copy import deepcopy
from modelscope_agent.tools.base import BaseTool, register_tool

from .tool import Tool


class LangchainTool(Tool):
@register_tool('plugin')
class LangchainTool(BaseTool):
description = '通过调用langchain插件来支持对语言模型的输入输出格式进行处理,输入文本字符,输出经过格式处理的结果'
name = 'plugin'
parameters: list = [{
'name': 'commands',
'description': '需要进行格式处理的文本字符列表',
'required': True,
'type': "string"
}]

def __init__(self, langchain_tool):
from langchain.tools import BaseTool
Expand All @@ -23,8 +30,11 @@ def parse_langchain_schema(self):
tool_arg = deepcopy(arg)
tool_arg['name'] = name
tool_arg['required'] = True
tool_arg['type'] = arg['anyOf'][0].get("type","string")
tool_arg.pop('title')
self.parameters.append(tool_arg)

def _local_call(self, *args, **kwargs):
return {'result': self.langchain_tool.run(kwargs)}
def call(self, params: str, **kwargs):
params = self._verify_args(params)
res = self.langchain_tool.run(params)
return res
40 changes: 9 additions & 31 deletions modelscope_agent/tools/text_to_speech_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from modelscope_agent.output_wrapper import AudioWrapper

from modelscope.utils.constant import Tasks
from .pipeline_tool import ModelscopePipelineTool

from modelscope_agent.output_wrapper import AudioWrapper

class TexttoSpeechTool(ModelscopePipelineTool):
default_model = 'damo/speech_sambert-hifigan_tts_zh-cn_16k'
Expand All @@ -11,35 +8,16 @@ class TexttoSpeechTool(ModelscopePipelineTool):
parameters: list = [{
'name': 'input',
'description': '要转成语音的文本',
'required': True
'required': True,
'type': 'string'
}, {
'name': 'gender',
'description': '用户身份',
'required': True
'required': True,
'type': 'string'
}]
task = Tasks.text_to_speech

def _local_parse_input(self, *args, **kwargs):
if 'gender' not in kwargs:
kwargs['gender'] = 'man'
voice = 'zhizhe_emo' if kwargs['gender'] == 'man' else 'zhiyan_emo'
kwargs['voice'] = voice
if 'text' in kwargs and 'input' not in kwargs:
kwargs['input'] = kwargs['text']
kwargs.pop('text')
kwargs.pop('gender')
return args, kwargs

def _remote_parse_input(self, *args, **kwargs):
if 'gender' not in kwargs:
kwargs['gender'] = 'man'
voice = 'zhizhe_emo' if kwargs['gender'] == 'man' or kwargs[
'gender'] == 'male' else 'zhiyan_emo'
kwargs['parameters'] = {'voice': voice}
kwargs.pop('gender')
return kwargs

def _parse_output(self, origin_result, remote=True):

audio = origin_result['output_wav']
return {'result': AudioWrapper(audio)}
def call(self, params: str, **kwargs) -> str:
result = super().call(params, **kwargs)
audio = result['Data']['output_wav']
return AudioWrapper(audio)
8 changes: 6 additions & 2 deletions tests/tools/test_langchain_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ def test_is_langchain_tool():
def test_run_langchin_tool():
# test run langchain tool
shell_tool = LangchainTool(ShellTool())
res = shell_tool(commands=["echo 'Hello World!'"])
assert res['result'] == 'Hello World!\n'
input = """{'commands': ["echo 'Hello World!'"]}"""
res = shell_tool.call(input)
print(res)
assert res == 'Hello World!\n'

test_run_langchin_tool()
18 changes: 18 additions & 0 deletions tests/tools/test_pipeline_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from modelscope_agent.tools.pipeline_tool import ModelscopePipelineTool
from modelscope.utils.config import Config
import os

cfg = Config.from_file('config/cfg_tool_template.json')
# 请用自己的SDK令牌替换{YOUR_MODELSCOPE_SDK_TOKEN}(包括大括号)
os.environ['MODELSCOPE_API_KEY'] = f"{YOUR_MODELSCOPE_SDK_TOKEN}"

def test_modelscope_speech_generation():
from modelscope_agent.tools.text_to_speech_tool import TexttoSpeechTool
kwargs = """{'input': '北京今天天气怎样?', 'gender': 'man'}"""
txt2speech = TexttoSpeechTool(cfg)
res = txt2speech.call(kwargs)
print(res)


test_modelscope_speech_generation()