Merge pull request #1874 from nghttp2/nghttpx-llhttp-resume-after-upgrade

nghttpx: Fix bug that causes 400 response after upgrade failure
This commit is contained in:
Tatsuhiro Tsujikawa 2023-02-26 23:22:13 +09:00 committed by GitHub
commit 14cc308d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 0 deletions

View File

@ -594,6 +594,76 @@ func TestH1H1ReqPhaseReturn(t *testing.T) {
}
}
// TestH1H1ReqPhaseReturnCONNECTMethod tests that mruby request phase
// hook resets llhttp HPE_PAUSED_UPGRADE.
func TestH1H1ReqPhaseReturnCONNECTMethod(t *testing.T) {
opts := options{
args: []string{"--mruby-file=" + testDir + "/req-return.rb"},
handler: func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("request should not be forwarded")
},
}
st := newServerTester(t, opts)
defer st.Close()
if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1ReqPhaseReturnCONNECTMethod\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}
resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("status: %v; want %v", got, want)
}
hdCheck := func() {
hdtests := []struct {
k, v string
}{
{"content-length", "20"},
{"from", "mruby"},
}
for _, tt := range hdtests {
if got, want := resp.Header.Get(tt.k), tt.v; got != want {
t.Errorf("%v = %v; want %v", tt.k, got, want)
}
}
if _, err := io.ReadAll(resp.Body); err != nil {
t.Fatalf("Error io.ReadAll() = %v", err)
}
}
hdCheck()
if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1ReqPhaseReturnCONNECTMethod\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}
resp, err = http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("status: %v; want %v", got, want)
}
hdCheck()
if _, err := io.ReadAll(resp.Body); err != nil {
t.Fatalf("Error io.ReadAll() = %v", err)
}
}
// TestH1H1RespPhaseSetHeader tests mruby response phase hook modifies
// response header fields.
func TestH1H1RespPhaseSetHeader(t *testing.T) {
@ -737,6 +807,54 @@ func TestH1H1POSTRequests(t *testing.T) {
}
}
// TestH1H1CONNECTMethodFailure tests that CONNECT method failure
// resets llhttp HPE_PAUSED_UPGRADE.
func TestH1H1CONNECTMethodFailure(t *testing.T) {
opts := options{
handler: func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("required-header") == "" {
w.WriteHeader(http.StatusNotFound)
}
},
}
st := newServerTester(t, opts)
defer st.Close()
if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1CONNECTMethodFailure\r\nHost: 127.0.0.1:443\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}
resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("status: %v; want %v", got, want)
}
if _, err := io.ReadAll(resp.Body); err != nil {
t.Fatalf("Error io.ReadAll() = %v", err)
}
if _, err := io.WriteString(st.conn, "CONNECT 127.0.0.1:443 HTTP/1.1\r\nTest-Case: TestH1H1CONNECTMethodFailure\r\nHost: 127.0.0.1:443\r\nrequired-header: foo\r\n\r\n"); err != nil {
t.Fatalf("Error io.WriteString() = %v", err)
}
resp, err = http.ReadResponse(bufio.NewReader(st.conn), nil)
if err != nil {
t.Fatalf("Error http.ReadResponse() = %v", err)
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("status: %v; want %v", got, want)
}
}
// // TestH1H2ConnectFailure tests that server handles the situation that
// // connection attempt to HTTP/2 backend failed.
// func TestH1H2ConnectFailure(t *testing.T) {

View File

@ -658,6 +658,15 @@ int HttpsUpstream::on_read() {
auto htperr = llhttp_execute(&htp_, reinterpret_cast<const char *>(rb->pos()),
rb->rleft());
if (htperr == HPE_PAUSED_UPGRADE &&
rb->pos() ==
reinterpret_cast<const uint8_t *>(llhttp_get_error_pos(&htp_))) {
llhttp_resume_after_upgrade(&htp_);
htperr = llhttp_execute(&htp_, reinterpret_cast<const char *>(rb->pos()),
rb->rleft());
}
auto nread =
htperr == HPE_OK
? rb->rleft()