From bb024e3d82d68497493a0a54532aad7e1499df43 Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Sun, 26 Feb 2023 21:34:21 +0900 Subject: [PATCH] nghttpx: Fix bug that causes 400 response after upgrade failure --- integration-tests/nghttpx_http1_test.go | 118 ++++++++++++++++++++++++ src/shrpx_https_upstream.cc | 9 ++ 2 files changed, 127 insertions(+) diff --git a/integration-tests/nghttpx_http1_test.go b/integration-tests/nghttpx_http1_test.go index a083f0e6..ca74a2e3 100644 --- a/integration-tests/nghttpx_http1_test.go +++ b/integration-tests/nghttpx_http1_test.go @@ -594,6 +594,76 @@ func TestH1H1ReqPhaseReturn(t *testing.T) { } } +// TestH1H1ReqPhaseReturnCONNECTMethod tests that mruby request phase +// hook resets llhttp HPE_PAUSED_UPGRADE. +func TestH1H1ReqPhaseReturnCONNECTMethod(t *testing.T) { + opts := options{ + args: []string{"--mruby-file=" + testDir + "/req-return.rb"}, + handler: func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("request should not be forwarded") + }, + } + st := newServerTester(t, opts) + defer st.Close() + + if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1ReqPhaseReturnCONNECTMethod\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil { + t.Fatalf("Error io.WriteString() = %v", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil) + if err != nil { + t.Fatalf("Error http.ReadResponse() = %v", err) + } + + defer resp.Body.Close() + + if got, want := resp.StatusCode, http.StatusNotFound; got != want { + t.Errorf("status: %v; want %v", got, want) + } + + hdCheck := func() { + hdtests := []struct { + k, v string + }{ + {"content-length", "20"}, + {"from", "mruby"}, + } + + for _, tt := range hdtests { + if got, want := resp.Header.Get(tt.k), tt.v; got != want { + t.Errorf("%v = %v; want %v", tt.k, got, want) + } + } + + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("Error io.ReadAll() = %v", err) + } + } + + hdCheck() + + if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1ReqPhaseReturnCONNECTMethod\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil { + t.Fatalf("Error io.WriteString() = %v", err) + } + + resp, err = http.ReadResponse(bufio.NewReader(st.conn), nil) + if err != nil { + t.Fatalf("Error http.ReadResponse() = %v", err) + } + + defer resp.Body.Close() + + if got, want := resp.StatusCode, http.StatusNotFound; got != want { + t.Errorf("status: %v; want %v", got, want) + } + + hdCheck() + + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("Error io.ReadAll() = %v", err) + } +} + // TestH1H1RespPhaseSetHeader tests mruby response phase hook modifies // response header fields. func TestH1H1RespPhaseSetHeader(t *testing.T) { @@ -737,6 +807,54 @@ func TestH1H1POSTRequests(t *testing.T) { } } +// TestH1H1CONNECTMethodFailure tests that CONNECT method failure +// resets llhttp HPE_PAUSED_UPGRADE. +func TestH1H1CONNECTMethodFailure(t *testing.T) { + opts := options{ + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("required-header") == "" { + w.WriteHeader(http.StatusNotFound) + } + }, + } + st := newServerTester(t, opts) + defer st.Close() + + if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1CONNECTMethodFailure\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil { + t.Fatalf("Error io.WriteString() = %v", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil) + if err != nil { + t.Fatalf("Error http.ReadResponse() = %v", err) + } + + defer resp.Body.Close() + + if got, want := resp.StatusCode, http.StatusNotFound; got != want { + t.Errorf("status: %v; want %v", got, want) + } + + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("Error io.ReadAll() = %v", err) + } + + if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1CONNECTMethodFailure\r\nHost: 127.0.0.1:443\r\nrequired-header: foo\r\n\r\n"); err != nil { + t.Fatalf("Error io.WriteString() = %v", err) + } + + resp, err = http.ReadResponse(bufio.NewReader(st.conn), nil) + if err != nil { + t.Fatalf("Error http.ReadResponse() = %v", err) + } + + defer resp.Body.Close() + + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("status: %v; want %v", got, want) + } +} + // // TestH1H2ConnectFailure tests that server handles the situation that // // connection attempt to HTTP/2 backend failed. // func TestH1H2ConnectFailure(t *testing.T) { diff --git a/src/shrpx_https_upstream.cc b/src/shrpx_https_upstream.cc index 458b5426..79cc22a9 100644 --- a/src/shrpx_https_upstream.cc +++ b/src/shrpx_https_upstream.cc @@ -658,6 +658,15 @@ int HttpsUpstream::on_read() { auto htperr = llhttp_execute(&htp_, reinterpret_cast(rb->pos()), rb->rleft()); + if (htperr == HPE_PAUSED_UPGRADE && + rb->pos() == + reinterpret_cast(llhttp_get_error_pos(&htp_))) { + llhttp_resume_after_upgrade(&htp_); + + htperr = llhttp_execute(&htp_, reinterpret_cast(rb->pos()), + rb->rleft()); + } + auto nread = htperr == HPE_OK ? rb->rleft()