diff options
author | Leah Neukirchen <leah@vuxu.org> | 2024-02-24 00:15:09 +0100 |
---|---|---|
committer | Leah Neukirchen <leah@vuxu.org> | 2024-02-24 00:16:03 +0100 |
commit | 01f56892e6c7554d0c17fb3c38415f0e5ab1ca6b (patch) | |
tree | f33313d938c171f0311956065bd8f5cc2436b061 | |
parent | 9471f96da61279b1cf335cc8e4243ad79e9cef9d (diff) | |
download | htping-01f56892e6c7554d0c17fb3c38415f0e5ab1ca6b.tar.gz htping-01f56892e6c7554d0c17fb3c38415f0e5ab1ca6b.tar.xz htping-01f56892e6c7554d0c17fb3c38415f0e5ab1ca6b.zip |
use contexts to properly structure shutdown on interrupt
...instead of this channel contraption.
-rw-r--r-- | htping.go | 98 |
1 files changed, 58 insertions, 40 deletions
diff --git a/htping.go b/htping.go index 60b8480..890be28 100644 --- a/htping.go +++ b/htping.go @@ -173,15 +173,17 @@ type result struct { code int } -func ping(url string, seq int, myTransport *transport, results chan result) { +func ping(ctx context.Context, url string, seq int, myTransport *transport, results chan result) { start := time.Now() requestCounter.WithLabelValues(url).Inc() atomic.AddInt32(&ntotal, 1) - req, err := http.NewRequest(method, url, nil) + req, err := http.NewRequestWithContext(ctx, method, url, nil) if err != nil { - fmt.Printf("error=%v\n", err) + if !errors.Is(err, context.Canceled) { + fmt.Printf("error=%v\n", err) + } return } @@ -205,7 +207,9 @@ func ping(url string, seq int, myTransport *transport, results chan result) { } res, err := client.Do(req) if err != nil { - fmt.Printf("error=%v\n", err) + if !errors.Is(err, context.Canceled) { + fmt.Printf("error=%v\n", err) + } return } @@ -233,7 +237,7 @@ func ping(url string, seq int, myTransport *transport, results chan result) { results <- result{dur, res.StatusCode} } -func stats(results chan result, done chan bool) { +func stats(ctx context.Context, results chan result) { var min, max, sum, sum2 float64 min = math.Inf(1) nrecv := 0 @@ -257,7 +261,7 @@ func stats(results chan result, done chan bool) { nsucc++ } - case <-done: + case <-ctx.Done(): stop := time.Now() if ntotal > 0 { fmt.Printf("\n%d requests sent, %d (%d%%) responses, %d (%d%%) successful, time %d ms\n", @@ -274,8 +278,7 @@ func stats(results chan result, done chan bool) { fmt.Printf("rtt min/avg/max/mdev = %.3f/%.3f/%.3f/%.3f s\n", min, sum/float64(nrecv), max, mdev) } - - done <- true + return } } } @@ -329,15 +332,23 @@ func main() { os.Exit(2) } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + var wg sync.WaitGroup + if *listenAddr != "" { + srv := &http.Server{ + Addr: *listenAddr, + } + prometheus.MustRegister(requestCounter) prometheus.MustRegister(responseCounter) prometheus.MustRegister(durSummary) - go func() { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - w.Write([]byte(`<html> + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(`<html> <head><title>htpingd</title></head> <body> <h1>htpingd</h1> @@ -345,28 +356,41 @@ func main() { </body> </html> `)) - }) - http.Handle("/metrics", promhttp.Handler()) + }) + http.Handle("/metrics", promhttp.Handler()) + + go func() { log.Println("Prometheus metrics listening on", *listenAddr) - err := http.ListenAndServe(*listenAddr, nil) - if err != http.ErrServerClosed { + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatal(err) - os.Exit(1) } }() - } + wg.Add(1) + go func() { + defer wg.Done() + <-ctx.Done() - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + fmt.Fprintf(os.Stderr, "error shutting down http server: %s\n", err) + } + }() + } results := make(chan result) - done := make(chan bool) - go stats(results, done) + + go func() { + wg.Add(1) + defer wg.Done() + stats(ctx, results) + }() count := 0 - var wg sync.WaitGroup - wg.Add(len(args)) + for _, u := range args { + wg.Add(1) u := u @@ -394,22 +418,28 @@ func main() { if *flood { for { select { + case <-ctx.Done(): + return default: - ping(u, count, myTransport, results) + ping(ctx, u, count, myTransport, results) count++ } } } else { pingTicker := time.NewTicker(*sleep) - go ping(u, count, myTransport, results) + defer pingTicker.Stop() + + go ping(ctx, u, count, myTransport, results) count++ for { if *maxCount > 0 && count > *maxCount { break } select { + case <-ctx.Done(): + return case <-pingTicker.C: - go ping(u, count, myTransport, results) + go ping(ctx, u, count, myTransport, results) count++ } } @@ -417,17 +447,5 @@ func main() { }() } - waitCh := make(chan struct{}) - go func() { - wg.Wait() - close(waitCh) - }() - - select { - case <-waitCh: - case <-interrupt: - } - - done <- true - <-done + wg.Wait() } |