reactjsfastapifetch-apievent-stream

How to stream LLM response from FastAPI to React?


I want to stream an LLM (ollama) response using fastapi and react. I can successfully get an answer from the LLM without streaming, but when I try to stream it, I get an error in react. The LLM answer streams successfully when I print each chunk in fastapi. I'm using @microsoft/fetch-event-source to stream the response.

I've create a github repo to help you help me.

Here's the fastapi and react code in question:

react

Something is happening to cause console.log("another error onopen") to trigger.

import { useState } from "react";
import {
  fetchEventSource,
  EventStreamContentType,
} from "@microsoft/fetch-event-source";

class FatalError extends Error {}

const StreamResponse = ({ input }: { input: string }) => {
  const [answer, setAnswer] = useState("");

  const handleClick = () => {
    fetchEventSource("http://localhost:8000/stream", {
      body: JSON.stringify({ question: input }),
      method: "POST",
      headers: {
        "Content-type": "application/json",
        Accept: "text/event-stream",
      },

      async onopen(response) {
        console.log("onopen");
        console.log(response);

        if (
          response.ok &&
          response.headers.get("content-type") === EventStreamContentType
        ) {
          // setAnswer(response.answer)
          return; // everything's good
        } else if (
          response.status >= 400 &&
          response.status < 500 &&
          response.status !== 429
        ) {
          console.log("fatal error onopen");
          // client-side errors are usually non-retriable:
          // throw new FatalError();
        } else {
          console.log("another error onopen");
          // throw new RetriableError();
        }
      },
      onmessage(msg) {
        // if the server emits an error message, throw an exception
        // so it gets handled by the onerror callback below:
        console.log("onmessage");
        console.log(msg);
        if (msg.event === "FatalError") {
          throw new FatalError(msg.data);
        }
      },
      onclose() {
        console.log("onclose");
        // if the server closes the connection unexpectedly, retry:
        // throw new RetriableError();
      },
      onerror(err) {
        console.log("onerror");
        console.log(err);
        if (err instanceof FatalError) {
          throw err; // rethrow to stop the operation
        } else {
          // do nothing to automatically retry. You can also
          // return a specific retry interval here.
        }
      },
    });
  };

  return (
    <div style={{ width: "100%" }}>
      <div style={{ display: "flex", gap: 10, alignItems: "center" }}>
        <button onClick={handleClick} style={{ height: 24 }}>
          Submit question with streaming
        </button>
      </div>
      <p>Response</p>
      <div style={{ border: 1, borderStyle: "solid", height: 100 }}>
        {answer}
      </div>
    </div>
  );
};

export default StreamResponse;

fastapi

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


app = FastAPI()

origins = ['*']

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=False,
    allow_methods=['*'],
    allow_headers=['*']
)

ollama = Ollama(
    base_url="http://localhost:11434",
    model="llama3"
)

system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know. Use three sentences maximum and keep the "
    "answer concise."
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

chain = (prompt | ollama | StrOutputParser())


class Question(BaseModel):
    question: str


@app.get('/test')
def read_test():
    return {'hello': 'world'}


@app.post('/nostream')
def no_stream_llm(question: Question):
    answer = chain.invoke({'input': question.question})
    print(answer)
    return {'answer': answer}


def stream_answer(question):
    for chunk in chain.stream(question):
        print(chunk, end='', flush=True)
        yield chunk


@app.post('/stream')
def stream_response_from_llm(question: Question):
    return StreamingResponse(stream_answer(question=question.question), media_type="text/event-stream")

