package main

import (
	"bytes"
	"context"
	"crypto/aes"
	"crypto/cipher"
	"crypto/pbkdf2"
	"crypto/rand"
	"crypto/sha256"
	"database/sql"
	"errors"
	"flag"
	"fmt"
	"io/fs"
	"os"
	"text/tabwriter"

	"golang.org/x/term"
	_ "modernc.org/sqlite"
)

// Vault layout: magic | salt(16) | nonce(12) | ciphertext+tag.
const (
	magic   = "DEMO1"
	saltLen = 16
	keyLen  = 32 // AES-256
	iter    = 600_000
)

// serializer gives access to sqlite3_serialize / sqlite3_deserialize via Raw.
type serializer interface {
	Serialize() ([]byte, error)
	Deserialize([]byte) error
}

func mustExec(db *sql.DB, q string, args ...any) {
	if _, err := db.Exec(q, args...); err != nil {
		fatal(err)
	}
}

func fileExists(path string) bool {
	_, err := os.Stat(path)
	return err == nil
}

func fatal(err error) {
	fmt.Fprintln(os.Stderr, "error:", err)
	os.Exit(1)
}

// --- snapshot of the in-RAM database --------------------------------------

func snapshot(db *sql.DB) ([]byte, error) {
	conn, err := db.Conn(context.Background())
	if err != nil {
		return nil, err
	}
	defer func() { _ = conn.Close() }()
	var buf []byte
	err = conn.Raw(func(dc any) error {
		s, ok := dc.(serializer)
		if !ok {
			return errors.New("the driver does not expose Serialize")
		}
		var serr error
		buf, serr = s.Serialize()
		return serr
	})
	return buf, err
}

func restore(db *sql.DB, buf []byte) error {
	conn, err := db.Conn(context.Background())
	if err != nil {
		return err
	}
	defer func() { _ = conn.Close() }()
	return conn.Raw(func(dc any) error {
		s, ok := dc.(serializer)
		if !ok {
			return errors.New("the driver does not expose Deserialize")
		}
		return s.Deserialize(buf)
	})
}

// --- encrypted vault on disk ----------------------------------------------

func seal(db *sql.DB, path, pass string) error {
	plain, err := snapshot(db)
	if err != nil {
		return err
	}
	blob, err := encrypt(pass, plain)
	if err != nil {
		return err
	}
	tmp, err := os.CreateTemp("", ".vault-*")
	if err != nil {
		return err
	}
	tmpName := tmp.Name()
	if _, werr := tmp.Write(blob); werr != nil {
		_ = tmp.Close()
		_ = os.Remove(tmpName)
		return werr
	}
	if cerr := tmp.Close(); cerr != nil {
		_ = os.Remove(tmpName)
		return cerr
	}
	if cerr := os.Chmod(tmpName, 0o600); cerr != nil {
		_ = os.Remove(tmpName)
		return cerr
	}
	return os.Rename(tmpName, path)
}

func readVault(path, pass string) ([]byte, bool, error) {
	blob, err := os.ReadFile(path) // #nosec G304 -- path is the operator-provided -vault flag, not untrusted input
	if errors.Is(err, fs.ErrNotExist) {
		return nil, false, nil
	}
	if err != nil {
		return nil, false, err
	}
	pt, err := decrypt(pass, blob)
	return pt, true, err
}

func encrypt(pass string, plain []byte) ([]byte, error) {
	salt := make([]byte, saltLen)
	if _, err := rand.Read(salt); err != nil {
		return nil, err
	}
	gcm, err := newGCM(pass, salt)
	if err != nil {
		return nil, err
	}
	nonce := make([]byte, gcm.NonceSize())
	if _, err := rand.Read(nonce); err != nil {
		return nil, err
	}
	ct := gcm.Seal(nil, nonce, plain, []byte(magic))
	out := append([]byte(magic), salt...)
	out = append(out, nonce...)
	return append(out, ct...), nil
}

