From df646ff0875a59981a2637e1e1e06ecaf54fcf9d Mon Sep 17 00:00:00 2001 From: Layla Date: Sat, 1 Apr 2023 22:25:54 +0000 Subject: [PATCH] Cleanup DB work --- Dockerfile | 2 +- ReadMe.md | 5 ++++- main.go | 5 ++++- persistence/database.go | 1 + persistence/sqlite3.go | 22 ++++++++++++++++++++-- 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index d71a9f9..2f51774 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,4 +4,4 @@ COPY build/birdbot /usr/bin/birdbot VOLUME /etc/birdbot -ENTRYPOINT ["/usr/bin/birdbot", "-c=/etc/birdbot/birdbot.yaml"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/birdbot", "-c=/etc/birdbot/birdbot.yaml", "-db=/var/lib/birdbot/birdbot.db"] \ No newline at end of file diff --git a/ReadMe.md b/ReadMe.md index 12ac8b5..84f76d7 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -24,4 +24,7 @@ Example: docker run -it -v `pwd`:/etc/birdbot yeslayla/birdbot:latest ``` -In this example, your config is in the current directory and call `birdbot.yaml` \ No newline at end of file +In this example, your config is in the current directory and call `birdbot.yaml` + +### Persistant Data + diff --git a/main.go b/main.go index f823d95..046d5f4 100644 --- a/main.go +++ b/main.go @@ -22,10 +22,13 @@ func main() { configDir, _ := os.UserConfigDir() defaultConfigPath := path.Join(configDir, "birdbot", "config.yaml") + defaultDBPath := path.Join(configDir, "birdbot", "birdbot.db") var config_file string + var db_file string var version bool flag.StringVar(&config_file, "c", defaultConfigPath, "Path to config file") + flag.StringVar(&db_file, "db", defaultDBPath, "Path to store persistant data") flag.BoolVar(&version, "v", false, "List version") flag.Parse() @@ -51,7 +54,7 @@ func main() { } } - db := persistence.NewSqlite3Database() + db := persistence.NewSqlite3Database(db_file) if err := db.MigrateUp(); err != nil { log.Fatal("Failed to migrate db: ", err) } diff --git a/persistence/database.go b/persistence/database.go index c74e418..df6aabf 100644 --- a/persistence/database.go +++ b/persistence/database.go @@ -1,5 +1,6 @@ package persistence +// Database is an interface used to wrap persistant data type Database interface { GetDiscordMessage(id string) (string, error) SetDiscordMessage(id string, messageID string) error diff --git a/persistence/sqlite3.go b/persistence/sqlite3.go index 4e0d8ef..3b4db49 100644 --- a/persistence/sqlite3.go +++ b/persistence/sqlite3.go @@ -5,6 +5,8 @@ import ( "embed" "fmt" "log" + "os" + "path/filepath" _ "github.com/mattn/go-sqlite3" migrate "github.com/rubenv/sql-migrate" @@ -17,8 +19,17 @@ type Sqlite3Database struct { db *sql.DB } -func NewSqlite3Database() *Sqlite3Database { - db, err := sql.Open("sqlite3", "./birdbot.db") +// NewSqlite3Database creates a new SqliteDB object +func NewSqlite3Database(path string) *Sqlite3Database { + dir := filepath.Dir(path) + if _, err := os.Stat(dir); os.IsNotExist(err) { + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + log.Printf("failed to create directory for db: %s", err) + return nil + } + } + + db, err := sql.Open("sqlite3", path) if err != nil { log.Printf("failed to open db: %s", err) return nil @@ -36,6 +47,7 @@ func getMigrations() migrate.MigrationSource { } } +// MigrateUp migrates the DB func (db *Sqlite3Database) MigrateUp() error { n, err := migrate.Exec(db.db, "sqlite3", getMigrations(), migrate.Up) @@ -49,6 +61,7 @@ func (db *Sqlite3Database) MigrateUp() error { return nil } +// MigrateUp destroys the DB func (db *Sqlite3Database) MigrateDown() error { n, err := migrate.Exec(db.db, "sqlite3", getMigrations(), migrate.Down) @@ -62,18 +75,23 @@ func (db *Sqlite3Database) MigrateDown() error { return nil } +// GetDiscordMessage finds a discord message ID from a given local ID func (db *Sqlite3Database) GetDiscordMessage(id string) (string, error) { var messageID string row := db.db.QueryRow("SELECT message_id FROM discord_messages WHERE id = $1", id) if err := row.Scan(&messageID); err != nil { + if err == sql.ErrNoRows { + return "", nil + } return "", fmt.Errorf("failed to get discord message from sqlite3: %s", err) } return messageID, nil } +// SetDiscordMessage sets a discord message ID from a given local ID func (db *Sqlite3Database) SetDiscordMessage(id string, messageID string) error { statement, err := db.db.Prepare("INSERT OR IGNORE INTO discord_messages (id, message_id) VALUES (?, ?)")