440 lines
9.9 KiB
Go
440 lines
9.9 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"io/fs"
|
|
"log"
|
|
"mime"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var uploadDir string
|
|
|
|
const idLength = 8
|
|
const idCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
|
|
var indexHTML *template.Template
|
|
|
|
func main() {
|
|
uploadDir = os.Getenv("UPLOAD_DIR")
|
|
if uploadDir == "" {
|
|
uploadDir = "/uploads"
|
|
}
|
|
var fileExpires time.Duration = 7 * 24 * time.Hour
|
|
|
|
dryRun := strings.ToLower(os.Getenv("DRYRUN_CLEAN")) != "false"
|
|
fileDuration := os.Getenv("FILE_DURATION")
|
|
if fileDuration != "" {
|
|
expires, err := time.ParseDuration(fileDuration)
|
|
if err != nil {
|
|
log.Fatalf("FILE_DURATION %s is invalid: %v", fileDuration, err)
|
|
}
|
|
fileExpires = expires
|
|
}
|
|
|
|
port := os.Getenv("PORT")
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
|
|
if err := os.MkdirAll(uploadDir, 0755); err != nil {
|
|
log.Fatalf("Failed to create upload directory: %v", err)
|
|
}
|
|
indexHTML = template.Must(template.ParseFiles("index.html.tmpl"))
|
|
http.HandleFunc("/", handleRequest)
|
|
|
|
log.Printf("Starting server on port %s, upload directory: %s", port, uploadDir)
|
|
log.Printf("Files expires after: %v (DRYRUN_CLEAN=%v)", fileExpires, dryRun)
|
|
if fileExpires != 0 {
|
|
go cleanupProcedure(fileExpires, dryRun)
|
|
} else {
|
|
log.Print("Cleaning old files is disabled")
|
|
}
|
|
|
|
if err := http.ListenAndServe(":"+port, nil); err != nil {
|
|
log.Fatalf("Server failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func handleRequest(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
if r.Method == "GET" && r.URL.Path == "/" {
|
|
serveFrontpage(w, r)
|
|
logRequest(r, 200, time.Since(start))
|
|
return
|
|
}
|
|
if r.Method == "GET" && r.URL.Path == "/healthz" {
|
|
return
|
|
}
|
|
|
|
if r.Method == "POST" && r.URL.Path == "/upload" {
|
|
handleMultipartUpload(w, r, start)
|
|
return
|
|
}
|
|
|
|
if r.Method == "PUT" {
|
|
handleUpload(w, r, start)
|
|
return
|
|
}
|
|
|
|
if r.Method == "GET" {
|
|
handleDownload(w, r, start)
|
|
return
|
|
}
|
|
|
|
if r.Method == "HEAD" {
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
logRequest(r, 405, time.Since(start))
|
|
}
|
|
|
|
func handleUpload(w http.ResponseWriter, r *http.Request, start time.Time) {
|
|
filename := strings.TrimPrefix(r.URL.Path, "/")
|
|
if filename == "" {
|
|
http.Error(w, "Filename required", http.StatusBadRequest)
|
|
logRequest(r, 400, time.Since(start))
|
|
return
|
|
}
|
|
|
|
filename = filepath.Base(filename)
|
|
|
|
id, err := generateUniqueID()
|
|
if err != nil {
|
|
http.Error(w, "Failed to generate unique ID", http.StatusInternalServerError)
|
|
logRequest(r, 500, time.Since(start))
|
|
return
|
|
}
|
|
|
|
targetPath := filepath.Join(uploadDir, id, filename)
|
|
file, err := os.Create(targetPath)
|
|
if err != nil {
|
|
http.Error(w, "Failed to create file", http.StatusInternalServerError)
|
|
logRequest(r, 500, time.Since(start))
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
written, err := io.Copy(file, r.Body)
|
|
if err != nil {
|
|
http.Error(w, "Failed to write file", http.StatusInternalServerError)
|
|
logRequest(r, 500, time.Since(start))
|
|
return
|
|
}
|
|
|
|
proto := getProtocol(r)
|
|
host := getHost(r)
|
|
downloadURL := fmt.Sprintf("%s://%s/%s/%s\n", proto, host, id, filename)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(downloadURL))
|
|
|
|
log.Printf("Uploaded %s (%d bytes) -> %s", filename, written, downloadURL)
|
|
logRequest(r, 200, time.Since(start))
|
|
}
|
|
|
|
func handleMultipartUpload(w http.ResponseWriter, r *http.Request, start time.Time) {
|
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
|
http.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
|
|
logRequest(r, 400, time.Since(start))
|
|
return
|
|
}
|
|
|
|
files := r.MultipartForm.File["files"]
|
|
if len(files) == 0 {
|
|
http.Error(w, "No files provided", http.StatusBadRequest)
|
|
logRequest(r, 400, time.Since(start))
|
|
return
|
|
}
|
|
|
|
proto := getProtocol(r)
|
|
host := getHost(r)
|
|
urls := make([]string, 0, len(files))
|
|
|
|
for _, fileHeader := range files {
|
|
src, err := fileHeader.Open()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
id, err := generateUniqueID()
|
|
if err != nil {
|
|
src.Close()
|
|
continue
|
|
}
|
|
|
|
filename := filepath.Base(fileHeader.Filename)
|
|
targetPath := filepath.Join(uploadDir, id, filename)
|
|
|
|
dst, err := os.Create(targetPath)
|
|
if err != nil {
|
|
src.Close()
|
|
continue
|
|
}
|
|
|
|
written, err := io.Copy(dst, src)
|
|
dst.Close()
|
|
src.Close()
|
|
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
downloadURL := fmt.Sprintf("%s://%s/%s/%s", proto, host, id, filename)
|
|
urls = append(urls, downloadURL)
|
|
log.Printf("Uploaded %s (%d bytes) -> %s", filename, written, downloadURL)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string][]string{"urls": urls})
|
|
logRequest(r, 200, time.Since(start))
|
|
}
|
|
|
|
func handleDownload(w http.ResponseWriter, r *http.Request, start time.Time) {
|
|
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/")
|
|
if len(parts) != 2 {
|
|
http.NotFound(w, r)
|
|
logRequest(r, 404, time.Since(start))
|
|
return
|
|
}
|
|
|
|
id := parts[0]
|
|
filename := parts[1]
|
|
|
|
filePath := filepath.Join(uploadDir, id, filename)
|
|
file, err := os.Open(filePath)
|
|
if err != nil {
|
|
http.NotFound(w, r)
|
|
logRequest(r, 404, time.Since(start))
|
|
return
|
|
}
|
|
if !strings.HasPrefix(filePath, uploadDir) {
|
|
http.NotFound(w, r)
|
|
logRequest(r, 404, time.Since(start))
|
|
return
|
|
}
|
|
|
|
defer file.Close()
|
|
|
|
stat, err := file.Stat()
|
|
if err != nil {
|
|
http.Error(w, "Failed to stat file", http.StatusInternalServerError)
|
|
logRequest(r, 500, time.Since(start))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", detectContentType(filename))
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=\"%s\"", filename))
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
|
|
|
io.Copy(w, file)
|
|
logRequest(r, 200, time.Since(start))
|
|
}
|
|
|
|
func generateUniqueID() (string, error) {
|
|
const maxRetries = 100
|
|
|
|
for range maxRetries {
|
|
id := generateID()
|
|
targetDir := filepath.Join(uploadDir, id)
|
|
|
|
if _, err := os.Stat(targetDir); os.IsNotExist(err) {
|
|
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
|
return "", err
|
|
}
|
|
return id, nil
|
|
}
|
|
}
|
|
|
|
return "", fmt.Errorf("failed to generate unique ID after %d retries", maxRetries)
|
|
}
|
|
|
|
func generateID() string {
|
|
bytes := make([]byte, idLength)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for i, b := range bytes {
|
|
bytes[i] = idCharset[int(b)%len(idCharset)]
|
|
}
|
|
|
|
return string(bytes)
|
|
}
|
|
|
|
func getHost(r *http.Request) string {
|
|
// Check X-Forwarded-Host (common proxy header)
|
|
if host := r.Header.Get("X-Forwarded-Host"); host != "" {
|
|
return host
|
|
}
|
|
|
|
// Check standard Forwarded header (RFC 7239)
|
|
if forwarded := r.Header.Get("Forwarded"); forwarded != "" {
|
|
// Parse "Forwarded: for=...; host=example.com; proto=https"
|
|
parts := strings.SplitSeq(forwarded, ";")
|
|
for part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if host, found := strings.CutPrefix(part, "host="); found {
|
|
return host
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback to Host header
|
|
if host := r.Header.Get("Host"); host != "" {
|
|
return host
|
|
}
|
|
|
|
return "localhost:8080"
|
|
}
|
|
|
|
func getProtocol(r *http.Request) string {
|
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
|
return proto
|
|
}
|
|
|
|
if forwarded := r.Header.Get("Forwarded"); forwarded != "" {
|
|
parts := strings.SplitSeq(forwarded, ";")
|
|
for part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if proto, found := strings.CutPrefix(part, "proto="); found {
|
|
return proto
|
|
}
|
|
}
|
|
}
|
|
|
|
if ssl := r.Header.Get("X-Forwarded-SSL"); ssl == "on" {
|
|
return "https"
|
|
}
|
|
if r.TLS != nil {
|
|
return "https"
|
|
}
|
|
|
|
return "http"
|
|
}
|
|
|
|
func detectContentType(filename string) string {
|
|
ext := filepath.Ext(filename)
|
|
mimeType := mime.TypeByExtension(ext)
|
|
if mimeType != "" {
|
|
return mimeType
|
|
}
|
|
return "application/octet-stream"
|
|
}
|
|
|
|
func logRequest(r *http.Request, status int, duration time.Duration) {
|
|
log.Printf("[%s] %s %s %d %v", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.URL.Path, status, duration)
|
|
}
|
|
|
|
func serveFrontpage(w http.ResponseWriter, r *http.Request) {
|
|
proto := getProtocol(r)
|
|
host := getHost(r)
|
|
pageData := map[string]string{"Proto": proto, "Host": host}
|
|
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
indexHTML.Execute(w, pageData)
|
|
}
|
|
|
|
// Cleanup
|
|
|
|
func isDirEmpty(name string) (bool, error) {
|
|
f, err := os.Open(name)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer f.Close()
|
|
|
|
// Read exactly one entry.
|
|
_, err = f.ReadDir(1)
|
|
if errors.Is(err, io.EOF) {
|
|
// EOF means there are no entries, so the directory is empty.
|
|
return true, nil
|
|
}
|
|
|
|
// If err is nil, an entry was found (not empty).
|
|
// If err is something else, a real error occurred.
|
|
return false, err
|
|
}
|
|
|
|
func cleanupProcedure(duration time.Duration, dryRun bool) {
|
|
var oldFiles []string
|
|
var emptyDirs []string
|
|
|
|
deleteIfTooOld := func(path string, d fs.DirEntry, err error) error {
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
|
|
info, err := d.Info()
|
|
if err != nil {
|
|
log.Printf("Error cleaning file '%s' : %v", d.Name(), err)
|
|
return nil
|
|
}
|
|
age := time.Since(info.ModTime())
|
|
if age >= duration {
|
|
oldFiles = append(oldFiles, path)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
findEmptyDirs := func(path string, d fs.DirEntry, err error) error {
|
|
if !d.IsDir() {
|
|
return nil
|
|
}
|
|
empty, err := isDirEmpty(path)
|
|
if empty {
|
|
emptyDirs = append(emptyDirs, path)
|
|
} else if err != nil {
|
|
log.Printf("Error cleaning directory '%s' : %v", path, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for {
|
|
|
|
filepath.WalkDir(uploadDir, deleteIfTooOld)
|
|
for _, file := range oldFiles {
|
|
if !dryRun {
|
|
err := os.Remove(file)
|
|
if err != nil {
|
|
log.Printf("Failed to remove expired file '%s' : %v", file, err)
|
|
} else {
|
|
log.Printf("Removed expired file '%s'", file)
|
|
}
|
|
} else {
|
|
log.Printf("Dry run: Would remove expired file '%s'", file)
|
|
}
|
|
}
|
|
|
|
filepath.WalkDir(uploadDir, findEmptyDirs)
|
|
for _, file := range emptyDirs {
|
|
if !dryRun {
|
|
err := os.Remove(file)
|
|
if err != nil {
|
|
log.Printf("Failed to remove empty folder '%s' : %v", file, err)
|
|
} else {
|
|
log.Printf("Removed empty folder '%s'", file)
|
|
}
|
|
} else {
|
|
log.Printf("Dry run: Would remove empty folder '%s'", file)
|
|
}
|
|
}
|
|
time.Sleep(6 * time.Hour)
|
|
}
|
|
}
|