105 lines
2.4 KiB
Go
105 lines
2.4 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"doudizhu-server/internal/models"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type Config struct {
|
|
Host string
|
|
Port int
|
|
User string
|
|
Password string
|
|
Database string
|
|
}
|
|
|
|
type DB struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
func New(cfg Config) (*DB, error) {
|
|
connStr := fmt.Sprintf(
|
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
|
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Database,
|
|
)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
pool, err := pgxpool.New(ctx, connStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
return &DB{pool: pool}, nil
|
|
}
|
|
|
|
func (d *DB) Close() {
|
|
if d.pool != nil {
|
|
d.pool.Close()
|
|
}
|
|
}
|
|
|
|
func (d *DB) CreateUser(ctx context.Context, user *models.User) error {
|
|
query := `
|
|
INSERT INTO users (id, username, password, nickname, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $5)
|
|
`
|
|
_, err := d.pool.Exec(ctx, query, user.ID, user.Username, user.Password, user.Nickname, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *DB) GetUserByUsername(ctx context.Context, username string) (*models.User, error) {
|
|
query := `
|
|
SELECT id, username, password, nickname
|
|
FROM users
|
|
WHERE username = $1
|
|
`
|
|
user := &models.User{}
|
|
err := d.pool.QueryRow(ctx, query, username).Scan(
|
|
&user.ID, &user.Username, &user.Password, &user.Nickname,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (d *DB) GetUserByID(ctx context.Context, id string) (*models.User, error) {
|
|
query := `
|
|
SELECT id, username, password, nickname
|
|
FROM users
|
|
WHERE id = $1
|
|
`
|
|
user := &models.User{}
|
|
err := d.pool.QueryRow(ctx, query, id).Scan(
|
|
&user.ID, &user.Username, &user.Password, &user.Nickname,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (d *DB) UsernameExists(ctx context.Context, username string) (bool, error) {
|
|
query := `SELECT EXISTS(SELECT 1 FROM users WHERE username = $1)`
|
|
var exists bool
|
|
err := d.pool.QueryRow(ctx, query, username).Scan(&exists)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check username: %w", err)
|
|
}
|
|
return exists, nil
|
|
}
|