func decrypt(pass string, blob []byte) ([]byte, error) {
	if len(blob) < len(magic)+saltLen+12 || string(blob[:len(magic)]) != magic {
		return nil, fmt.Errorf("invalid format")
	}
	p := blob[len(magic):]
	salt, rest := p[:saltLen], p[saltLen:]
	gcm, err := newGCM(pass, salt)
	if err != nil {
		return nil, err
	}
	ns := gcm.NonceSize()
	if len(rest) < ns {
		return nil, fmt.Errorf("invalid format")
	}
	nonce, ct := rest[:ns], rest[ns:]
	pt, err := gcm.Open(nil, nonce, ct, []byte(magic))
	if err != nil {
		return nil, fmt.Errorf("wrong password (or tampered file)")
	}
	return pt, nil
}

func newGCM(pass string, salt []byte) (cipher.AEAD, error) {
	key, err := pbkdf2.Key(sha256.New, pass, salt, iter, keyLen)
	if err != nil {
		return nil, err
	}
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	return cipher.NewGCM(block)
}

// --- terminal and helpers -------------------------------------------------

func askPassphrase(confirm bool) (string, error) {
	fd := int(os.Stdin.Fd()) // #nosec G115 -- stdin's fd is small; the conversion is safe
	if !term.IsTerminal(fd) {
		return "", errors.New("an interactive terminal is required for the password")
	}
	fmt.Print("password: ")
	p1, err := term.ReadPassword(fd)
	fmt.Println()
	if err != nil {
		return "", fmt.Errorf("reading the password: %w", err)
	}
	if len(p1) == 0 {
		return "", errors.New("empty password")
	}
	if !confirm {
		return string(p1), nil
	}
	fmt.Print("confirm password: ")
	p2, err := term.ReadPassword(fd)
	fmt.Println()
	if err != nil {
		return "", fmt.Errorf("reading the confirmation: %w", err)
	}
	if !bytes.Equal(p1, p2) {
		return "", errors.New("passwords do not match")
	}
	return string(p1), nil
}

func list(db *sql.DB) {
	rows, err := db.Query(`SELECT id, key, value FROM secrets ORDER BY id`)
	if err != nil {
		fatal(err)
	}
	defer func() { _ = rows.Close() }()
	w := tabwriter.NewWriter(os.Stdout, 0, 2, 2, ' ', 0)
	fmt.Fprintln(w, "ID\tKEY\tVALUE")
	for rows.Next() {
		var id int
		var k, v string
		if err := rows.Scan(&id, &k, &v); err != nil {
			fatal(err)
		}
		fmt.Fprintf(w, "%d\t%s\t%s\n", id, k, v)
	}
	_ = w.Flush()
	if rerr := rows.Err(); rerr != nil {
		fatal(rerr)
	}
}

func main() {
	vault := flag.String("vault", "vault.blob", "encrypted vault file")
	flag.Parse()

	// Decide existence BEFORE asking for the password: a new vault asks twice.
	exists := fileExists(*vault)
	pass, err := askPassphrase(!exists)
	if err != nil {
		fatal(err)
	}

	// Database ONLY in RAM. MaxOpenConns(1) pins everything to one connection
	// (a ":memory:" DB belongs to the connection that opened it);
	// temp_store=MEMORY avoids spilling to disk.
	db, err := sql.Open("sqlite", ":memory:")
	if err != nil {
		fatal(err)
	}
	defer func() { _ = db.Close() }()
	db.SetMaxOpenConns(1)
	mustExec(db, `PRAGMA temp_store=MEMORY;`)

	// Vault exists: decrypt -> deserialize into RAM -> list.
	if exists {
		blob, _, rerr := readVault(*vault, pass)
		if rerr != nil {
			fatal(rerr) // wrong password or tampered: abort WITHOUT overwriting
		}
		if derr := restore(db, blob); derr != nil {
			fatal(derr)
		}
		fmt.Println("vault opened:")
		list(db)
		return
	}

	// First time: create in RAM, serialize the bytes, encrypt and seal.
	fmt.Println("vault not found — creating and inserting secrets…")
	mustExec(db, `CREATE TABLE secrets (id INTEGER PRIMARY KEY, key TEXT, value TEXT);`)
	mustExec(db, `INSERT INTO secrets (key, value) VALUES (?,?),(?,?),(?,?)`,
		"api_key", "sk-demo-3f9a…",
		"pin", "4729",
		"note", "the plaintext only ever exists in RAM",
	)
	list(db)
	if serr := seal(db, *vault, pass); serr != nil {
		fatal(serr)
	}
	fmt.Printf("\nsealed to %s (encrypted). run it again to open it.\n", *vault)
}
