about summary refs log tree commit diff
path: root/htping.go
diff options
context:
space:
mode:
Diffstat (limited to 'htping.go')
-rw-r--r--htping.go98
1 files changed, 58 insertions, 40 deletions
diff --git a/htping.go b/htping.go
index 60b8480..890be28 100644
--- a/htping.go
+++ b/htping.go
@@ -173,15 +173,17 @@ type result struct {
 	code int
 }
 
-func ping(url string, seq int, myTransport *transport, results chan result) {
+func ping(ctx context.Context, url string, seq int, myTransport *transport, results chan result) {
 	start := time.Now()
 
 	requestCounter.WithLabelValues(url).Inc()
 	atomic.AddInt32(&ntotal, 1)
 
-	req, err := http.NewRequest(method, url, nil)
+	req, err := http.NewRequestWithContext(ctx, method, url, nil)
 	if err != nil {
-		fmt.Printf("error=%v\n", err)
+		if !errors.Is(err, context.Canceled) {
+			fmt.Printf("error=%v\n", err)
+		}
 		return
 	}
 
@@ -205,7 +207,9 @@ func ping(url string, seq int, myTransport *transport, results chan result) {
 	}
 	res, err := client.Do(req)
 	if err != nil {
-		fmt.Printf("error=%v\n", err)
+		if !errors.Is(err, context.Canceled) {
+			fmt.Printf("error=%v\n", err)
+		}
 		return
 	}
 
@@ -233,7 +237,7 @@ func ping(url string, seq int, myTransport *transport, results chan result) {
 	results <- result{dur, res.StatusCode}
 }
 
-func stats(results chan result, done chan bool) {
+func stats(ctx context.Context, results chan result) {
 	var min, max, sum, sum2 float64
 	min = math.Inf(1)
 	nrecv := 0
@@ -257,7 +261,7 @@ func stats(results chan result, done chan bool) {
 				nsucc++
 			}
 
-		case <-done:
+		case <-ctx.Done():
 			stop := time.Now()
 			if ntotal > 0 {
 				fmt.Printf("\n%d requests sent, %d (%d%%) responses, %d (%d%%) successful, time %d ms\n",
@@ -274,8 +278,7 @@ func stats(results chan result, done chan bool) {
 				fmt.Printf("rtt min/avg/max/mdev = %.3f/%.3f/%.3f/%.3f s\n",
 					min, sum/float64(nrecv), max, mdev)
 			}
-
-			done <- true
+			return
 		}
 	}
 }
@@ -329,15 +332,23 @@ func main() {
 		os.Exit(2)
 	}
 
+	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
+	defer stop()
+
+	var wg sync.WaitGroup
+
 	if *listenAddr != "" {
+		srv := &http.Server{
+			Addr: *listenAddr,
+		}
+
 		prometheus.MustRegister(requestCounter)
 		prometheus.MustRegister(responseCounter)
 		prometheus.MustRegister(durSummary)
 
-		go func() {
-			http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
-				w.Header().Set("Content-Type", "text/html")
-				w.Write([]byte(`<html>
+		http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			w.Header().Set("Content-Type", "text/html")
+			w.Write([]byte(`<html>
     <head><title>htpingd</title></head>
     <body>
     <h1>htpingd</h1>
@@ -345,28 +356,41 @@ func main() {
 </body>
     </html>
 `))
-			})
-			http.Handle("/metrics", promhttp.Handler())
+		})
+		http.Handle("/metrics", promhttp.Handler())
+
+		go func() {
 			log.Println("Prometheus metrics listening on", *listenAddr)
-			err := http.ListenAndServe(*listenAddr, nil)
-			if err != http.ErrServerClosed {
+
+			if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
 				log.Fatal(err)
-				os.Exit(1)
 			}
 		}()
-	}
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			<-ctx.Done()
 
-	interrupt := make(chan os.Signal, 1)
-	signal.Notify(interrupt, os.Interrupt)
+			shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+			defer cancel()
+			if err := srv.Shutdown(shutdownCtx); err != nil {
+				fmt.Fprintf(os.Stderr, "error shutting down http server: %s\n", err)
+			}
+		}()
+	}
 
 	results := make(chan result)
-	done := make(chan bool)
-	go stats(results, done)
+
+	go func() {
+		wg.Add(1)
+		defer wg.Done()
+		stats(ctx, results)
+	}()
 
 	count := 0
 
-	var wg sync.WaitGroup
-	wg.Add(len(args))
+	for _, u := range args {
+		wg.Add(1)
 
 		u := u
 
@@ -394,22 +418,28 @@ func main() {
 			if *flood {
 				for {
 					select {
+					case <-ctx.Done():
+						return
 					default:
-						ping(u, count, myTransport, results)
+						ping(ctx, u, count, myTransport, results)
 						count++
 					}
 				}
 			} else {
 				pingTicker := time.NewTicker(*sleep)
-				go ping(u, count, myTransport, results)
+				defer pingTicker.Stop()
+
+				go ping(ctx, u, count, myTransport, results)
 				count++
 				for {
 					if *maxCount > 0 && count > *maxCount {
 						break
 					}
 					select {
+					case <-ctx.Done():
+						return
 					case <-pingTicker.C:
-						go ping(u, count, myTransport, results)
+						go ping(ctx, u, count, myTransport, results)
 						count++
 					}
 				}
@@ -417,17 +447,5 @@ func main() {
 		}()
 	}
 
-	waitCh := make(chan struct{})
-	go func() {
-		wg.Wait()
-		close(waitCh)
-	}()
-
-	select {
-	case <-waitCh:
-	case <-interrupt:
-	}
-
-	done <- true
-	<-done
+	wg.Wait()
 }