1use crate::{
29 body::{Bytes, HttpBody},
30 BoxError,
31};
32use axum_core::{
33 body::Body,
34 response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::{
38 ready,
39 stream::{Stream, TryStream},
40};
41use http_body::Frame;
42use pin_project_lite::pin_project;
43use std::{
44 fmt,
45 future::Future,
46 pin::Pin,
47 task::{Context, Poll},
48 time::Duration,
49};
50use sync_wrapper::SyncWrapper;
51use tokio::time::Sleep;
52
53#[derive(Clone)]
55#[must_use]
56pub struct Sse<S> {
57 stream: S,
58 keep_alive: Option<KeepAlive>,
59}
60
61impl<S> Sse<S> {
62 pub fn new(stream: S) -> Self
67 where
68 S: TryStream<Ok = Event> + Send + 'static,
69 S::Error: Into<BoxError>,
70 {
71 Sse {
72 stream,
73 keep_alive: None,
74 }
75 }
76
77 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
81 self.keep_alive = Some(keep_alive);
82 self
83 }
84}
85
86impl<S> fmt::Debug for Sse<S> {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("Sse")
89 .field("stream", &format_args!("{}", std::any::type_name::<S>()))
90 .field("keep_alive", &self.keep_alive)
91 .finish()
92 }
93}
94
95impl<S, E> IntoResponse for Sse<S>
96where
97 S: Stream<Item = Result<Event, E>> + Send + 'static,
98 E: Into<BoxError>,
99{
100 fn into_response(self) -> Response {
101 (
102 [
103 (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
104 (http::header::CACHE_CONTROL, "no-cache"),
105 ],
106 Body::new(SseBody {
107 event_stream: SyncWrapper::new(self.stream),
108 keep_alive: self.keep_alive.map(KeepAliveStream::new),
109 }),
110 )
111 .into_response()
112 }
113}
114
115pin_project! {
116 struct SseBody<S> {
117 #[pin]
118 event_stream: SyncWrapper<S>,
119 #[pin]
120 keep_alive: Option<KeepAliveStream>,
121 }
122}
123
124impl<S, E> HttpBody for SseBody<S>
125where
126 S: Stream<Item = Result<Event, E>>,
127{
128 type Data = Bytes;
129 type Error = E;
130
131 fn poll_frame(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
135 let this = self.project();
136
137 match this.event_stream.get_pin_mut().poll_next(cx) {
138 Poll::Pending => {
139 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
140 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
141 } else {
142 Poll::Pending
143 }
144 }
145 Poll::Ready(Some(Ok(event))) => {
146 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
147 keep_alive.reset();
148 }
149 Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
150 }
151 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
152 Poll::Ready(None) => Poll::Ready(None),
153 }
154 }
155}
156
157#[derive(Debug, Default, Clone)]
159#[must_use]
160pub struct Event {
161 buffer: BytesMut,
162 flags: EventFlags,
163}
164
165impl Event {
166 pub fn data<T>(mut self, data: T) -> Event
181 where
182 T: AsRef<str>,
183 {
184 if self.flags.contains(EventFlags::HAS_DATA) {
185 panic!("Called `EventBuilder::data` multiple times");
186 }
187
188 for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
189 self.field("data", line);
190 }
191
192 self.flags.insert(EventFlags::HAS_DATA);
193
194 self
195 }
196
197 #[cfg(feature = "json")]
207 pub fn json_data<T>(mut self, data: T) -> Result<Event, axum_core::Error>
208 where
209 T: serde::Serialize,
210 {
211 struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
212 impl std::io::Write for IgnoreNewLines<'_> {
213 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
214 let mut last_split = 0;
215 for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
216 self.0.write_all(&buf[last_split..delimiter])?;
217 last_split = delimiter + 1;
218 }
219 self.0.write_all(&buf[last_split..])?;
220 Ok(buf.len())
221 }
222
223 fn flush(&mut self) -> std::io::Result<()> {
224 self.0.flush()
225 }
226 }
227 if self.flags.contains(EventFlags::HAS_DATA) {
228 panic!("Called `EventBuilder::json_data` multiple times");
229 }
230
231 self.buffer.extend_from_slice(b"data: ");
232 serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
233 .map_err(axum_core::Error::new)?;
234 self.buffer.put_u8(b'\n');
235
236 self.flags.insert(EventFlags::HAS_DATA);
237
238 Ok(self)
239 }
240
241 pub fn comment<T>(mut self, comment: T) -> Event
252 where
253 T: AsRef<str>,
254 {
255 self.field("", comment.as_ref());
256 self
257 }
258
259 pub fn event<T>(mut self, event: T) -> Event
274 where
275 T: AsRef<str>,
276 {
277 if self.flags.contains(EventFlags::HAS_EVENT) {
278 panic!("Called `EventBuilder::event` multiple times");
279 }
280 self.flags.insert(EventFlags::HAS_EVENT);
281
282 self.field("event", event.as_ref());
283
284 self
285 }
286
287 pub fn retry(mut self, duration: Duration) -> Event {
297 if self.flags.contains(EventFlags::HAS_RETRY) {
298 panic!("Called `EventBuilder::retry` multiple times");
299 }
300 self.flags.insert(EventFlags::HAS_RETRY);
301
302 self.buffer.extend_from_slice(b"retry:");
303
304 let secs = duration.as_secs();
305 let millis = duration.subsec_millis();
306
307 if secs > 0 {
308 self.buffer
310 .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
311
312 if millis < 10 {
314 self.buffer.extend_from_slice(b"00");
315 } else if millis < 100 {
316 self.buffer.extend_from_slice(b"0");
317 }
318 }
319
320 self.buffer
322 .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
323
324 self.buffer.put_u8(b'\n');
325
326 self
327 }
328
329 pub fn id<T>(mut self, id: T) -> Event
342 where
343 T: AsRef<str>,
344 {
345 if self.flags.contains(EventFlags::HAS_ID) {
346 panic!("Called `EventBuilder::id` multiple times");
347 }
348 self.flags.insert(EventFlags::HAS_ID);
349
350 let id = id.as_ref().as_bytes();
351 assert_eq!(
352 memchr::memchr(b'\0', id),
353 None,
354 "Event ID cannot contain null characters",
355 );
356
357 self.field("id", id);
358 self
359 }
360
361 fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
362 let value = value.as_ref();
363 assert_eq!(
364 memchr::memchr2(b'\r', b'\n', value),
365 None,
366 "SSE field value cannot contain newlines or carriage returns",
367 );
368 self.buffer.extend_from_slice(name.as_bytes());
369 self.buffer.put_u8(b':');
370 self.buffer.put_u8(b' ');
371 self.buffer.extend_from_slice(value);
372 self.buffer.put_u8(b'\n');
373 }
374
375 fn finalize(mut self) -> Bytes {
376 self.buffer.put_u8(b'\n');
377 self.buffer.freeze()
378 }
379}
380
381#[derive(Default, Debug, Copy, Clone, PartialEq)]
382struct EventFlags(u8);
383
384impl EventFlags {
385 const HAS_DATA: Self = Self::from_bits(0b0001);
386 const HAS_EVENT: Self = Self::from_bits(0b0010);
387 const HAS_RETRY: Self = Self::from_bits(0b0100);
388 const HAS_ID: Self = Self::from_bits(0b1000);
389
390 const fn bits(&self) -> u8 {
391 self.0
392 }
393
394 const fn from_bits(bits: u8) -> Self {
395 Self(bits)
396 }
397
398 const fn contains(&self, other: Self) -> bool {
399 self.bits() & other.bits() == other.bits()
400 }
401
402 fn insert(&mut self, other: Self) {
403 *self = Self::from_bits(self.bits() | other.bits());
404 }
405}
406
407#[derive(Debug, Clone)]
410#[must_use]
411pub struct KeepAlive {
412 event: Bytes,
413 max_interval: Duration,
414}
415
416impl KeepAlive {
417 pub fn new() -> Self {
419 Self {
420 event: Bytes::from_static(b":\n\n"),
421 max_interval: Duration::from_secs(15),
422 }
423 }
424
425 pub fn interval(mut self, time: Duration) -> Self {
429 self.max_interval = time;
430 self
431 }
432
433 pub fn text<I>(self, text: I) -> Self
442 where
443 I: AsRef<str>,
444 {
445 self.event(Event::default().comment(text))
446 }
447
448 pub fn event(mut self, event: Event) -> Self {
457 self.event = event.finalize();
458 self
459 }
460}
461
462impl Default for KeepAlive {
463 fn default() -> Self {
464 Self::new()
465 }
466}
467
468pin_project! {
469 #[derive(Debug)]
470 struct KeepAliveStream {
471 keep_alive: KeepAlive,
472 #[pin]
473 alive_timer: Sleep,
474 }
475}
476
477impl KeepAliveStream {
478 fn new(keep_alive: KeepAlive) -> Self {
479 Self {
480 alive_timer: tokio::time::sleep(keep_alive.max_interval),
481 keep_alive,
482 }
483 }
484
485 fn reset(self: Pin<&mut Self>) {
486 let this = self.project();
487 this.alive_timer
488 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
489 }
490
491 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
492 let this = self.as_mut().project();
493
494 ready!(this.alive_timer.poll(cx));
495
496 let event = this.keep_alive.event.clone();
497
498 self.reset();
499
500 Poll::Ready(event)
501 }
502}
503
504fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
505 MemchrSplit {
506 needle,
507 haystack: Some(haystack),
508 }
509}
510
511struct MemchrSplit<'a> {
512 needle: u8,
513 haystack: Option<&'a [u8]>,
514}
515
516impl<'a> Iterator for MemchrSplit<'a> {
517 type Item = &'a [u8];
518 fn next(&mut self) -> Option<Self::Item> {
519 let haystack = self.haystack?;
520 if let Some(pos) = memchr::memchr(self.needle, haystack) {
521 let (front, back) = haystack.split_at(pos);
522 self.haystack = Some(&back[1..]);
523 Some(front)
524 } else {
525 self.haystack.take()
526 }
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::{routing::get, test_helpers::*, Router};
534 use futures_util::stream;
535 use serde_json::value::RawValue;
536 use std::{collections::HashMap, convert::Infallible};
537 use tokio_stream::StreamExt as _;
538
539 #[test]
540 fn leading_space_is_not_stripped() {
541 let no_leading_space = Event::default().data("\tfoobar");
542 assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
543
544 let leading_space = Event::default().data(" foobar");
545 assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
546 }
547
548 #[test]
549 fn valid_json_raw_value_chars_stripped() {
550 let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}";
551 let json_raw_value_event = Event::default()
552 .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
553 .unwrap();
554 assert_eq!(
555 &*json_raw_value_event.finalize(),
556 format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
557 );
558 }
559
560 #[crate::test]
561 async fn basic() {
562 let app = Router::new().route(
563 "/",
564 get(|| async {
565 let stream = stream::iter(vec![
566 Event::default().data("one").comment("this is a comment"),
567 Event::default()
568 .json_data(serde_json::json!({ "foo": "bar" }))
569 .unwrap(),
570 Event::default()
571 .event("three")
572 .retry(Duration::from_secs(30))
573 .id("unique-id"),
574 ])
575 .map(Ok::<_, Infallible>);
576 Sse::new(stream)
577 }),
578 );
579
580 let client = TestClient::new(app);
581 let mut stream = client.get("/").await;
582
583 assert_eq!(stream.headers()["content-type"], "text/event-stream");
584 assert_eq!(stream.headers()["cache-control"], "no-cache");
585
586 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
587 assert_eq!(event_fields.get("data").unwrap(), "one");
588 assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
589
590 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
591 assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
592 assert!(!event_fields.contains_key("comment"));
593
594 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
595 assert_eq!(event_fields.get("event").unwrap(), "three");
596 assert_eq!(event_fields.get("retry").unwrap(), "30000");
597 assert_eq!(event_fields.get("id").unwrap(), "unique-id");
598 assert!(!event_fields.contains_key("comment"));
599
600 assert!(stream.chunk_text().await.is_none());
601 }
602
603 #[tokio::test(start_paused = true)]
604 async fn keep_alive() {
605 const DELAY: Duration = Duration::from_secs(5);
606
607 let app = Router::new().route(
608 "/",
609 get(|| async {
610 let stream = stream::repeat_with(|| Event::default().data("msg"))
611 .map(Ok::<_, Infallible>)
612 .throttle(DELAY);
613
614 Sse::new(stream).keep_alive(
615 KeepAlive::new()
616 .interval(Duration::from_secs(1))
617 .text("keep-alive-text"),
618 )
619 }),
620 );
621
622 let client = TestClient::new(app);
623 let mut stream = client.get("/").await;
624
625 for _ in 0..5 {
626 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
628 assert_eq!(event_fields.get("data").unwrap(), "msg");
629
630 for _ in 0..4 {
632 tokio::time::sleep(Duration::from_secs(1)).await;
633 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
634 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
635 }
636 }
637 }
638
639 #[tokio::test(start_paused = true)]
640 async fn keep_alive_ends_when_the_stream_ends() {
641 const DELAY: Duration = Duration::from_secs(5);
642
643 let app = Router::new().route(
644 "/",
645 get(|| async {
646 let stream = stream::repeat_with(|| Event::default().data("msg"))
647 .map(Ok::<_, Infallible>)
648 .throttle(DELAY)
649 .take(2);
650
651 Sse::new(stream).keep_alive(
652 KeepAlive::new()
653 .interval(Duration::from_secs(1))
654 .text("keep-alive-text"),
655 )
656 }),
657 );
658
659 let client = TestClient::new(app);
660 let mut stream = client.get("/").await;
661
662 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
664 assert_eq!(event_fields.get("data").unwrap(), "msg");
665
666 for _ in 0..4 {
668 tokio::time::sleep(Duration::from_secs(1)).await;
669 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
670 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
671 }
672
673 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
675 assert_eq!(event_fields.get("data").unwrap(), "msg");
676
677 assert!(stream.chunk_text().await.is_none());
679 }
680
681 fn parse_event(payload: &str) -> HashMap<String, String> {
682 let mut fields = HashMap::new();
683
684 let mut lines = payload.lines().peekable();
685 while let Some(line) = lines.next() {
686 if line.is_empty() {
687 assert!(lines.next().is_none());
688 break;
689 }
690
691 let (mut key, value) = line.split_once(':').unwrap();
692 let value = value.trim();
693 if key.is_empty() {
694 key = "comment";
695 }
696 fields.insert(key.to_owned(), value.to_owned());
697 }
698
699 fields
700 }
701
702 #[test]
703 fn memchr_splitting() {
704 assert_eq!(
705 memchr_split(2, &[]).collect::<Vec<_>>(),
706 [&[]] as [&[u8]; 1]
707 );
708 assert_eq!(
709 memchr_split(2, &[2]).collect::<Vec<_>>(),
710 [&[], &[]] as [&[u8]; 2]
711 );
712 assert_eq!(
713 memchr_split(2, &[1]).collect::<Vec<_>>(),
714 [&[1]] as [&[u8]; 1]
715 );
716 assert_eq!(
717 memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
718 [&[1], &[]] as [&[u8]; 2]
719 );
720 assert_eq!(
721 memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
722 [&[], &[1]] as [&[u8]; 2]
723 );
724 assert_eq!(
725 memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
726 [&[1], &[], &[1]] as [&[u8]; 3]
727 );
728 }
729}