347 lines
10 KiB
Go
347 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/matrix-org/gomatrix"
|
|
)
|
|
|
|
// BotConfig represents the structure of the configuration file
|
|
type BotConfig struct {
|
|
Homeserver string `json:"homeserver"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
RoomID string `json:"room_id"`
|
|
AccessToken string `json:"access_token"`
|
|
AIAPIKey string `json:"ai_api_key"`
|
|
AIConfig AIConfig `json:"ai_config"`
|
|
Prompt string `json:"prompt"` // Global prompt for initial context
|
|
Commands map[string]string `json:"commands"` // Command-specific prompts
|
|
}
|
|
|
|
// AIConfig represents the static part of the AI API request
|
|
type AIConfig struct {
|
|
Model string `json:"model"`
|
|
Stream bool `json:"stream"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Stop []string `json:"stop"`
|
|
Temperature float64 `json:"temperature"`
|
|
TopP float64 `json:"top_p"`
|
|
TopK int `json:"top_k"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
N int `json:"n"`
|
|
}
|
|
|
|
// AIRequest represents the complete request structure for the AI API
|
|
type AIRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Stop []string `json:"stop"`
|
|
Temperature float64 `json:"temperature"`
|
|
TopP float64 `json:"top_p"`
|
|
TopK int `json:"top_k"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
N int `json:"n"`
|
|
}
|
|
|
|
// Message represents a single message structure
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// AIResponse represents the response structure from the AI API
|
|
type AIResponse struct {
|
|
ID string `json:"id"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Object string `json:"object"`
|
|
}
|
|
|
|
// Choice represents a single choice in the AI response
|
|
type Choice struct {
|
|
Message Message `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
// Usage represents the usage information in the AI response
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
// TemporaryPrompts maintains a map of user and command specific prompts
|
|
var TemporaryPrompts = make(map[string]map[string]string)
|
|
|
|
func main() {
|
|
// Load configuration
|
|
config, err := loadBotConfig("bot-config.json")
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
// Create Matrix client
|
|
client, err := gomatrix.NewClient(config.Homeserver, "", "")
|
|
if err != nil {
|
|
log.Fatalf("Failed to create Matrix client: %v", err)
|
|
}
|
|
|
|
// Login to Matrix
|
|
resp, err := client.Login(&gomatrix.ReqLogin{
|
|
Type: "m.login.password",
|
|
User: config.Username,
|
|
Password: config.Password,
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("Failed to login: %v", err)
|
|
}
|
|
|
|
client.SetCredentials(resp.UserID, resp.AccessToken)
|
|
|
|
// Join room
|
|
_, err = client.JoinRoom(config.RoomID, "", nil)
|
|
if err != nil {
|
|
log.Fatalf("Failed to join room: %v", err)
|
|
}
|
|
|
|
log.Println("Successfully login")
|
|
|
|
// Listen for messages
|
|
syncer := client.Syncer.(*gomatrix.DefaultSyncer)
|
|
syncer.OnEventType("m.room.message", func(event *gomatrix.Event) {
|
|
if event.Sender != resp.UserID {
|
|
handleMessage(client, config, event)
|
|
}
|
|
})
|
|
|
|
// Start syncing
|
|
if err := client.Sync(); err != nil {
|
|
log.Fatalf("Sync failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func handleMessage(client *gomatrix.Client, config *BotConfig, event *gomatrix.Event) {
|
|
if event.Content["msgtype"] == "m.text" {
|
|
userMessage := event.Content["body"].(string)
|
|
// Check if the message starts with $$
|
|
if strings.HasPrefix(userMessage, "$$") {
|
|
// Remove the $$ to get the actual message content
|
|
userMessage = strings.TrimPrefix(userMessage, "$$")
|
|
command := getCommand(userMessage)
|
|
|
|
aiResponse, err := getAIResponse(client, config, event.Sender, userMessage, command)
|
|
if err != nil {
|
|
log.Printf("Failed to get AI response: %v", err)
|
|
|
|
if strings.Contains(err.Error(), "429") {
|
|
aiResponse = "Sorry Reached Limited"
|
|
} else {
|
|
return
|
|
}
|
|
}
|
|
|
|
_, err = client.SendText(config.RoomID, aiResponse)
|
|
if err != nil {
|
|
log.Printf("Failed to send message: %v", err)
|
|
}
|
|
|
|
// Log the interaction
|
|
logInteraction(event.Sender, userMessage, aiResponse)
|
|
}
|
|
}
|
|
}
|
|
|
|
func getCommand(message string) string {
|
|
// Extract the command from the message
|
|
parts := strings.Fields(message)
|
|
if len(parts) > 0 {
|
|
return parts[0]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func getAIResponse(client *gomatrix.Client, config *BotConfig, sender, userMessage, command string) (string, error) {
|
|
url := "https://api.siliconflow.cn/v1/chat/completions"
|
|
|
|
// Ensure TemporaryPrompts has an entry for the sender
|
|
if _, exists := TemporaryPrompts[sender]; !exists {
|
|
TemporaryPrompts[sender] = make(map[string]string)
|
|
}
|
|
|
|
// Check if command exists in configuration
|
|
prompt, exists := config.Commands[command]
|
|
if !exists {
|
|
// If command doesn't exist, use default prompt
|
|
command = "default"
|
|
prompt = config.Prompt
|
|
}
|
|
|
|
// Initialize the prompt for the specific command if it doesn't exist
|
|
if _, exists := TemporaryPrompts[sender][command]; !exists {
|
|
TemporaryPrompts[sender][command] = prompt
|
|
}
|
|
|
|
// Prepare AI request with dynamic message content
|
|
messages := []Message{
|
|
{Role: "system", Content: TemporaryPrompts[sender][command]}, // Use user and command-specific prompt
|
|
{Role: "user", Content: userMessage},
|
|
}
|
|
payload := AIRequest{
|
|
Model: config.AIConfig.Model,
|
|
Messages: messages,
|
|
Stream: config.AIConfig.Stream,
|
|
MaxTokens: config.AIConfig.MaxTokens,
|
|
Stop: config.AIConfig.Stop,
|
|
Temperature: config.AIConfig.Temperature,
|
|
TopP: config.AIConfig.TopP,
|
|
TopK: config.AIConfig.TopK,
|
|
FrequencyPenalty: config.AIConfig.FrequencyPenalty,
|
|
N: config.AIConfig.N,
|
|
}
|
|
|
|
payloadBytes, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("accept", "application/json")
|
|
req.Header.Set("content-type", "application/json")
|
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", config.AIAPIKey))
|
|
|
|
clientHTTP := &http.Client{}
|
|
resp, err := clientHTTP.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
if resp.StatusCode == http.StatusTooManyRequests {
|
|
// Notify user of rate limit and retry after 10 seconds
|
|
_, err := client.SendText(config.RoomID, "Request limit reached. Retrying in 10 seconds...")
|
|
if err != nil {
|
|
log.Printf("Failed to send retry notification: %v", err)
|
|
}
|
|
time.Sleep(10 * time.Second)
|
|
return getAIResponse(client, config, sender, userMessage, command)
|
|
}
|
|
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to read response body: %w", err)
|
|
}
|
|
|
|
var aiResp AIResponse
|
|
if err := json.Unmarshal(body, &aiResp); err != nil {
|
|
return "", fmt.Errorf("failed to unmarshal response: %w", err)
|
|
}
|
|
|
|
if len(aiResp.Choices) > 0 {
|
|
// Update the user's prompt with the latest conversation for memory
|
|
TemporaryPrompts[sender][command] += "\nUser: " + userMessage + "\nAI: " + aiResp.Choices[0].Message.Content
|
|
fmt.Printf("Size %v\n", len(TemporaryPrompts[sender][command]))
|
|
return aiResp.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
return "No response from AI", nil
|
|
}
|
|
|
|
func loadBotConfig(filename string) (*BotConfig, error) {
|
|
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
|
// Create default config
|
|
defaultConfig := BotConfig{
|
|
Homeserver: "https://matrix.org",
|
|
Username: "your_username",
|
|
Password: "your_password",
|
|
RoomID: "!your_room_id:matrix.org",
|
|
AccessToken: "",
|
|
AIAPIKey: "your_api_key",
|
|
Prompt: "Your default prompt here", // Default prompt
|
|
AIConfig: AIConfig{
|
|
Model: "deepseek-ai/deepseek-v2-chat",
|
|
Stream: false,
|
|
MaxTokens: 512,
|
|
Stop: []string{"string"},
|
|
Temperature: 0.7,
|
|
TopP: 0.7,
|
|
TopK: 50,
|
|
FrequencyPenalty: 0.5,
|
|
N: 1,
|
|
},
|
|
Commands: map[string]string{
|
|
"translate": "Translate prompt here",
|
|
"summarize": "Summarize prompt here",
|
|
},
|
|
}
|
|
|
|
file, err := os.Create(filename)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create config file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
encoder := json.NewEncoder(file)
|
|
if err := encoder.Encode(defaultConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to write default config to file: %w", err)
|
|
}
|
|
|
|
return &defaultConfig, nil
|
|
}
|
|
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open config file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
var config BotConfig
|
|
decoder := json.NewDecoder(file)
|
|
if err := decoder.Decode(&config); err != nil {
|
|
return nil, fmt.Errorf("failed to decode config file: %w", err)
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
func logInteraction(sender, userMessage, aiResponse string) {
|
|
logDir := "logs"
|
|
if _, err := os.Stat(logDir); os.IsNotExist(err) {
|
|
os.Mkdir(logDir, os.ModePerm)
|
|
}
|
|
|
|
timestamp := time.Now().Format("2006-01-02")
|
|
logFilename := fmt.Sprintf("%s/%s.log", logDir, timestamp)
|
|
|
|
logFile, err := os.OpenFile(logFilename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
log.Printf("Failed to open log file: %v", err)
|
|
return
|
|
}
|
|
defer logFile.Close()
|
|
|
|
logEntry := fmt.Sprintf("%s - %s: %s\nAI: %s\n\n", time.Now().Format(time.RFC3339), sender, userMessage, aiResponse)
|
|
if _, err := logFile.WriteString(logEntry); err != nil {
|
|
log.Printf("Failed to write log entry: %v", err)
|
|
}
|
|
}
|