diff --git a/main.go b/main.go index f8a6891..f283bc7 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,9 @@ package main import ( + "database/sql" "encoding/json" + "fmt" "os" "sort" "time" @@ -30,6 +32,14 @@ func start(path *string) func(*cli.Cmd) { NPomodoros: *pomodoros, Duration: parsed, } + maybe(db.With(func(tx *sql.Tx) error { + id, err := db.CreateTask(tx, *task) + if err != nil { + return err + } + task.ID = id + return nil + })) runner, err := NewTaskRunner(task, db, NewXnotifier(*path+"/icon.png")) maybe(err) server, err := NewServer(*path+"/pomo.sock", runner) @@ -63,39 +73,51 @@ func create(path *string) func(*cli.Cmd) { NPomodoros: *pomodoros, Duration: parsed, } - taskID, err := db.CreateTask(*task) - maybe(err) + maybe(db.With(func(tx *sql.Tx) error { + taskId, err := db.CreateTask(tx, *task) + if err != nil { + return err + } + fmt.Println(taskId) + return nil + })) } } } func begin(path *string) func(*cli.Cmd) { return func(cmd *cli.Cmd) { - cmd.Spec = "ID" - var jobId = cmd.IntArg("ID", -1, "ID of Pomodoro to begin") + cmd.Spec = "[OPTIONS] TASK_ID" + var ( + taskId = cmd.IntArg("TASK_ID", -1, "ID of Pomodoro to begin") + ) cmd.Action = func() { db, err := NewStore(*path) maybe(err) defer db.Close() - tasks, err := db.ReadTasks() - maybe(err) - task := &Task{} - for _, task = range tasks { - if task.ID == *jobId { - break + var task *Task + maybe(db.With(func(tx *sql.Tx) error { + read, err := db.ReadTask(tx, *taskId) + if err != nil { + return err } - } - if task.ID == *jobId { - runner, err := NewTaskRunner(task, db, NewXnotifier(*path+"/icon.png")) - maybe(err) - server, err := NewServer(*path+"/pomo.sock", runner) - maybe(err) - server.Start() - defer server.Stop() - runner.Start() - startUI(runner) - } + task = read + err = db.DeletePomodoros(tx, *taskId) + if err != nil { + return err + } + task.Pomodoros = []*Pomodoro{} + return nil + })) + runner, err := NewTaskRunner(task, db, NewXnotifier(*path+"/icon.png")) + maybe(err) + server, err := NewServer(*path+"/pomo.sock", runner) + maybe(err) + server.Start() + defer server.Stop() + runner.Start() + startUI(runner) } } } @@ -128,24 +150,27 @@ func list(path *string) func(*cli.Cmd) { db, err := NewStore(*path) maybe(err) defer db.Close() - tasks, err := db.ReadTasks() - maybe(err) - if *assend { - sort.Sort(sort.Reverse(ByID(tasks))) - } - if !*all { - tasks = After(time.Now().Add(-duration), tasks) - } - if *limit > 0 && (len(tasks) > *limit) { - tasks = tasks[0:*limit] - } - if *asJSON { - maybe(json.NewEncoder(os.Stdout).Encode(tasks)) - return - } - config, err := NewConfig(*path + "/config.json") - maybe(err) - summerizeTasks(config, tasks) + maybe(db.With(func(tx *sql.Tx) error { + tasks, err := db.ReadTasks(tx) + maybe(err) + if *assend { + sort.Sort(sort.Reverse(ByID(tasks))) + } + if !*all { + tasks = After(time.Now().Add(-duration), tasks) + } + if *limit > 0 && (len(tasks) > *limit) { + tasks = tasks[0:*limit] + } + if *asJSON { + maybe(json.NewEncoder(os.Stdout).Encode(tasks)) + return nil + } + config, err := NewConfig(*path + "/config.json") + maybe(err) + summerizeTasks(config, tasks) + return nil + })) } } } @@ -158,7 +183,9 @@ func _delete(path *string) func(*cli.Cmd) { db, err := NewStore(*path) maybe(err) defer db.Close() - maybe(db.DeleteTask(*taskID)) + maybe(db.With(func(tx *sql.Tx) error { + return db.DeleteTask(tx, *taskID) + })) } } } diff --git a/runner.go b/runner.go index 9823484..97869d6 100644 --- a/runner.go +++ b/runner.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "time" ) @@ -20,12 +21,8 @@ type TaskRunner struct { } func NewTaskRunner(task *Task, store *Store, notifier Notifier) (*TaskRunner, error) { - taskID, err := store.CreateTask(*task) - if err != nil { - return nil, err - } tr := &TaskRunner{ - taskID: taskID, + taskID: task.ID, taskMessage: task.Message, nPomodoros: task.NPomodoros, origDuration: task.Duration, @@ -89,7 +86,9 @@ func (t *TaskRunner) run() error { goto loop } pomodoro.End = time.Now() - err := t.store.CreatePomodoro(t.taskID, *pomodoro) + err := t.store.With(func(tx *sql.Tx) error { + return t.store.CreatePomodoro(tx, t.taskID, *pomodoro) + }) if err != nil { return err } diff --git a/store.go b/store.go index a56ed75..0152fbe 100644 --- a/store.go +++ b/store.go @@ -12,6 +12,8 @@ import ( // 2018-01-16 19:05:21.752851759+08:00 const datetimeFmt = "2006-01-02 15:04:05.999999999-07:00" +type StoreFunc func(tx *sql.Tx) error + type Store struct { db *sql.DB } @@ -25,39 +27,45 @@ func NewStore(path string) (*Store, error) { return &Store{db: db}, nil } -func (s Store) CreateTask(task Task) (int, error) { - var taskID int +// With applies all of the given functions with +// a single transaction, rolling back on failure +// and commiting on success. +func (s Store) With(fns ...func(tx *sql.Tx) error) error { tx, err := s.db.Begin() if err != nil { - return -1, err + return err } - _, err = tx.Exec( + for _, fn := range fns { + err = fn(tx) + if err != nil { + tx.Rollback() + return err + } + } + return tx.Commit() +} + +func (s Store) CreateTask(tx *sql.Tx, task Task) (int, error) { + var taskID int + _, err := tx.Exec( "INSERT INTO task (message,pomodoros,duration,tags) VALUES ($1,$2,$3,$4)", task.Message, task.NPomodoros, task.Duration.String(), strings.Join(task.Tags, ",")) if err != nil { - tx.Rollback() return -1, err } err = tx.QueryRow("SELECT last_insert_rowid() FROM task").Scan(&taskID) if err != nil { - tx.Rollback() return -1, err } - return taskID, tx.Commit() + err = tx.QueryRow("SELECT last_insert_rowid() FROM task").Scan(&taskID) + if err != nil { + return -1, err + } + return taskID, nil } -func (s Store) CreatePomodoro(taskID int, pomodoro Pomodoro) error { - _, err := s.db.Exec( - `INSERT INTO pomodoro (task_id, start, end) VALUES ($1, $2, $3)`, - taskID, - pomodoro.Start, - pomodoro.End, - ) - return err -} - -func (s Store) ReadTasks() ([]*Task, error) { - rows, err := s.db.Query(`SELECT rowid,message,pomodoros,duration,tags FROM task`) +func (s Store) ReadTasks(tx *sql.Tx) ([]*Task, error) { + rows, err := tx.Query(`SELECT rowid,message,pomodoros,duration,tags FROM task`) if err != nil { return nil, err } @@ -77,7 +85,7 @@ func (s Store) ReadTasks() ([]*Task, error) { if tags != "" { task.Tags = strings.Split(tags, ",") } - pomodoros, err := s.ReadPomodoros(task.ID) + pomodoros, err := s.ReadPomodoros(tx, task.ID) if err != nil { return nil, err } @@ -89,8 +97,49 @@ func (s Store) ReadTasks() ([]*Task, error) { return tasks, nil } -func (s Store) ReadPomodoros(taskID int) ([]*Pomodoro, error) { - rows, err := s.db.Query(`SELECT start,end FROM pomodoro WHERE task_id = $1`, &taskID) +func (s Store) DeleteTask(tx *sql.Tx, taskID int) error { + _, err := tx.Exec("DELETE FROM task WHERE rowid = $1", &taskID) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM pomodoro WHERE task_id = $1", &taskID) + if err != nil { + return err + } + return nil +} + +func (s Store) ReadTask(tx *sql.Tx, taskID int) (*Task, error) { + task := &Task{} + var ( + tags string + strDuration string + ) + err := tx.QueryRow(`SELECT rowid,message,pomodoros,duration,tags FROM task WHERE rowid = $1`, &taskID). + Scan(&task.ID, &task.Message, &task.NPomodoros, &strDuration, &tags) + if err != nil { + return nil, err + } + duration, _ := time.ParseDuration(strDuration) + task.Duration = duration + if tags != "" { + task.Tags = strings.Split(tags, ",") + } + return task, nil +} + +func (s Store) CreatePomodoro(tx *sql.Tx, taskID int, pomodoro Pomodoro) error { + _, err := tx.Exec( + `INSERT INTO pomodoro (task_id, start, end) VALUES ($1, $2, $3)`, + taskID, + pomodoro.Start, + pomodoro.End, + ) + return err +} + +func (s Store) ReadPomodoros(tx *sql.Tx, taskID int) ([]*Pomodoro, error) { + rows, err := tx.Query(`SELECT start,end FROM pomodoro WHERE task_id = $1`, &taskID) if err != nil { return nil, err } @@ -114,22 +163,9 @@ func (s Store) ReadPomodoros(taskID int) ([]*Pomodoro, error) { return pomodoros, nil } -func (s Store) DeleteTask(taskID int) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM task WHERE rowid = $1", &taskID) - if err != nil { - tx.Rollback() - return err - } - _, err = tx.Exec("DELETE FROM pomodoro WHERE task_id = $1", &taskID) - if err != nil { - tx.Rollback() - return err - } - return tx.Commit() +func (s Store) DeletePomodoros(tx *sql.Tx, taskID int) error { + _, err := tx.Exec("DELETE FROM pomodoro WHERE task_id = $1", &taskID) + return err } func (s Store) Close() error { return s.db.Close() }