Files
Nils O. Selåsdal e0531098d8 Add file duration
2026-07-04 19:30:25 +02:00

440 lines
9.9 KiB
Go

package main
import (
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"html/template"
"io"
"io/fs"
"log"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
var uploadDir string
const idLength = 8
const idCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
var indexHTML *template.Template
func main() {
uploadDir = os.Getenv("UPLOAD_DIR")
if uploadDir == "" {
uploadDir = "/uploads"
}
var fileExpires time.Duration = 7 * 24 * time.Hour
dryRun := strings.ToLower(os.Getenv("DRYRUN_CLEAN")) != "false"
fileDuration := os.Getenv("FILE_DURATION")
if fileDuration != "" {
expires, err := time.ParseDuration(fileDuration)
if err != nil {
log.Fatalf("FILE_DURATION %s is invalid: %v", fileDuration, err)
}
fileExpires = expires
}
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
if err := os.MkdirAll(uploadDir, 0755); err != nil {
log.Fatalf("Failed to create upload directory: %v", err)
}
indexHTML = template.Must(template.ParseFiles("index.html.tmpl"))
http.HandleFunc("/", handleRequest)
log.Printf("Starting server on port %s, upload directory: %s", port, uploadDir)
log.Printf("Files expires after: %v (DRYRUN_CLEAN=%v)", fileExpires, dryRun)
if fileExpires != 0 {
go cleanupProcedure(fileExpires, dryRun)
} else {
log.Print("Cleaning old files is disabled")
}
if err := http.ListenAndServe(":"+port, nil); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
func handleRequest(w http.ResponseWriter, r *http.Request) {
start := time.Now()
if r.Method == "GET" && r.URL.Path == "/" {
serveFrontpage(w, r)
logRequest(r, 200, time.Since(start))
return
}
if r.Method == "GET" && r.URL.Path == "/healthz" {
return
}
if r.Method == "POST" && r.URL.Path == "/upload" {
handleMultipartUpload(w, r, start)
return
}
if r.Method == "PUT" {
handleUpload(w, r, start)
return
}
if r.Method == "GET" {
handleDownload(w, r, start)
return
}
if r.Method == "HEAD" {
return
}
w.WriteHeader(http.StatusMethodNotAllowed)
logRequest(r, 405, time.Since(start))
}
func handleUpload(w http.ResponseWriter, r *http.Request, start time.Time) {
filename := strings.TrimPrefix(r.URL.Path, "/")
if filename == "" {
http.Error(w, "Filename required", http.StatusBadRequest)
logRequest(r, 400, time.Since(start))
return
}
filename = filepath.Base(filename)
id, err := generateUniqueID()
if err != nil {
http.Error(w, "Failed to generate unique ID", http.StatusInternalServerError)
logRequest(r, 500, time.Since(start))
return
}
targetPath := filepath.Join(uploadDir, id, filename)
file, err := os.Create(targetPath)
if err != nil {
http.Error(w, "Failed to create file", http.StatusInternalServerError)
logRequest(r, 500, time.Since(start))
return
}
defer file.Close()
written, err := io.Copy(file, r.Body)
if err != nil {
http.Error(w, "Failed to write file", http.StatusInternalServerError)
logRequest(r, 500, time.Since(start))
return
}
proto := getProtocol(r)
host := getHost(r)
downloadURL := fmt.Sprintf("%s://%s/%s/%s\n", proto, host, id, filename)
w.WriteHeader(http.StatusOK)
w.Write([]byte(downloadURL))
log.Printf("Uploaded %s (%d bytes) -> %s", filename, written, downloadURL)
logRequest(r, 200, time.Since(start))
}
func handleMultipartUpload(w http.ResponseWriter, r *http.Request, start time.Time) {
if err := r.ParseMultipartForm(32 << 20); err != nil {
http.Error(w, "Failed to parse multipart form", http.StatusBadRequest)
logRequest(r, 400, time.Since(start))
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
http.Error(w, "No files provided", http.StatusBadRequest)
logRequest(r, 400, time.Since(start))
return
}
proto := getProtocol(r)
host := getHost(r)
urls := make([]string, 0, len(files))
for _, fileHeader := range files {
src, err := fileHeader.Open()
if err != nil {
continue
}
id, err := generateUniqueID()
if err != nil {
src.Close()
continue
}
filename := filepath.Base(fileHeader.Filename)
targetPath := filepath.Join(uploadDir, id, filename)
dst, err := os.Create(targetPath)
if err != nil {
src.Close()
continue
}
written, err := io.Copy(dst, src)
dst.Close()
src.Close()
if err != nil {
continue
}
downloadURL := fmt.Sprintf("%s://%s/%s/%s", proto, host, id, filename)
urls = append(urls, downloadURL)
log.Printf("Uploaded %s (%d bytes) -> %s", filename, written, downloadURL)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string][]string{"urls": urls})
logRequest(r, 200, time.Since(start))
}
func handleDownload(w http.ResponseWriter, r *http.Request, start time.Time) {
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/")
if len(parts) != 2 {
http.NotFound(w, r)
logRequest(r, 404, time.Since(start))
return
}
id := parts[0]
filename := parts[1]
filePath := filepath.Join(uploadDir, id, filename)
file, err := os.Open(filePath)
if err != nil {
http.NotFound(w, r)
logRequest(r, 404, time.Since(start))
return
}
if !strings.HasPrefix(filePath, uploadDir) {
http.NotFound(w, r)
logRequest(r, 404, time.Since(start))
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
http.Error(w, "Failed to stat file", http.StatusInternalServerError)
logRequest(r, 500, time.Since(start))
return
}
w.Header().Set("Content-Type", detectContentType(filename))
w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=\"%s\"", filename))
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
io.Copy(w, file)
logRequest(r, 200, time.Since(start))
}
func generateUniqueID() (string, error) {
const maxRetries = 100
for range maxRetries {
id := generateID()
targetDir := filepath.Join(uploadDir, id)
if _, err := os.Stat(targetDir); os.IsNotExist(err) {
if err := os.MkdirAll(targetDir, 0755); err != nil {
return "", err
}
return id, nil
}
}
return "", fmt.Errorf("failed to generate unique ID after %d retries", maxRetries)
}
func generateID() string {
bytes := make([]byte, idLength)
if _, err := rand.Read(bytes); err != nil {
panic(err)
}
for i, b := range bytes {
bytes[i] = idCharset[int(b)%len(idCharset)]
}
return string(bytes)
}
func getHost(r *http.Request) string {
// Check X-Forwarded-Host (common proxy header)
if host := r.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
// Check standard Forwarded header (RFC 7239)
if forwarded := r.Header.Get("Forwarded"); forwarded != "" {
// Parse "Forwarded: for=...; host=example.com; proto=https"
parts := strings.SplitSeq(forwarded, ";")
for part := range parts {
part = strings.TrimSpace(part)
if host, found := strings.CutPrefix(part, "host="); found {
return host
}
}
}
// Fallback to Host header
if host := r.Header.Get("Host"); host != "" {
return host
}
return "localhost:8080"
}
func getProtocol(r *http.Request) string {
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
if forwarded := r.Header.Get("Forwarded"); forwarded != "" {
parts := strings.SplitSeq(forwarded, ";")
for part := range parts {
part = strings.TrimSpace(part)
if proto, found := strings.CutPrefix(part, "proto="); found {
return proto
}
}
}
if ssl := r.Header.Get("X-Forwarded-SSL"); ssl == "on" {
return "https"
}
if r.TLS != nil {
return "https"
}
return "http"
}
func detectContentType(filename string) string {
ext := filepath.Ext(filename)
mimeType := mime.TypeByExtension(ext)
if mimeType != "" {
return mimeType
}
return "application/octet-stream"
}
func logRequest(r *http.Request, status int, duration time.Duration) {
log.Printf("[%s] %s %s %d %v", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.URL.Path, status, duration)
}
func serveFrontpage(w http.ResponseWriter, r *http.Request) {
proto := getProtocol(r)
host := getHost(r)
pageData := map[string]string{"Proto": proto, "Host": host}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
indexHTML.Execute(w, pageData)
}
// Cleanup
func isDirEmpty(name string) (bool, error) {
f, err := os.Open(name)
if err != nil {
return false, err
}
defer f.Close()
// Read exactly one entry.
_, err = f.ReadDir(1)
if errors.Is(err, io.EOF) {
// EOF means there are no entries, so the directory is empty.
return true, nil
}
// If err is nil, an entry was found (not empty).
// If err is something else, a real error occurred.
return false, err
}
func cleanupProcedure(duration time.Duration, dryRun bool) {
var oldFiles []string
var emptyDirs []string
deleteIfTooOld := func(path string, d fs.DirEntry, err error) error {
if err != nil {
return nil
}
if d.IsDir() {
return nil
}
info, err := d.Info()
if err != nil {
log.Printf("Error cleaning file '%s' : %v", d.Name(), err)
return nil
}
age := time.Since(info.ModTime())
if age >= duration {
oldFiles = append(oldFiles, path)
}
return nil
}
findEmptyDirs := func(path string, d fs.DirEntry, err error) error {
if !d.IsDir() {
return nil
}
empty, err := isDirEmpty(path)
if empty {
emptyDirs = append(emptyDirs, path)
} else if err != nil {
log.Printf("Error cleaning directory '%s' : %v", path, err)
}
return nil
}
for {
filepath.WalkDir(uploadDir, deleteIfTooOld)
for _, file := range oldFiles {
if !dryRun {
err := os.Remove(file)
if err != nil {
log.Printf("Failed to remove expired file '%s' : %v", file, err)
} else {
log.Printf("Removed expired file '%s'", file)
}
} else {
log.Printf("Dry run: Would remove expired file '%s'", file)
}
}
filepath.WalkDir(uploadDir, findEmptyDirs)
for _, file := range emptyDirs {
if !dryRun {
err := os.Remove(file)
if err != nil {
log.Printf("Failed to remove empty folder '%s' : %v", file, err)
} else {
log.Printf("Removed empty folder '%s'", file)
}
} else {
log.Printf("Dry run: Would remove empty folder '%s'", file)
}
}
time.Sleep(6 * time.Hour)
}
}