matrix/gpt.go
2024-11-21 20:09:48 +08:00

347 lines
10 KiB
Go

package main
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/matrix-org/gomatrix"
)
// BotConfig represents the structure of the configuration file
type BotConfig struct {
Homeserver string `json:"homeserver"`
Username string `json:"username"`
Password string `json:"password"`
RoomID string `json:"room_id"`
AccessToken string `json:"access_token"`
AIAPIKey string `json:"ai_api_key"`
AIConfig AIConfig `json:"ai_config"`
Prompt string `json:"prompt"` // Global prompt for initial context
Commands map[string]string `json:"commands"` // Command-specific prompts
}
// AIConfig represents the static part of the AI API request
type AIConfig struct {
Model string `json:"model"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
Stop []string `json:"stop"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
FrequencyPenalty float64 `json:"frequency_penalty"`
N int `json:"n"`
}
// AIRequest represents the complete request structure for the AI API
type AIRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
Stop []string `json:"stop"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
FrequencyPenalty float64 `json:"frequency_penalty"`
N int `json:"n"`
}
// Message represents a single message structure
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// AIResponse represents the response structure from the AI API
type AIResponse struct {
ID string `json:"id"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
Created int64 `json:"created"`
Model string `json:"model"`
Object string `json:"object"`
}
// Choice represents a single choice in the AI response
type Choice struct {
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
// Usage represents the usage information in the AI response
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// TemporaryPrompts maintains a map of user and command specific prompts
var TemporaryPrompts = make(map[string]map[string]string)
func main() {
// Load configuration
config, err := loadBotConfig("bot-config.json")
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Create Matrix client
client, err := gomatrix.NewClient(config.Homeserver, "", "")
if err != nil {
log.Fatalf("Failed to create Matrix client: %v", err)
}
// Login to Matrix
resp, err := client.Login(&gomatrix.ReqLogin{
Type: "m.login.password",
User: config.Username,
Password: config.Password,
})
if err != nil {
log.Fatalf("Failed to login: %v", err)
}
client.SetCredentials(resp.UserID, resp.AccessToken)
// Join room
_, err = client.JoinRoom(config.RoomID, "", nil)
if err != nil {
log.Fatalf("Failed to join room: %v", err)
}
log.Println("Successfully login")
// Listen for messages
syncer := client.Syncer.(*gomatrix.DefaultSyncer)
syncer.OnEventType("m.room.message", func(event *gomatrix.Event) {
if event.Sender != resp.UserID {
handleMessage(client, config, event)
}
})
// Start syncing
if err := client.Sync(); err != nil {
log.Fatalf("Sync failed: %v", err)
}
}
func handleMessage(client *gomatrix.Client, config *BotConfig, event *gomatrix.Event) {
if event.Content["msgtype"] == "m.text" {
userMessage := event.Content["body"].(string)
// Check if the message starts with $$
if strings.HasPrefix(userMessage, "$$") {
// Remove the $$ to get the actual message content
userMessage = strings.TrimPrefix(userMessage, "$$")
command := getCommand(userMessage)
aiResponse, err := getAIResponse(client, config, event.Sender, userMessage, command)
if err != nil {
log.Printf("Failed to get AI response: %v", err)
if strings.Contains(err.Error(), "429") {
aiResponse = "Sorry Reached Limited"
} else {
return
}
}
_, err = client.SendText(config.RoomID, aiResponse)
if err != nil {
log.Printf("Failed to send message: %v", err)
}
// Log the interaction
logInteraction(event.Sender, userMessage, aiResponse)
}
}
}
func getCommand(message string) string {
// Extract the command from the message
parts := strings.Fields(message)
if len(parts) > 0 {
return parts[0]
}
return ""
}
func getAIResponse(client *gomatrix.Client, config *BotConfig, sender, userMessage, command string) (string, error) {
url := "https://api.siliconflow.cn/v1/chat/completions"
// Ensure TemporaryPrompts has an entry for the sender
if _, exists := TemporaryPrompts[sender]; !exists {
TemporaryPrompts[sender] = make(map[string]string)
}
// Check if command exists in configuration
prompt, exists := config.Commands[command]
if !exists {
// If command doesn't exist, use default prompt
command = "default"
prompt = config.Prompt
}
// Initialize the prompt for the specific command if it doesn't exist
if _, exists := TemporaryPrompts[sender][command]; !exists {
TemporaryPrompts[sender][command] = prompt
}
// Prepare AI request with dynamic message content
messages := []Message{
{Role: "system", Content: TemporaryPrompts[sender][command]}, // Use user and command-specific prompt
{Role: "user", Content: userMessage},
}
payload := AIRequest{
Model: config.AIConfig.Model,
Messages: messages,
Stream: config.AIConfig.Stream,
MaxTokens: config.AIConfig.MaxTokens,
Stop: config.AIConfig.Stop,
Temperature: config.AIConfig.Temperature,
TopP: config.AIConfig.TopP,
TopK: config.AIConfig.TopK,
FrequencyPenalty: config.AIConfig.FrequencyPenalty,
N: config.AIConfig.N,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("accept", "application/json")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", config.AIAPIKey))
clientHTTP := &http.Client{}
resp, err := clientHTTP.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusTooManyRequests {
// Notify user of rate limit and retry after 10 seconds
_, err := client.SendText(config.RoomID, "Request limit reached. Retrying in 10 seconds...")
if err != nil {
log.Printf("Failed to send retry notification: %v", err)
}
time.Sleep(10 * time.Second)
return getAIResponse(client, config, sender, userMessage, command)
}
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
var aiResp AIResponse
if err := json.Unmarshal(body, &aiResp); err != nil {
return "", fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(aiResp.Choices) > 0 {
// Update the user's prompt with the latest conversation for memory
TemporaryPrompts[sender][command] += "\nUser: " + userMessage + "\nAI: " + aiResp.Choices[0].Message.Content
fmt.Printf("Size %v\n", len(TemporaryPrompts[sender][command]))
return aiResp.Choices[0].Message.Content, nil
}
return "No response from AI", nil
}
func loadBotConfig(filename string) (*BotConfig, error) {
if _, err := os.Stat(filename); os.IsNotExist(err) {
// Create default config
defaultConfig := BotConfig{
Homeserver: "https://matrix.org",
Username: "your_username",
Password: "your_password",
RoomID: "!your_room_id:matrix.org",
AccessToken: "",
AIAPIKey: "your_api_key",
Prompt: "Your default prompt here", // Default prompt
AIConfig: AIConfig{
Model: "deepseek-ai/deepseek-v2-chat",
Stream: false,
MaxTokens: 512,
Stop: []string{"string"},
Temperature: 0.7,
TopP: 0.7,
TopK: 50,
FrequencyPenalty: 0.5,
N: 1,
},
Commands: map[string]string{
"translate": "Translate prompt here",
"summarize": "Summarize prompt here",
},
}
file, err := os.Create(filename)
if err != nil {
return nil, fmt.Errorf("failed to create config file: %w", err)
}
defer file.Close()
encoder := json.NewEncoder(file)
if err := encoder.Encode(defaultConfig); err != nil {
return nil, fmt.Errorf("failed to write default config to file: %w", err)
}
return &defaultConfig, nil
}
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open config file: %w", err)
}
defer file.Close()
var config BotConfig
decoder := json.NewDecoder(file)
if err := decoder.Decode(&config); err != nil {
return nil, fmt.Errorf("failed to decode config file: %w", err)
}
return &config, nil
}
func logInteraction(sender, userMessage, aiResponse string) {
logDir := "logs"
if _, err := os.Stat(logDir); os.IsNotExist(err) {
os.Mkdir(logDir, os.ModePerm)
}
timestamp := time.Now().Format("2006-01-02")
logFilename := fmt.Sprintf("%s/%s.log", logDir, timestamp)
logFile, err := os.OpenFile(logFilename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Printf("Failed to open log file: %v", err)
return
}
defer logFile.Close()
logEntry := fmt.Sprintf("%s - %s: %s\nAI: %s\n\n", time.Now().Format(time.RFC3339), sender, userMessage, aiResponse)
if _, err := logFile.WriteString(logEntry); err != nil {
log.Printf("Failed to write log entry: %v", err)
}
}