Solution

  • I was able to figure out. The working code is provided below and can also be found in this repo. The code is by no means perfect, but it should work. Here are some of my observations while trying to figure out the issue:

    1. Whether using EventSource or fetchEventSource (@microsoft/fetch-event-source), both needed to appear within useEffect. I tried to create a function that would set this up when a button was clicked (similar to NoStreamResponse.tsx, but that didn't work (not sure why).
    2. Since EventSource only supports GET, I tried POST with fetchEventSource and noticed that it always retries (according to console.log("retriableerror"). When I implemented GET with fetchEventSource, that log never triggered.
    3. In the backend, for some reason I have yet to determine, the yield statement had to be in this format: yield f'data: {chunk}\n\n. Both \n\n were necessary in order for it to work (again, not sure why).
    4. I had to include asyncio.sleep() to slow down the response so it could be streamed. The stream doesn't get captured without it.

    Working code

    main.py

    import asyncio
    
    from fastapi import FastAPI
    from fastapi.middleware.cors import CORSMiddleware
    from fastapi.responses import StreamingResponse
    from pydantic import BaseModel
    
    from langchain_community.llms import Ollama
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_core.output_parsers import StrOutputParser
    
    
    app = FastAPI()
    
    origins = ['*']
    
    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=False,
        allow_methods=['*'],
        allow_headers=['*']
    )
    
    ollama = Ollama(
        base_url="http://localhost:11434",
        model="llama3"
    )
    
    system_prompt = (
        "You are an assistant for question-answering tasks. "
        "Use the following pieces of retrieved context to answer "
        "the question. If you don't know the answer, say that you "
        "don't know. Use three sentences maximum and keep the "
        "answer concise."
    )
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("human", "{input}"),
        ]
    )
    
    chain = (prompt | ollama | StrOutputParser())
    
    
    class Question(BaseModel):
        question: str
    
    
    @app.post('/nostream')
    def no_stream_llm(question: Question):
        answer = chain.invoke({'input': question.question})
        print(answer)
        return {'answer': answer}
    
    
    async def stream_answer(question):
        for chunk in chain.stream(question):
            print(chunk, end='', flush=True)
            yield f'data: {chunk}\n\n'
            await asyncio.sleep(0.25)
    
    
    @app.get('/stream-with-get')
    async def stream_response_from_llm_get(question: str):
        return StreamingResponse(stream_answer(question=question), media_type='text/event-stream')
    
    
    @app.post('/stream-with-post')
    async def stream_response_from_llm_post(question: Question):
        return StreamingResponse(stream_answer(question=question.question), media_type='text/event-stream')
    

    App.tsx

    import { useState } from "react";
    import NoStreamResponse from "./components/NoStreamResponse";
    import StreamResponseEventSource from "./components/StreamResponseEventSource";
    import StreamResponseFetchEventSourcePost from "./components/StreamResponseFetchEventSourcePost";
    
    function App() {
      const [input] = useState("What color is the sky?");
    
      return (
        <div style={{ display: "flex", flexDirection: "column", gap: 50 }}>
          <p>Question: {input}</p>
          <NoStreamResponse input={input} />
          <StreamResponseEventSource input={input} />
          <StreamResponseFetchEventSourcePost input={input} />
        </div>
      );
    }
    
    export default App;
    

    NoStreamResponse.tsx

    import { useState } from "react";
    
    interface Response {
      answer: string;
    }
    
    const NoStreamResponse = ({ input }: { input: string }) => {
      const [answer, setAnswer] = useState("");
    
      const handleClick = () => {
        const handleResponse = (response: Response) => {
          console.log(response);
          setAnswer(response.answer);
        };
    
        fetch("http://localhost:8000/nostream", {
          body: JSON.stringify({ question: input }),
          method: "POST",
          headers: { "Content-type": "application/json" },
        })
          .then((response) => response.json())
          .then((response) => handleResponse(response))
          .catch((error) => console.error(error));
      };
      return (
        <div style={{ width: "100%" }}>
          <div style={{ display: "flex", gap: 10, alignItems: "center" }}>
            <button onClick={handleClick} style={{ height: 24 }}>
              Submit question with no stream
            </button>
          </div>
          <p>Response</p>
          <div style={{ border: 1, borderStyle: "solid", height: 100 }}>
            {answer}
          </div>
        </div>
      );
    };
    
    export default NoStreamResponse;
    

    StreamResponseEventSource.tsx

    import { useState, useEffect } from "react";
    
    const StreamResponseEventSource = ({ input }: { input: string }) => {
      const [answer, setAnswer] = useState("");
      const [startStream, setStartStream] = useState(false);
    
      useEffect(() => {
        if (startStream) {
          setAnswer("");
          const eventSource = new EventSource(
            `http://localhost:8000/stream-with-get?question=${input}`
          );
    
          eventSource.onmessage = function (event) {
            console.log(event);
            setAnswer((prevAnswer) => prevAnswer + event.data);
          };
    
          eventSource.onerror = function (err) {
            console.error("EventSource failed.");
            console.error(err);
            eventSource.close();
          };
    
          return () => {
            setStartStream(false);
            eventSource.close();
          };
        }
      }, [startStream, input]);
    
      return (
        <div style={{ width: "100%" }}>
          <div style={{ display: "flex", gap: 10, alignItems: "center" }}>
            <button onClick={() => setStartStream(true)} style={{ height: 24 }}>
              Stream with EventSource
            </button>
          </div>
          <p>Response</p>
          <div style={{ border: 1, borderStyle: "solid", height: 100 }}>
            {answer}
          </div>
        </div>
      );
    };
    
    export default StreamResponseEventSource;
    

    StreamResponseFetchEventSourceGet.tsx

    import { useState, useEffect } from "react";
    import {
      fetchEventSource,
      EventStreamContentType,
    } from "@microsoft/fetch-event-source";
    
    class RetriableError extends Error {}
    class FatalError extends Error {}
    
    const StreamResponseFetchEventSourceGet = ({ input }: { input: string }) => {
      const [answer, setAnswer] = useState("");
      const [startStream, setStartStream] = useState(false);
    
      useEffect(() => {
        if (startStream) {
          setAnswer("");
    
          fetchEventSource(
            `http://localhost:8000/stream-with-get?question=${input}`,
            {
              async onopen(response) {
                if (
                  response.ok &&
                  response.headers.get("content-type") === EventStreamContentType
                ) {
                  console.log("everytings good");
                  return; // everything's good
                } else if (
                  response.status >= 400 &&
                  response.status < 500 &&
                  response.status !== 429
                ) {
                  // client-side errors are usually non-retriable:
                  throw new FatalError();
                } else {
                  console.log("retriableerror");
                  // throw new RetriableError();
                }
              },
              onmessage(event) {
                // if the server emits an error message, throw an exception
                // so it gets handled by the onerror callback below:
                if (event.event === "FatalError") {
                  throw new FatalError(event.data);
                }
                console.log(event);
                setAnswer((prevMessages) => prevMessages + event.data);
              },
              onclose() {
                // if the server closes the connection unexpectedly, retry:
                console.log("onclose");
                // throw new RetriableError();
              },
              onerror(err) {
                if (err instanceof FatalError) {
                  throw err; // rethrow to stop the operation
                } else {
                  console.log("onerror");
                  // do nothing to automatically retry. You can also
                  // return a specific retry interval here.
                }
              },
            }
          );
    
          return () => {
            setStartStream(false);
          };
        }
      }, [startStream, input]);
    
      return (
        <div style={{ width: "100%" }}>
          <div style={{ display: "flex", gap: 10, alignItems: "center" }}>
            <button onClick={() => setStartStream(true)} style={{ height: 24 }}>
              Stream with fetchEventSource (GET)
            </button>
          </div>
          <p>Response</p>
          <div style={{ border: 1, borderStyle: "solid", height: 100 }}>
            {answer}
          </div>
        </div>
      );
    };
    
    export default StreamResponseFetchEventSourceGet;
    

    StreamResponseFetchEventSourcePost.tsx

    import { useState, useEffect } from "react";
    import {
      fetchEventSource,
      EventStreamContentType,
    } from "@microsoft/fetch-event-source";
    
    // class RetriableError extends Error {}
    class FatalError extends Error {}
    
    const StreamResponseFetchEventSourcePost = ({ input }: { input: string }) => {
      const [answer, setAnswer] = useState("");
      const [startStream, setStartStream] = useState(false);
    
      useEffect(() => {
        if (startStream) {
          setAnswer("");
    
          fetchEventSource("http://localhost:8000/stream-with-post", {
            method: "POST",
            headers: {
              "Content-Type": "application/json",
            },
            body: JSON.stringify({ question: input }),
            async onopen(response) {
              if (
                response.ok &&
                response.headers.get("content-type") === EventStreamContentType
              ) {
                console.log("everything is good");
                return; // everything's good
              } else if (
                response.status >= 400 &&
                response.status < 500 &&
                response.status !== 429
              ) {
                // client-side errors are usually non-retriable:
                throw new FatalError();
              } else {
                // NOTE: This triggers for POST, but not GET. Not sure why
                console.log("retriableerror");
                // throw new RetriableError();
              }
            },
            onmessage(event) {
              // if the server emits an error message, throw an exception
              // so it gets handled by the onerror callback below:
              if (event.event === "FatalError") {
                throw new FatalError(event.data);
              }
              console.log(event);
              setAnswer((prevMessages) => prevMessages + event.data);
            },
            onclose() {
              // if the server closes the connection unexpectedly, retry:
              console.log("onclose");
              // throw new RetriableError();
            },
            onerror(err) {
              if (err instanceof FatalError) {
                throw err; // rethrow to stop the operation
              } else {
                console.log("onerror");
                // do nothing to automatically retry. You can also
                // return a specific retry interval here.
              }
            },
          });
    
          return () => {
            setStartStream(false);
          };
        }
      }, [startStream, input]);
    
      return (
        <div style={{ width: "100%" }}>
          <div style={{ display: "flex", gap: 10, alignItems: "center" }}>
            <button onClick={() => setStartStream(true)} style={{ height: 24 }}>
              Stream with fetchEventSource (POST)
            </button>
          </div>
          <p>Response</p>
          <div style={{ border: 1, borderStyle: "solid", height: 100 }}>
            {answer}
          </div>
        </div>
      );
    };
    
    export default StreamResponseFetchEventSourcePost;