Add route to get paste

Needs an actual template - it's totally XSS vulnerable at the moment
master
Nick Krichevsky 2019-03-17 17:40:47 -04:00
parent f9b7eebc17
commit 58bfe7729c
4 changed files with 68 additions and 2 deletions

View File

@ -1,6 +1,7 @@
package paste
import (
"github.com/google/uuid"
"github.com/ollien/updown/repository"
)
@ -27,3 +28,16 @@ func (service Service) CreatePaste(title string) (Paster, error) {
pasteDir: service.pasteDir,
}, nil
}
// GetPaste gets a paste from the database
func (service Service) GetPaste(handle uuid.UUID) (Paster, error) {
paste, err := service.db.GetPaste(handle)
if err != nil {
return Paster{}, err
}
return Paster{
Paste: paste,
pasteDir: service.pasteDir,
}, nil
}

View File

@ -56,3 +56,20 @@ func TestCreatePaste(t *testing.T) {
assert.Equal(t, testPaste, paster.Paste)
assert.Equal(t, service.pasteDir, paster.pasteDir)
}
func TestGetPaste(t *testing.T) {
service := setupService()
mockRepo := service.db.(*MockPasteRepository)
testPaste := repository.Paste{
ID: 23,
Title: "My Awesome Paste",
Handle: uuid.New(),
}
mockRepo.On("GetPaste", testPaste.Handle).Return(testPaste, nil).Once()
paster, err := service.GetPaste(testPaste.Handle)
mockRepo.AssertExpectations(t)
assert.Nil(t, err)
assert.Equal(t, testPaste, paster.Paste)
assert.Equal(t, service.pasteDir, paster.pasteDir)
}

View File

@ -1,8 +1,12 @@
package web
import (
"database/sql"
"fmt"
"net/http"
"github.com/go-chi/chi"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
@ -45,5 +49,33 @@ func (server *Webserver) makePaste(w http.ResponseWriter, req *http.Request) {
return
}
w.Write(paster.Paste.Handle[:])
redirectURI := fmt.Sprintf("/paste/%s", paster.Paste.Handle)
http.Redirect(w, req, redirectURI, http.StatusSeeOther)
}
func (server *Webserver) getPaste(w http.ResponseWriter, req *http.Request) {
rawHandle := chi.URLParam(req, "handle")
handle, err := uuid.Parse(rawHandle)
if err != nil {
// If the handle doesn't parse, we won't ever be able to find the resource
w.WriteHeader(404)
server.log(logrus.InfoLevel, req, err)
return
}
paster, err := server.pasteService.GetPaste(handle)
if err == sql.ErrNoRows {
w.WriteHeader(404)
server.log(logrus.InfoLevel, req, err)
} else if err != nil {
w.WriteHeader(500)
server.log(logrus.ErrorLevel, req, err)
return
}
err = paster.Pipe(w)
if err != nil {
w.WriteHeader(500)
server.log(logrus.ErrorLevel, req, err)
return
}
}

View File

@ -38,7 +38,10 @@ func (server *Webserver) Start() error {
func (server *Webserver) setupRoutes() {
router := server.router
router.Get("/", server.index)
router.Post("/paste", server.makePaste)
router.Route("/paste", func(r chi.Router) {
r.Post("/", server.makePaste)
r.Get(`/{handle:[a-zA-Z0-9-]+}`, server.getPaste)
})
router.Get("/static/*", http.StripPrefix("/static", http.FileServer(server.staticFiles)).ServeHTTP)
}