Streaming OpenAI function calls

Streaming OpenAI function calls
Photo by Robert Zunikoff / Unsplash

Given the long latency that LLMs currently introduce doing inference, streaming the responses to the user as they are generated is almost a requirement in AI applications.

Streaming decreases the perceived and actual wait for results – the user can start reading and reflecting on the results as soon as output is available. This is not only useful but also pleasing: the interplay between the LLM chugging along and the built-in predicting machine that is our brain can somehow make us feel more involved in the outcome.

Our pipeline includes multiple steps, and it's not possible to stream every answer from the LLM. We start by classifying the intent of the query, then retrieving relevant document chunks from several systems (text and semantic search), then we typically perform a "fact extraction" to reduce hallucination before feeding the facts to the final LLM generation prompt. At each step we try to update the user with what's going on, by e.g. displaying a count of the relevant documents we're processing, but it's only at the final generation step that we can confidently stream the results to the user. Since at the moment we use GPT-4 for this generation step, it often happens to be relatively slow, so streaming the output of this step is a big win for the user experience of Connie AI.

In this article, we describe how we stream the response from OpenAI function calls to our user interface. We use function calls extensively in order to improve our chances of constraining the LLM output to a valid, parseable format.

While streaming the response from an LLM is typically as easy as forwarding its output to the front-end, streaming the response from a function call gets slightly more tricky. If you're interested in streaming a normal text completion, OpenAI offer a cookbook that may be more useful.

The first step is to enable streaming responses when invoking the completions endpoint. This is done by passing a boolean parameter stream in the payload (see docs). This is an example request to the OpenAI chat completion endpoint with streaming enabled:

const response = await fetch('https://api.openai.com/v1/chat/completions', {
  method: 'POST',
  headers: {
    'Content-Type': 'application/json',
    Authorization: `Bearer ${API_KEY}`,
  },
  body: JSON.stringify({
    model,
    max_tokens,
    temperature,
    messages,
    functions,
    function_call,
    stream: true,
  }),
});

if (response.ok) {
  return streamingGenerator(response.body); // Explained below
} 

If the request is successful, instead of a single response we will get a stream of data-only server-sent events, which is a very simple format consisting of one message per line, where each line starts with the prefix data: .

To make it convenient to consume the messages incrementally, we use a Javascript generator function that can be called repeatedly in order to obtain the status after each message. This is a simplified version of our generator:

async function* streamingGenerator(stream) {
  let completion;
  let leftover = '';
  
  for await (const chunk of stream) {
    const chunkText = Buffer.from(chunk).toString('utf8');
    // We may receive more than a message per chunk,
    // and messages may be split across chunks
    const lines = (leftover + chunkText).split('\n');
    // The last line in a chunk is either empty or an incomplete message
    leftover = lines.pop();
    for (const line of lines) {
      if (line.trim() === '') continue;
      if (!line.startsWith('data:'))
        throw new Error(
          "Expected all server-side event messages to start with 'data:'"
        );

      const message = line.slice(5).trim();
      // The last message is literally 'data: [DONE]'
      if (message === '[DONE]') break;
      
      let data;
      try {
        data = JSON.parse(message);
      } catch (err) {
        throw new Error('Chunk JSON parsing failed');
      }
      
      // Update the completion by merging the new message with
      // the tokens we had already-received
      completion = {
        ...data,
        choices: data.choices.map((choice, i) => ({
          ...choice,
          delta: undefined,
          message: mergeDelta(
                     completion?.choices?.[i]?.message || {},
                     choice.delta
                   ),
        })),
      };
      
      yield fixFunctionArguments(completion);
    }
  }

  return {
    ...fixFunctionArguments(completion),
    usage: computeUsage(completion),
  };
}

The output of this generator can be consumed with code similar to this:

const updates = await createChatCompletion(...);
 
for await (const update of updates) {
  const completion = update.completion;
  
  const partialFunctionCallArguments =
    completion.choices?.[0]?.message?.function_call?.arguments;
    
  // The partial arguments is an incomplete JSON object
  const partialAnswer = 
    parsePartialAnswer(partialFunctionCallArguments ?? '');
    
  if (partialAnswer) {
    sendPartialAnswerToUI(partialAnswer);
  } 
}

Let's see how this works. Every message received from the OpenAI completion endpoint is a JSON object in the form:

{
    "id": "chatcmpl-{ID}", // where ID is the completion ID
    "object": "chat.completion.chunk",
    "created": 1694792391,
    "model": "gpt-4-0613",
    "choices": [
        {
            "index": 0,
            "delta": {
                "role": "assistant",
                "content": null,
                "function_call": {
                    "name": "answer_question", // name of our function
                    "arguments": "" // incremental
                }
            },
            "finish_reason": null
        }
    ]
}

