From 01f56892e6c7554d0c17fb3c38415f0e5ab1ca6b Mon Sep 17 00:00:00 2001 From: Leah Neukirchen Date: Sat, 24 Feb 2024 00:15:09 +0100 Subject: use contexts to properly structure shutdown on interrupt ...instead of this channel contraption. --- htping.go | 98 +++++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 40 deletions(-) diff --git a/htping.go b/htping.go index 60b8480..890be28 100644 --- a/htping.go +++ b/htping.go @@ -173,15 +173,17 @@ type result struct { code int } -func ping(url string, seq int, myTransport *transport, results chan result) { +func ping(ctx context.Context, url string, seq int, myTransport *transport, results chan result) { start := time.Now() requestCounter.WithLabelValues(url).Inc() atomic.AddInt32(&ntotal, 1) - req, err := http.NewRequest(method, url, nil) + req, err := http.NewRequestWithContext(ctx, method, url, nil) if err != nil { - fmt.Printf("error=%v\n", err) + if !errors.Is(err, context.Canceled) { + fmt.Printf("error=%v\n", err) + } return } @@ -205,7 +207,9 @@ func ping(url string, seq int, myTransport *transport, results chan result) { } res, err := client.Do(req) if err != nil { - fmt.Printf("error=%v\n", err) + if !errors.Is(err, context.Canceled) { + fmt.Printf("error=%v\n", err) + } return } @@ -233,7 +237,7 @@ func ping(url string, seq int, myTransport *transport, results chan result) { results <- result{dur, res.StatusCode} } -func stats(results chan result, done chan bool) { +func stats(ctx context.Context, results chan result) { var min, max, sum, sum2 float64 min = math.Inf(1) nrecv := 0 @@ -257,7 +261,7 @@ func stats(results chan result, done chan bool) { nsucc++ } - case <-done: + case <-ctx.Done(): stop := time.Now() if ntotal > 0 { fmt.Printf("\n%d requests sent, %d (%d%%) responses, %d (%d%%) successful, time %d ms\n", @@ -274,8 +278,7 @@ func stats(results chan result, done chan bool) { fmt.Printf("rtt min/avg/max/mdev = %.3f/%.3f/%.3f/%.3f s\n", min, sum/float64(nrecv), max, mdev) } - - done <- true + return } } } @@ -329,15 +332,23 @@ func main() { os.Exit(2) } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + var wg sync.WaitGroup + if *listenAddr != "" { + srv := &http.Server{ + Addr: *listenAddr, + } + prometheus.MustRegister(requestCounter) prometheus.MustRegister(responseCounter) prometheus.MustRegister(durSummary) - go func() { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - w.Write([]byte(` + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(` htpingd

htpingd

@@ -345,28 +356,41 @@ func main() { `)) - }) - http.Handle("/metrics", promhttp.Handler()) + }) + http.Handle("/metrics", promhttp.Handler()) + + go func() { log.Println("Prometheus metrics listening on", *listenAddr) - err := http.ListenAndServe(*listenAddr, nil) - if err != http.ErrServerClosed { + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatal(err) - os.Exit(1) } }() - } + wg.Add(1) + go func() { + defer wg.Done() + <-ctx.Done() - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + fmt.Fprintf(os.Stderr, "error shutting down http server: %s\n", err) + } + }() + } results := make(chan result) - done := make(chan bool) - go stats(results, done) + + go func() { + wg.Add(1) + defer wg.Done() + stats(ctx, results) + }() count := 0 - var wg sync.WaitGroup - wg.Add(len(args)) + for _, u := range args { + wg.Add(1) u := u @@ -394,22 +418,28 @@ func main() { if *flood { for { select { + case <-ctx.Done(): + return default: - ping(u, count, myTransport, results) + ping(ctx, u, count, myTransport, results) count++ } } } else { pingTicker := time.NewTicker(*sleep) - go ping(u, count, myTransport, results) + defer pingTicker.Stop() + + go ping(ctx, u, count, myTransport, results) count++ for { if *maxCount > 0 && count > *maxCount { break } select { + case <-ctx.Done(): + return case <-pingTicker.C: - go ping(u, count, myTransport, results) + go ping(ctx, u, count, myTransport, results) count++ } } @@ -417,17 +447,5 @@ func main() { }() } - waitCh := make(chan struct{}) - go func() { - wg.Wait() - close(waitCh) - }() - - select { - case <-waitCh: - case <-interrupt: - } - - done <- true - <-done + wg.Wait() } -- cgit 1.4.1