axum/response/
sse.rs

1//! Server-Sent Events (SSE) responses.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     Router,
8//!     routing::get,
9//!     response::sse::{Event, KeepAlive, Sse},
10//! };
11//! use std::{time::Duration, convert::Infallible};
12//! use tokio_stream::StreamExt as _ ;
13//! use futures_util::stream::{self, Stream};
14//!
15//! let app = Router::new().route("/sse", get(sse_handler));
16//!
17//! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18//!     // A `Stream` that repeats an event every second
19//!     let stream = stream::repeat_with(|| Event::default().data("hi!"))
20//!         .map(Ok)
21//!         .throttle(Duration::from_secs(1));
22//!
23//!     Sse::new(stream).keep_alive(KeepAlive::default())
24//! }
25//! # let _: Router = app;
26//! ```
27
28use crate::{
29    body::{Bytes, HttpBody},
30    BoxError,
31};
32use axum_core::{
33    body::Body,
34    response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::{
38    ready,
39    stream::{Stream, TryStream},
40};
41use http_body::Frame;
42use pin_project_lite::pin_project;
43use std::{
44    fmt,
45    future::Future,
46    pin::Pin,
47    task::{Context, Poll},
48    time::Duration,
49};
50use sync_wrapper::SyncWrapper;
51use tokio::time::Sleep;
52
53/// An SSE response
54#[derive(Clone)]
55#[must_use]
56pub struct Sse<S> {
57    stream: S,
58    keep_alive: Option<KeepAlive>,
59}
60
61impl<S> Sse<S> {
62    /// Create a new [`Sse`] response that will respond with the given stream of
63    /// [`Event`]s.
64    ///
65    /// See the [module docs](self) for more details.
66    pub fn new(stream: S) -> Self
67    where
68        S: TryStream<Ok = Event> + Send + 'static,
69        S::Error: Into<BoxError>,
70    {
71        Sse {
72            stream,
73            keep_alive: None,
74        }
75    }
76
77    /// Configure the interval between keep-alive messages.
78    ///
79    /// Defaults to no keep-alive messages.
80    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
81        self.keep_alive = Some(keep_alive);
82        self
83    }
84}
85
86impl<S> fmt::Debug for Sse<S> {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("Sse")
89            .field("stream", &format_args!("{}", std::any::type_name::<S>()))
90            .field("keep_alive", &self.keep_alive)
91            .finish()
92    }
93}
94
95impl<S, E> IntoResponse for Sse<S>
96where
97    S: Stream<Item = Result<Event, E>> + Send + 'static,
98    E: Into<BoxError>,
99{
100    fn into_response(self) -> Response {
101        (
102            [
103                (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
104                (http::header::CACHE_CONTROL, "no-cache"),
105            ],
106            Body::new(SseBody {
107                event_stream: SyncWrapper::new(self.stream),
108                keep_alive: self.keep_alive.map(KeepAliveStream::new),
109            }),
110        )
111            .into_response()
112    }
113}
114
115pin_project! {
116    struct SseBody<S> {
117        #[pin]
118        event_stream: SyncWrapper<S>,
119        #[pin]
120        keep_alive: Option<KeepAliveStream>,
121    }
122}
123
124impl<S, E> HttpBody for SseBody<S>
125where
126    S: Stream<Item = Result<Event, E>>,
127{
128    type Data = Bytes;
129    type Error = E;
130
131    fn poll_frame(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
135        let this = self.project();
136
137        match this.event_stream.get_pin_mut().poll_next(cx) {
138            Poll::Pending => {
139                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
140                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
141                } else {
142                    Poll::Pending
143                }
144            }
145            Poll::Ready(Some(Ok(event))) => {
146                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
147                    keep_alive.reset();
148                }
149                Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
150            }
151            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
152            Poll::Ready(None) => Poll::Ready(None),
153        }
154    }
155}
156
157/// Server-sent event
158#[derive(Debug, Default, Clone)]
159#[must_use]
160pub struct Event {
161    buffer: BytesMut,
162    flags: EventFlags,
163}
164
165impl Event {
166    /// Set the event's data data field(s) (`data: <content>`)
167    ///
168    /// Newlines in `data` will automatically be broken across `data: ` fields.
169    ///
170    /// This corresponds to [`MessageEvent`'s data field].
171    ///
172    /// Note that events with an empty data field will be ignored by the browser.
173    ///
174    /// # Panics
175    ///
176    /// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
177    /// - Panics if `data` or `json_data` have already been called.
178    ///
179    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
180    pub fn data<T>(mut self, data: T) -> Event
181    where
182        T: AsRef<str>,
183    {
184        if self.flags.contains(EventFlags::HAS_DATA) {
185            panic!("Called `EventBuilder::data` multiple times");
186        }
187
188        for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
189            self.field("data", line);
190        }
191
192        self.flags.insert(EventFlags::HAS_DATA);
193
194        self
195    }
196
197    /// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`).
198    ///
199    /// This corresponds to [`MessageEvent`'s data field].
200    ///
201    /// # Panics
202    ///
203    /// Panics if `data` or `json_data` have already been called.
204    ///
205    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
206    #[cfg(feature = "json")]
207    pub fn json_data<T>(mut self, data: T) -> Result<Event, axum_core::Error>
208    where
209        T: serde::Serialize,
210    {
211        struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
212        impl std::io::Write for IgnoreNewLines<'_> {
213            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
214                let mut last_split = 0;
215                for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
216                    self.0.write_all(&buf[last_split..delimiter])?;
217                    last_split = delimiter + 1;
218                }
219                self.0.write_all(&buf[last_split..])?;
220                Ok(buf.len())
221            }
222
223            fn flush(&mut self) -> std::io::Result<()> {
224                self.0.flush()
225            }
226        }
227        if self.flags.contains(EventFlags::HAS_DATA) {
228            panic!("Called `EventBuilder::json_data` multiple times");
229        }
230
231        self.buffer.extend_from_slice(b"data: ");
232        serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
233            .map_err(axum_core::Error::new)?;
234        self.buffer.put_u8(b'\n');
235
236        self.flags.insert(EventFlags::HAS_DATA);
237
238        Ok(self)
239    }
240
241    /// Set the event's comment field (`:<comment-text>`).
242    ///
243    /// This field will be ignored by most SSE clients.
244    ///
245    /// Unlike other functions, this function can be called multiple times to add many comments.
246    ///
247    /// # Panics
248    ///
249    /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
250    /// comments.
251    pub fn comment<T>(mut self, comment: T) -> Event
252    where
253        T: AsRef<str>,
254    {
255        self.field("", comment.as_ref());
256        self
257    }
258
259    /// Set the event's name field (`event:<event-name>`).
260    ///
261    /// This corresponds to the `type` parameter given when calling `addEventListener` on an
262    /// [`EventSource`]. For example, `.event("update")` should correspond to
263    /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
264    /// [`message` event] instead.
265    ///
266    /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
267    /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
268    ///
269    /// # Panics
270    ///
271    /// - Panics if `event` contains any newlines or carriage returns.
272    /// - Panics if this function has already been called on this event.
273    pub fn event<T>(mut self, event: T) -> Event
274    where
275        T: AsRef<str>,
276    {
277        if self.flags.contains(EventFlags::HAS_EVENT) {
278            panic!("Called `EventBuilder::event` multiple times");
279        }
280        self.flags.insert(EventFlags::HAS_EVENT);
281
282        self.field("event", event.as_ref());
283
284        self
285    }
286
287    /// Set the event's retry timeout field (`retry:<timeout>`).
288    ///
289    /// This sets how long clients will wait before reconnecting if they are disconnected from the
290    /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
291    /// wish, such as if they implement exponential backoff.
292    ///
293    /// # Panics
294    ///
295    /// Panics if this function has already been called on this event.
296    pub fn retry(mut self, duration: Duration) -> Event {
297        if self.flags.contains(EventFlags::HAS_RETRY) {
298            panic!("Called `EventBuilder::retry` multiple times");
299        }
300        self.flags.insert(EventFlags::HAS_RETRY);
301
302        self.buffer.extend_from_slice(b"retry:");
303
304        let secs = duration.as_secs();
305        let millis = duration.subsec_millis();
306
307        if secs > 0 {
308            // format seconds
309            self.buffer
310                .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
311
312            // pad milliseconds
313            if millis < 10 {
314                self.buffer.extend_from_slice(b"00");
315            } else if millis < 100 {
316                self.buffer.extend_from_slice(b"0");
317            }
318        }
319
320        // format milliseconds
321        self.buffer
322            .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
323
324        self.buffer.put_u8(b'\n');
325
326        self
327    }
328
329    /// Set the event's identifier field (`id:<identifier>`).
330    ///
331    /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
332    /// the browser will set that field to the last known message ID, starting with the empty
333    /// string.
334    ///
335    /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
336    ///
337    /// # Panics
338    ///
339    /// - Panics if `id` contains any newlines, carriage returns or null characters.
340    /// - Panics if this function has already been called on this event.
341    pub fn id<T>(mut self, id: T) -> Event
342    where
343        T: AsRef<str>,
344    {
345        if self.flags.contains(EventFlags::HAS_ID) {
346            panic!("Called `EventBuilder::id` multiple times");
347        }
348        self.flags.insert(EventFlags::HAS_ID);
349
350        let id = id.as_ref().as_bytes();
351        assert_eq!(
352            memchr::memchr(b'\0', id),
353            None,
354            "Event ID cannot contain null characters",
355        );
356
357        self.field("id", id);
358        self
359    }
360
361    fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
362        let value = value.as_ref();
363        assert_eq!(
364            memchr::memchr2(b'\r', b'\n', value),
365            None,
366            "SSE field value cannot contain newlines or carriage returns",
367        );
368        self.buffer.extend_from_slice(name.as_bytes());
369        self.buffer.put_u8(b':');
370        self.buffer.put_u8(b' ');
371        self.buffer.extend_from_slice(value);
372        self.buffer.put_u8(b'\n');
373    }
374
375    fn finalize(mut self) -> Bytes {
376        self.buffer.put_u8(b'\n');
377        self.buffer.freeze()
378    }
379}
380
381#[derive(Default, Debug, Copy, Clone, PartialEq)]
382struct EventFlags(u8);
383
384impl EventFlags {
385    const HAS_DATA: Self = Self::from_bits(0b0001);
386    const HAS_EVENT: Self = Self::from_bits(0b0010);
387    const HAS_RETRY: Self = Self::from_bits(0b0100);
388    const HAS_ID: Self = Self::from_bits(0b1000);
389
390    const fn bits(&self) -> u8 {
391        self.0
392    }
393
394    const fn from_bits(bits: u8) -> Self {
395        Self(bits)
396    }
397
398    const fn contains(&self, other: Self) -> bool {
399        self.bits() & other.bits() == other.bits()
400    }
401
402    fn insert(&mut self, other: Self) {
403        *self = Self::from_bits(self.bits() | other.bits());
404    }
405}
406
407/// Configure the interval between keep-alive messages, the content
408/// of each message, and the associated stream.
409#[derive(Debug, Clone)]
410#[must_use]
411pub struct KeepAlive {
412    event: Bytes,
413    max_interval: Duration,
414}
415
416impl KeepAlive {
417    /// Create a new `KeepAlive`.
418    pub fn new() -> Self {
419        Self {
420            event: Bytes::from_static(b":\n\n"),
421            max_interval: Duration::from_secs(15),
422        }
423    }
424
425    /// Customize the interval between keep-alive messages.
426    ///
427    /// Default is 15 seconds.
428    pub fn interval(mut self, time: Duration) -> Self {
429        self.max_interval = time;
430        self
431    }
432
433    /// Customize the text of the keep-alive message.
434    ///
435    /// Default is an empty comment.
436    ///
437    /// # Panics
438    ///
439    /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
440    /// comments.
441    pub fn text<I>(self, text: I) -> Self
442    where
443        I: AsRef<str>,
444    {
445        self.event(Event::default().comment(text))
446    }
447
448    /// Customize the event of the keep-alive message.
449    ///
450    /// Default is an empty comment.
451    ///
452    /// # Panics
453    ///
454    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
455    /// comments.
456    pub fn event(mut self, event: Event) -> Self {
457        self.event = event.finalize();
458        self
459    }
460}
461
462impl Default for KeepAlive {
463    fn default() -> Self {
464        Self::new()
465    }
466}
467
468pin_project! {
469    #[derive(Debug)]
470    struct KeepAliveStream {
471        keep_alive: KeepAlive,
472        #[pin]
473        alive_timer: Sleep,
474    }
475}
476
477impl KeepAliveStream {
478    fn new(keep_alive: KeepAlive) -> Self {
479        Self {
480            alive_timer: tokio::time::sleep(keep_alive.max_interval),
481            keep_alive,
482        }
483    }
484
485    fn reset(self: Pin<&mut Self>) {
486        let this = self.project();
487        this.alive_timer
488            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
489    }
490
491    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
492        let this = self.as_mut().project();
493
494        ready!(this.alive_timer.poll(cx));
495
496        let event = this.keep_alive.event.clone();
497
498        self.reset();
499
500        Poll::Ready(event)
501    }
502}
503
504fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
505    MemchrSplit {
506        needle,
507        haystack: Some(haystack),
508    }
509}
510
511struct MemchrSplit<'a> {
512    needle: u8,
513    haystack: Option<&'a [u8]>,
514}
515
516impl<'a> Iterator for MemchrSplit<'a> {
517    type Item = &'a [u8];
518    fn next(&mut self) -> Option<Self::Item> {
519        let haystack = self.haystack?;
520        if let Some(pos) = memchr::memchr(self.needle, haystack) {
521            let (front, back) = haystack.split_at(pos);
522            self.haystack = Some(&back[1..]);
523            Some(front)
524        } else {
525            self.haystack.take()
526        }
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::{routing::get, test_helpers::*, Router};
534    use futures_util::stream;
535    use serde_json::value::RawValue;
536    use std::{collections::HashMap, convert::Infallible};
537    use tokio_stream::StreamExt as _;
538
539    #[test]
540    fn leading_space_is_not_stripped() {
541        let no_leading_space = Event::default().data("\tfoobar");
542        assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
543
544        let leading_space = Event::default().data(" foobar");
545        assert_eq!(&*leading_space.finalize(), b"data:  foobar\n\n");
546    }
547
548    #[test]
549    fn valid_json_raw_value_chars_stripped() {
550        let json_string = "{\r\"foo\":  \n\r\r   \"bar\\n\"\n}";
551        let json_raw_value_event = Event::default()
552            .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
553            .unwrap();
554        assert_eq!(
555            &*json_raw_value_event.finalize(),
556            format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
557        );
558    }
559
560    #[crate::test]
561    async fn basic() {
562        let app = Router::new().route(
563            "/",
564            get(|| async {
565                let stream = stream::iter(vec![
566                    Event::default().data("one").comment("this is a comment"),
567                    Event::default()
568                        .json_data(serde_json::json!({ "foo": "bar" }))
569                        .unwrap(),
570                    Event::default()
571                        .event("three")
572                        .retry(Duration::from_secs(30))
573                        .id("unique-id"),
574                ])
575                .map(Ok::<_, Infallible>);
576                Sse::new(stream)
577            }),
578        );
579
580        let client = TestClient::new(app);
581        let mut stream = client.get("/").await;
582
583        assert_eq!(stream.headers()["content-type"], "text/event-stream");
584        assert_eq!(stream.headers()["cache-control"], "no-cache");
585
586        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
587        assert_eq!(event_fields.get("data").unwrap(), "one");
588        assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
589
590        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
591        assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
592        assert!(!event_fields.contains_key("comment"));
593
594        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
595        assert_eq!(event_fields.get("event").unwrap(), "three");
596        assert_eq!(event_fields.get("retry").unwrap(), "30000");
597        assert_eq!(event_fields.get("id").unwrap(), "unique-id");
598        assert!(!event_fields.contains_key("comment"));
599
600        assert!(stream.chunk_text().await.is_none());
601    }
602
603    #[tokio::test(start_paused = true)]
604    async fn keep_alive() {
605        const DELAY: Duration = Duration::from_secs(5);
606
607        let app = Router::new().route(
608            "/",
609            get(|| async {
610                let stream = stream::repeat_with(|| Event::default().data("msg"))
611                    .map(Ok::<_, Infallible>)
612                    .throttle(DELAY);
613
614                Sse::new(stream).keep_alive(
615                    KeepAlive::new()
616                        .interval(Duration::from_secs(1))
617                        .text("keep-alive-text"),
618                )
619            }),
620        );
621
622        let client = TestClient::new(app);
623        let mut stream = client.get("/").await;
624
625        for _ in 0..5 {
626            // first message should be an event
627            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
628            assert_eq!(event_fields.get("data").unwrap(), "msg");
629
630            // then 4 seconds of keep-alive messages
631            for _ in 0..4 {
632                tokio::time::sleep(Duration::from_secs(1)).await;
633                let event_fields = parse_event(&stream.chunk_text().await.unwrap());
634                assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
635            }
636        }
637    }
638
639    #[tokio::test(start_paused = true)]
640    async fn keep_alive_ends_when_the_stream_ends() {
641        const DELAY: Duration = Duration::from_secs(5);
642
643        let app = Router::new().route(
644            "/",
645            get(|| async {
646                let stream = stream::repeat_with(|| Event::default().data("msg"))
647                    .map(Ok::<_, Infallible>)
648                    .throttle(DELAY)
649                    .take(2);
650
651                Sse::new(stream).keep_alive(
652                    KeepAlive::new()
653                        .interval(Duration::from_secs(1))
654                        .text("keep-alive-text"),
655                )
656            }),
657        );
658
659        let client = TestClient::new(app);
660        let mut stream = client.get("/").await;
661
662        // first message should be an event
663        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
664        assert_eq!(event_fields.get("data").unwrap(), "msg");
665
666        // then 4 seconds of keep-alive messages
667        for _ in 0..4 {
668            tokio::time::sleep(Duration::from_secs(1)).await;
669            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
670            assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
671        }
672
673        // then the last event
674        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
675        assert_eq!(event_fields.get("data").unwrap(), "msg");
676
677        // then no more events or keep-alive messages
678        assert!(stream.chunk_text().await.is_none());
679    }
680
681    fn parse_event(payload: &str) -> HashMap<String, String> {
682        let mut fields = HashMap::new();
683
684        let mut lines = payload.lines().peekable();
685        while let Some(line) = lines.next() {
686            if line.is_empty() {
687                assert!(lines.next().is_none());
688                break;
689            }
690
691            let (mut key, value) = line.split_once(':').unwrap();
692            let value = value.trim();
693            if key.is_empty() {
694                key = "comment";
695            }
696            fields.insert(key.to_owned(), value.to_owned());
697        }
698
699        fields
700    }
701
702    #[test]
703    fn memchr_splitting() {
704        assert_eq!(
705            memchr_split(2, &[]).collect::<Vec<_>>(),
706            [&[]] as [&[u8]; 1]
707        );
708        assert_eq!(
709            memchr_split(2, &[2]).collect::<Vec<_>>(),
710            [&[], &[]] as [&[u8]; 2]
711        );
712        assert_eq!(
713            memchr_split(2, &[1]).collect::<Vec<_>>(),
714            [&[1]] as [&[u8]; 1]
715        );
716        assert_eq!(
717            memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
718            [&[1], &[]] as [&[u8]; 2]
719        );
720        assert_eq!(
721            memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
722            [&[], &[1]] as [&[u8]; 2]
723        );
724        assert_eq!(
725            memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
726            [&[1], &[], &[1]] as [&[u8]; 3]
727        );
728    }
729}