(This assumes you only requested one completion choice in your request.)

Notice this is very similar to the non-streaming response, where the main difference is that instead of a message field in each choice, you get a delta field.

As the LLM generates tokens, you receive repeated messages almost exactly like the above, with the arguments field containing the newly-generated tokens. In order to obtain the full response, we need to accumulate them, and this is what the mergeDelta function does:

function mergeDelta(obj, delta) {
  Object.entries(delta).forEach(([key, value]) => {
    if (!(key in obj)) {
      obj[key] = value;
      return;
    }
    
    if (typeof value !== typeof obj[key]) {
      throw new Error('Merging: type mismatch');
    }
    
    switch (typeof value) {
      case 'string':
        obj[key] += value;
        break;
      case 'object':
        mergeDelta(obj[key], delta[key], log);
        break;
      default:
        throw new Error('Merging: non-string property');
    }
  });
  
  return obj;
}

Thus, with each message we get a completion object that has a bit more of the function call arguments.

The function call arguments is a stringified JSON object itself (this is the reason we use function calls in the first place – to try to obtain valid JSON as output!) For example, at some middle point our streamed completion may look like:

{
    ...
    "choices": [
        {
            ...
            "delta": {
                ...
                "role": "assistant",
                "function_call": {
                    "name": "answer_question",
                    "arguments": "{\"confidence\":\"high\",\"referenceDocuments\":[\"docId1\",\"docId2\"],\"answer\":\"According to the facts in the reference documents, the an"
                }
            },
        }
    ]
}

Now, since the JSON is still being generated, we can't just do JSON.parse(partialFunctionCallArguments). Instead, we need to partially-parse it in order to extract and forward the current value of the answer to the user interface. There are partial JSON parsers out there that you can use, but for many cases a regex will do the trick:

function parsePartialAnswer(partialFunctionCallArguments) {
  return partialFunctionCallArguments
    .match(/"answer":\s*"((?:[^"\\]|\\.)*));
}

This partial answer can then be sent to the user interface for the user to see it as it's being generated.

Finally, a difference with streaming completions is that the OpenAI API doesn't include the usage statistics in streamed completions like it does for non-streaming requests.

Since it's important for us to track our costs, we work around this limitation using tiktoken to count the token usage for our streaming requests. Once we receive the message saying the generation is done, we compute the token counts and return them with the final completion response as can be seen at the end of our generator function above.

The code we use to count tokens is similar to this:

import { get_encoding as getEncoding } from 'tiktoken';

const encoder = getEncoding('cl100k_base');

function countTokens(text) {
  return encoder.encode(text).length;
}

let estimatedPromptTokens = messages.reduce((acc, val) => acc + countTokens(val.content), 0);
if (functions.length > 0) {
  // estimate token cost of functions
   estimatedPromptTokens += countTokens(JSON.stringify(functions));
}
  
const estimatedCompletionTokens = countTokens(
    completion.choices
      .map(
        (choice) =>
          `${choice.message.content || ''}${choice.message.function_call?.arguments || ''}`
      )
      .join('')
  );
  
completion.usage = {
  prompt_tokens: estimatedPromptTokens,
  completion_tokens: estimatedCompletionTokens,
  total_tokens: estimatedPromptTokens + estimatedCompletionTokens,
};

Sagacious readers may have noticed that the generator calls a function we haven't defined above, fixFunctionArguments. This function is necessary to correctly parse the output of OpenAI function calls, even when not streaming.

We realized that OpenAIs function calls can generate malformed JSON by including newlines inside stringified strings. This is something they may fix in the future, but as of September 2023 the issue is still there.

This is the implementation of our workaround:

const fixFunctionArguments = (completion) => {
  // Function arguments are supposed to be stringified JSON,
  // but the LLM sometimes emits newline characters (as opposed to
  // a '\', 'n' character sequence) inside of strings, which makes
  // parsing fail.
  completion.choices?.forEach((choice) => {
    const fcArgs = choice.message?.function_call?.arguments;
    if (fcArgs) {
      choice.message.function_call.arguments = fcArgs
        .split(/("(?:[^"\\]|\\.)*(?:"|$))/g)
        // Every odd piece is a quoted string
        .map((piece, i) => (
            i % 2 ? piece.replaceAll('\n', '\\n') : piece
        ))
        .join('');
    }
  });
  return completion;
};