diff --git a/xservice-server/src/main/java/com/xiang/xsa/xservice/ai/server/ChatController.java b/xservice-server/src/main/java/com/xiang/xsa/xservice/ai/server/ChatController.java index 8f80bce..ac4c488 100644 --- a/xservice-server/src/main/java/com/xiang/xsa/xservice/ai/server/ChatController.java +++ b/xservice-server/src/main/java/com/xiang/xsa/xservice/ai/server/ChatController.java @@ -2,7 +2,10 @@ package com.xiang.xsa.xservice.ai.server; import com.xiang.xservice.ai.agent.BaseAgent; import com.xiang.xservice.ai.core.enums.ModelTypeEnum; +import com.xiang.xservice.ai.pojo.enums.AgentEnums; +import com.xiang.xservice.ai.service.AgentService; import lombok.RequiredArgsConstructor; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; @@ -11,11 +14,12 @@ import org.springframework.web.bind.annotation.RestController; @RequiredArgsConstructor public class ChatController { - private final BaseAgent baseAgent; + private final AgentService agentService; @GetMapping("/chat") public String chatDemo(@RequestParam("question") String question, @RequestParam("memoryId") Long memoryId, @RequestParam("userId") Long userId) { - baseAgent.chat(ModelTypeEnum.OPEN_AI, memoryId, userId, question); + BaseAgent agent = agentService.createAgent(AgentEnums.SIMPLE_CHAT_AGENT); + agent.chat(ModelTypeEnum.OPEN_AI, memoryId, userId, question); return "321"; } diff --git a/xservice-service/src/main/java/com/xiang/xservice/ai/agent/BaseAgent.java b/xservice-service/src/main/java/com/xiang/xservice/ai/agent/BaseAgent.java index 8a2ad56..685bc77 100644 --- a/xservice-service/src/main/java/com/xiang/xservice/ai/agent/BaseAgent.java +++ b/xservice-service/src/main/java/com/xiang/xservice/ai/agent/BaseAgent.java @@ -1,6 +1,7 @@ package com.xiang.xservice.ai.agent; import com.xiang.xservice.ai.core.enums.ModelTypeEnum; +import com.xiang.xservice.ai.pojo.enums.AgentEnums; public interface BaseAgent { /** @@ -11,4 +12,6 @@ public interface BaseAgent { * @param message */ void chat(ModelTypeEnum modelType, Long memoryId, Long userId, String message); + + AgentEnums agent(); } diff --git a/xservice-service/src/main/java/com/xiang/xservice/ai/agent/SimpleChatAgent.java b/xservice-service/src/main/java/com/xiang/xservice/ai/agent/SimpleChatAgent.java index 099e4fd..0b3a1f1 100644 --- a/xservice-service/src/main/java/com/xiang/xservice/ai/agent/SimpleChatAgent.java +++ b/xservice-service/src/main/java/com/xiang/xservice/ai/agent/SimpleChatAgent.java @@ -11,6 +11,7 @@ import com.xiang.xservice.ai.core.handler.MyStreamingHandler; import com.xiang.xservice.ai.core.route.TaskRouter; import com.xiang.xservice.ai.core.storage.DbPersistentStore; import com.xiang.xservice.ai.core.storage.MemoryPersistentStore; +import com.xiang.xservice.ai.pojo.enums.AgentEnums; import com.xiang.xservice.ai.repository.manage.IAiSimpleChatMessageManage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; @@ -69,4 +70,9 @@ public class SimpleChatAgent implements BaseAgent { .ignoreErrors() .start(); } + + @Override + public AgentEnums agent() { + return AgentEnums.SIMPLE_CHAT_AGENT; + } } diff --git a/xservice-service/src/main/java/com/xiang/xservice/ai/agent/StockAnalysisAgent.java b/xservice-service/src/main/java/com/xiang/xservice/ai/agent/StockAnalysisAgent.java new file mode 100644 index 0000000..e6e6e9b --- /dev/null +++ b/xservice-service/src/main/java/com/xiang/xservice/ai/agent/StockAnalysisAgent.java @@ -0,0 +1,18 @@ +package com.xiang.xservice.ai.agent; + +import com.xiang.xservice.ai.core.enums.ModelTypeEnum; +import com.xiang.xservice.ai.pojo.enums.AgentEnums; +import org.springframework.stereotype.Service; + +@Service +public class StockAnalysisAgent implements BaseAgent{ + @Override + public void chat(ModelTypeEnum modelType, Long memoryId, Long userId, String message) { + + } + + @Override + public AgentEnums agent() { + return AgentEnums.STOCK_ANALYZER_AGENT; + } +} diff --git a/xservice-service/src/main/java/com/xiang/xservice/ai/pojo/enums/AgentEnums.java b/xservice-service/src/main/java/com/xiang/xservice/ai/pojo/enums/AgentEnums.java new file mode 100644 index 0000000..3f77cde --- /dev/null +++ b/xservice-service/src/main/java/com/xiang/xservice/ai/pojo/enums/AgentEnums.java @@ -0,0 +1,13 @@ +package com.xiang.xservice.ai.pojo.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public enum AgentEnums { + SIMPLE_CHAT_AGENT("SimpleChatAgent"), + STOCK_ANALYZER_AGENT("StockAnalysisAgent"), + ; + private final String name; +} diff --git a/xservice-service/src/main/java/com/xiang/xservice/ai/service/AgentService.java b/xservice-service/src/main/java/com/xiang/xservice/ai/service/AgentService.java new file mode 100644 index 0000000..bba4824 --- /dev/null +++ b/xservice-service/src/main/java/com/xiang/xservice/ai/service/AgentService.java @@ -0,0 +1,36 @@ +package com.xiang.xservice.ai.service; + +import com.google.common.collect.Maps; +import com.xiang.xservice.ai.agent.BaseAgent; +import com.xiang.xservice.ai.agent.SimpleChatAgent; +import com.xiang.xservice.ai.agent.StockAnalysisAgent; +import com.xiang.xservice.ai.config.OpenAIBaseConfig; +import com.xiang.xservice.ai.core.route.TaskRouter; +import com.xiang.xservice.ai.pojo.enums.AgentEnums; +import com.xiang.xservice.ai.repository.manage.IAiSimpleChatMessageManage; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +@Service +@RequiredArgsConstructor +public class AgentService { + + private final static Map agents = Maps.newHashMap(); + + public AgentService(List agentList) { + agents.putAll(agentList.stream() + .collect(Collectors.toMap(BaseAgent::agent, Function.identity()))); + } + + public BaseAgent createAgent(AgentEnums name) { + BaseAgent agent = agents.get(name); + if (Objects.isNull(agent)) throw new RuntimeException("Agent not found: " + name); + return agent; + } +}