diff --git a/handler/paste/service.go b/handler/paste/service.go index afaa299..1593725 100644 --- a/handler/paste/service.go +++ b/handler/paste/service.go @@ -1,6 +1,7 @@ package paste import ( + "github.com/google/uuid" "github.com/ollien/updown/repository" ) @@ -27,3 +28,16 @@ func (service Service) CreatePaste(title string) (Paster, error) { pasteDir: service.pasteDir, }, nil } + +// GetPaste gets a paste from the database +func (service Service) GetPaste(handle uuid.UUID) (Paster, error) { + paste, err := service.db.GetPaste(handle) + if err != nil { + return Paster{}, err + } + + return Paster{ + Paste: paste, + pasteDir: service.pasteDir, + }, nil +} diff --git a/handler/paste/service_test.go b/handler/paste/service_test.go index e283862..c3c151f 100644 --- a/handler/paste/service_test.go +++ b/handler/paste/service_test.go @@ -56,3 +56,20 @@ func TestCreatePaste(t *testing.T) { assert.Equal(t, testPaste, paster.Paste) assert.Equal(t, service.pasteDir, paster.pasteDir) } + +func TestGetPaste(t *testing.T) { + service := setupService() + mockRepo := service.db.(*MockPasteRepository) + testPaste := repository.Paste{ + ID: 23, + Title: "My Awesome Paste", + Handle: uuid.New(), + } + + mockRepo.On("GetPaste", testPaste.Handle).Return(testPaste, nil).Once() + paster, err := service.GetPaste(testPaste.Handle) + mockRepo.AssertExpectations(t) + assert.Nil(t, err) + assert.Equal(t, testPaste, paster.Paste) + assert.Equal(t, service.pasteDir, paster.pasteDir) +} diff --git a/web/paste.go b/web/paste.go index 206dfc7..cb078c4 100644 --- a/web/paste.go +++ b/web/paste.go @@ -1,8 +1,12 @@ package web import ( + "database/sql" + "fmt" "net/http" + "github.com/go-chi/chi" + "github.com/google/uuid" "github.com/sirupsen/logrus" ) @@ -45,5 +49,33 @@ func (server *Webserver) makePaste(w http.ResponseWriter, req *http.Request) { return } - w.Write(paster.Paste.Handle[:]) + redirectURI := fmt.Sprintf("/paste/%s", paster.Paste.Handle) + http.Redirect(w, req, redirectURI, http.StatusSeeOther) +} + +func (server *Webserver) getPaste(w http.ResponseWriter, req *http.Request) { + rawHandle := chi.URLParam(req, "handle") + handle, err := uuid.Parse(rawHandle) + if err != nil { + // If the handle doesn't parse, we won't ever be able to find the resource + w.WriteHeader(404) + server.log(logrus.InfoLevel, req, err) + return + } + + paster, err := server.pasteService.GetPaste(handle) + if err == sql.ErrNoRows { + w.WriteHeader(404) + server.log(logrus.InfoLevel, req, err) + } else if err != nil { + w.WriteHeader(500) + server.log(logrus.ErrorLevel, req, err) + return + } + err = paster.Pipe(w) + if err != nil { + w.WriteHeader(500) + server.log(logrus.ErrorLevel, req, err) + return + } } diff --git a/web/server.go b/web/server.go index bb289ac..58eaceb 100644 --- a/web/server.go +++ b/web/server.go @@ -38,7 +38,10 @@ func (server *Webserver) Start() error { func (server *Webserver) setupRoutes() { router := server.router router.Get("/", server.index) - router.Post("/paste", server.makePaste) + router.Route("/paste", func(r chi.Router) { + r.Post("/", server.makePaste) + r.Get(`/{handle:[a-zA-Z0-9-]+}`, server.getPaste) + }) router.Get("/static/*", http.StripPrefix("/static", http.FileServer(server.staticFiles)).ServeHTTP) }