Adaptyst
A comprehensive and architecture-agnostic performance analysis tool
Loading...
Searching...
No Matches
socket.hpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2026 CERN
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4#ifndef ADAPTYST_SOCKET_HPP_
5#define ADAPTYST_SOCKET_HPP_
6
7#include "os_detect.h"
8#include <string>
9#include <queue>
10#include <memory>
11#include <iostream>
12#include <filesystem>
13#include <Poco/Net/ServerSocket.h>
14#include <cstring>
15#include <unistd.h>
16#include <fstream>
17#include <poll.h>
18#include <Poco/Buffer.h>
19#include <Poco/Net/NetException.h>
20#include <Poco/StreamCopier.h>
21#include <Poco/FileStream.h>
22#include <Poco/Net/SocketStream.h>
23
24#define UNLIMITED_ACCEPTED -1
25#define NO_TIMEOUT -1
26
27#ifndef FILE_BUFFER_SIZE
28#define FILE_BUFFER_SIZE 1048576
29#endif
30
31namespace adaptyst {
32 namespace net = Poco::Net;
33 namespace fs = std::filesystem;
34
35 class charstreambuf : public std::streambuf {
36 public:
37 charstreambuf(std::unique_ptr<char[]> &begin, unsigned int length) {
38 this->setg(begin.get(), begin.get(), begin.get() + length - 1);
39 }
40 };
41
48 class ConnectionException : public std::exception {
49 public:
51 ConnectionException(std::exception &other) : std::exception(other) { }
52 };
53
59
60 };
61
65 class TimeoutException : public std::exception {
66
67 };
68
72 class Connection {
73 protected:
77 virtual void close() = 0;
78
79 public:
80 virtual ~Connection() { }
81
95 virtual int read(char *buf, unsigned int len, long timeout_seconds) = 0;
96
107 virtual std::string read(long timeout_seconds = NO_TIMEOUT) = 0;
108
118 virtual void write(std::string msg, bool new_line = true) = 0;
119
127 virtual void write(fs::path file) = 0;
128
137 virtual void write(unsigned int len, char *buf) = 0;
138
142 virtual unsigned int get_buf_size() = 0;
143 };
144
148 class Socket : public Connection {
149 protected:
150 virtual void close() = 0;
151
152 public:
153 virtual ~Socket() { }
154
158 virtual std::string get_address() = 0;
159
163 virtual unsigned short get_port() = 0;
164
165 virtual unsigned int get_buf_size() = 0;
166 virtual int read(char *buf, unsigned int len, long timeout_seconds) = 0;
167 virtual std::string read(long timeout_seconds = NO_TIMEOUT) = 0;
168 virtual void write(std::string msg, bool new_line = true) = 0;
169 virtual void write(fs::path file) = 0;
170 virtual void write(unsigned int len, char *buf) = 0;
171 };
172
176 class Acceptor {
177 private:
178 int max_accepted;
179 int accepted;
180
181 protected:
189 Acceptor(int max_accepted) {
190 this->max_accepted = max_accepted;
191 this->accepted = 0;
192 }
193
206 virtual std::unique_ptr<Connection> accept_connection(unsigned int buf_size,
207 long timeout) = 0;
208
212 virtual void close() = 0;
213
214 public:
218 class Factory {
219 public:
227 virtual std::unique_ptr<Acceptor> make_acceptor(int max_accepted) = 0;
228
233 virtual std::string get_type() = 0;
234 };
235
252 std::unique_ptr<Connection> accept(unsigned int buf_size,
253 long timeout = NO_TIMEOUT) {
254 if (this->max_accepted != UNLIMITED_ACCEPTED &&
255 this->accepted >= this->max_accepted) {
256 throw std::runtime_error("Maximum accepted connections reached.");
257 }
258
259 std::unique_ptr<Connection> connection = this->accept_connection(buf_size,
260 timeout);
261 this->accepted++;
262
263 return connection;
264 }
265
266 virtual ~Acceptor() { }
267
276 virtual std::string get_connection_instructions() = 0;
277
282 virtual std::string get_type() = 0;
283 };
284
288 class TCPSocket : public Socket {
289 private:
290 net::StreamSocket socket;
291 std::unique_ptr<char[]> buf;
292 unsigned int buf_size;
293 int start_pos;
294 std::queue<std::string> buffered_msgs;
295
296 protected:
297 void close();
298
299 public:
307 TCPSocket(net::StreamSocket &sock, unsigned int buf_size) {
308 this->socket = sock;
309 this->buf.reset(new char[buf_size]);
310 this->buf_size = buf_size;
311 this->start_pos = 0;
312 }
313
315 this->close();
316 }
317
318 std::string get_address() {
319 return this->socket.address().host().toString();
320 }
321
322 unsigned short get_port() {
323 return this->socket.address().port();
324 }
325
326 unsigned int get_buf_size() {
327 return this->buf_size;
328 }
329
330 int read(char *buf, unsigned int len, long timeout_seconds) {
331 auto set_timeout = [this, &timeout_seconds]() {
332 if (timeout_seconds != NO_TIMEOUT) {
333 this->socket.setReceiveTimeout(Poco::Timespan(timeout_seconds, 0));
334 }
335 };
336
337 auto unset_timeout = [this, &timeout_seconds]() {
338 if (timeout_seconds != NO_TIMEOUT) {
339 this->socket.setReceiveTimeout(Poco::Timespan());
340 }
341 };
342
343 try {
344 set_timeout();
345 int bytes = this->socket.receiveBytes(buf, len);
346 unset_timeout();
347 return bytes;
348 } catch (net::NetException &e) {
349 unset_timeout();
350 throw ConnectionException(e);
351 } catch (Poco::TimeoutException &e) {
352 unset_timeout();
353 throw TimeoutException();
354 }
355 }
356
357 std::string read(long timeout_seconds = NO_TIMEOUT) {
358 try {
359 if (!this->buffered_msgs.empty()) {
360 std::string msg = this->buffered_msgs.front();
361 this->buffered_msgs.pop();
362 return msg;
363 }
364
365 std::string cur_msg = "";
366
367 while (true) {
368 int bytes_received =
369 this->read(this->buf.get() + this->start_pos,
370 this->buf_size - this->start_pos, timeout_seconds);
371
372 if (bytes_received == 0) {
373 return std::string(this->buf.get(), this->start_pos);
374 }
375
376 bool first_msg_to_receive = true;
377 std::string first_msg;
378
379 charstreambuf buf(this->buf, bytes_received + this->start_pos);
380 std::istream in(&buf);
381
382 int cur_pos = 0;
383 bool last_is_newline = this->buf.get()[bytes_received + this->start_pos - 1] == '\n';
384
385 while (!in.eof()) {
386 std::string msg;
387 std::getline(in, msg);
388
389 if (in.eof() && !last_is_newline) {
390 int size = bytes_received + this->start_pos - cur_pos;
391
392 if (size == this->buf_size) {
393 cur_msg += std::string(this->buf.get(), this->buf_size);
394 this->start_pos = 0;
395 } else {
396 std::memmove(this->buf.get(), this->buf.get() + cur_pos, size);
397 this->start_pos = size;
398 }
399 } else {
400 if (!cur_msg.empty() || !msg.empty()) {
401 if (first_msg_to_receive) {
402 first_msg = cur_msg + msg;
403 first_msg_to_receive = false;
404 } else {
405 this->buffered_msgs.push(cur_msg + msg);
406 }
407
408 cur_msg = "";
409 }
410
411 cur_pos += msg.length() + 1;
412 }
413 }
414
415 if (last_is_newline) {
416 this->start_pos = 0;
417 }
418
419 if (!first_msg_to_receive) {
420 return first_msg;
421 }
422 }
423
424 // Should not get here.
425 return "";
426 } catch (net::NetException &e) {
427 throw ConnectionException(e);
428 }
429 }
430
431 void write(std::string msg, bool new_line) {
432 try {
433 if (new_line) {
434 msg += "\n";
435 }
436
437 const char *buf = msg.c_str();
438
439 int bytes_written = this->socket.sendBytes(buf, msg.size());
440
441 if (bytes_written != msg.size()) {
442 std::runtime_error err("Wrote " +
443 std::to_string(bytes_written) +
444 " bytes instead of " +
445 std::to_string(msg.size()) +
446 " to " +
447 this->socket.address().toString());
448 throw ConnectionException(err);
449 }
450 } catch (net::NetException &e) {
451 throw ConnectionException(e);
452 }
453 }
454
455 void write(fs::path file) {
456 try {
457 net::SocketStream socket_stream(this->socket);
458 Poco::FileInputStream stream(file, std::ios::in | std::ios::binary);
459 Poco::StreamCopier::copyStream(stream, socket_stream);
460 } catch (net::NetException &e) {
461 throw ConnectionException(e);
462 }
463 }
464
465 void write(unsigned int len, char *buf) {
466 try {
467 int bytes_written = this->socket.sendBytes(buf, len);
468 if (bytes_written != len) {
469 std::runtime_error err("Wrote " +
470 std::to_string(bytes_written) +
471 " bytes instead of " +
472 std::to_string(len) +
473 " to " +
474 this->socket.address().toString());
475 throw ConnectionException(err);
476 }
477 } catch (net::NetException &e) {
478 throw ConnectionException(e);
479 }
480 }
481 };
482
486 class TCPAcceptor : public Acceptor {
487 private:
488 net::ServerSocket acceptor;
489
490 TCPAcceptor(std::string address, unsigned short port,
491 int max_accepted,
492 bool try_subsequent_ports) : Acceptor(max_accepted) {
493 if (try_subsequent_ports) {
494 bool success = false;
495 while (!success) {
496 try {
497 this->acceptor.bind(net::SocketAddress(address, port), false);
498 success = true;
499 } catch (net::NetException &e) {
500 if (e.message().find("already in use") != std::string::npos) {
501 port++;
502 } else {
503 throw ConnectionException(e);
504 }
505 }
506 }
507 } else {
508 try {
509 this->acceptor.bind(net::SocketAddress(address, port), false);
510 } catch (net::NetException &e) {
511 if (e.message().find("already in use") != std::string::npos) {
512 throw AlreadyInUseException();
513 } else {
514 throw ConnectionException(e);
515 }
516 }
517 }
518
519 try {
520 this->acceptor.listen();
521 } catch (net::NetException &e) {
522 throw ConnectionException(e);
523 }
524 }
525
526 protected:
527 std::unique_ptr<Connection> accept_connection(unsigned int buf_size,
528 long timeout) {
529 try {
530 net::StreamSocket socket = this->acceptor.acceptConnection();
531 return std::make_unique<TCPSocket>(socket, buf_size);
532 } catch (net::NetException &e) {
533 throw ConnectionException(e);
534 }
535 }
536
537 void close() { this->acceptor.close(); }
538
539 public:
543 class Factory : public Acceptor::Factory {
544 private:
545 std::string address;
546 unsigned short port;
547 bool try_subsequent_ports;
548
549 public:
561 Factory(std::string address, unsigned short port,
562 bool try_subsequent_ports = false) {
563 this->address = address;
564 this->port = port;
565 this->try_subsequent_ports = try_subsequent_ports;
566 };
567
568 std::unique_ptr<Acceptor> make_acceptor(int max_accepted) {
569 return std::unique_ptr<Acceptor>(new TCPAcceptor(this->address,
570 this->port,
571 max_accepted,
572 this->try_subsequent_ports));
573 }
574
575 std::string get_type() {
576 return "tcp";
577 }
578 };
579
581 this->close();
582 }
583
588 return this->acceptor.address().host().toString() + "_" + std::to_string(this->acceptor.address().port());
589 }
590
591 std::string get_type() { return "tcp"; }
592 };
593
594#ifdef ADAPTYST_UNIX
599 class FileDescriptor : public Connection {
600 private:
601 int read_fd[2];
602 int write_fd[2];
603 unsigned int buf_size;
604 std::queue<std::string> buffered_msgs;
605 std::unique_ptr<char[]> buf;
606 int start_pos;
607 bool close_on_destruct;
608
609 public:
623 FileDescriptor(int read_fd[2],
624 int write_fd[2],
625 unsigned int buf_size,
626 bool close_on_destruct = true) {
627 this->buf.reset(new char[buf_size]);
628 this->buf_size = buf_size;
629 this->start_pos = 0;
630 this->close_on_destruct = close_on_destruct;
631
632 if (read_fd != nullptr) {
633 this->read_fd[0] = read_fd[0];
634 this->read_fd[1] = read_fd[1];
635 } else {
636 this->read_fd[0] = -1;
637 this->read_fd[1] = -1;
638 }
639
640 if (write_fd != nullptr) {
641 this->write_fd[0] = write_fd[0];
642 this->write_fd[1] = write_fd[1];
643 } else {
644 this->write_fd[0] = -1;
645 this->write_fd[1] = -1;
646 }
647 }
648
649 ~FileDescriptor() {
650 if (this->close_on_destruct) {
651 this->close();
652 }
653 }
654
655 int read(char *buf, unsigned int len, long timeout_seconds) {
656 if (timeout_seconds != NO_TIMEOUT) {
657 struct pollfd poll_struct;
658 poll_struct.fd = this->read_fd[0];
659 poll_struct.events = POLLIN;
660
661 int code = ::poll(&poll_struct, 1, 1000 * timeout_seconds);
662
663 if (code == -1) {
664 throw ConnectionException();
665 } else if (code == 0) {
666 throw TimeoutException();
667 }
668 }
669
670 return ::read(this->read_fd[0], buf, len);
671 }
672
673 std::string read(long timeout_seconds = NO_TIMEOUT) {
674 if (!this->buffered_msgs.empty()) {
675 std::string msg = this->buffered_msgs.front();
676 this->buffered_msgs.pop();
677 return msg;
678 }
679
680 std::string cur_msg = "";
681
682 while (true) {
683 int bytes_received =
684 this->read(this->buf.get() + this->start_pos,
685 this->buf_size - this->start_pos,
686 timeout_seconds);
687
688 if (bytes_received == -1) {
689 throw ConnectionException();
690 } else if (bytes_received == 0) {
691 return std::string(this->buf.get(), this->start_pos);
692 }
693
694 bool first_msg_to_receive = true;
695 std::string first_msg;
696
697 charstreambuf buf(this->buf, bytes_received + this->start_pos);
698 std::istream in(&buf);
699
700 int cur_pos = 0;
701 bool last_is_newline = this->buf.get()[bytes_received + this->start_pos - 1] == '\n';
702
703 while (!in.eof()) {
704 std::string msg;
705 std::getline(in, msg);
706
707 if (in.eof() && !last_is_newline) {
708 int size = bytes_received + this->start_pos - cur_pos;
709
710 if (size == this->buf_size) {
711 cur_msg += std::string(this->buf.get(), this->buf_size);
712 this->start_pos = 0;
713 } else {
714 std::memmove(this->buf.get(), this->buf.get() + cur_pos, size);
715 this->start_pos = size;
716 }
717 } else {
718 if (!cur_msg.empty() || !msg.empty()) {
719 if (first_msg_to_receive) {
720 first_msg = cur_msg + msg;
721 first_msg_to_receive = false;
722 } else {
723 this->buffered_msgs.push(cur_msg + msg);
724 }
725
726 cur_msg = "";
727 }
728
729 cur_pos += msg.length() + 1;
730 }
731 }
732
733 if (last_is_newline) {
734 this->start_pos = 0;
735 }
736
737 if (!first_msg_to_receive) {
738 return first_msg;
739 }
740 }
741
742 // Should not get here.
743 return "";
744 }
745
746 void write(std::string msg, bool new_line) {
747 if (new_line) {
748 msg += "\n";
749 }
750
751 const char *buf = msg.c_str();
752 int written = ::write(this->write_fd[1], buf, msg.size());
753
754 if (written != msg.size()) {
755 std::runtime_error err("Wrote " +
756 std::to_string(written) +
757 " bytes instead of " +
758 std::to_string(msg.size()) +
759 " to fd " +
760 std::to_string(this->write_fd[1]));
761 throw ConnectionException(err);
762 }
763 }
764
765 void write(fs::path file) {
766 std::unique_ptr<char> buf(new char[FILE_BUFFER_SIZE]);
767 std::ifstream file_stream(file, std::ios_base::in |
768 std::ios_base::binary);
769
770 if (!file_stream) {
771 std::runtime_error err("Could not open the file " +
772 file.string() + "!");
773 throw ConnectionException(err);
774 }
775
776 while (file_stream) {
777 file_stream.read(buf.get(), FILE_BUFFER_SIZE);
778 int bytes_read = file_stream.gcount();
779 int bytes_written = ::write(this->write_fd[1], buf.get(),
780 bytes_read);
781
782 if (bytes_written != bytes_read) {
783 std::runtime_error err("Wrote " +
784 std::to_string(bytes_written) +
785 " bytes instead of " +
786 std::to_string(bytes_read) +
787 " to fd " +
788 std::to_string(this->write_fd[1]));
789 throw ConnectionException(err);
790 }
791 }
792 }
793
794 void write(unsigned int len, char *buf) {
795 int bytes_written = ::write(this->write_fd[1], buf, len);
796
797 if (bytes_written != len) {
798 std::runtime_error err("Wrote " +
799 std::to_string(bytes_written) +
800 " bytes instead of " +
801 std::to_string(len) +
802 " to fd " +
803 std::to_string(this->write_fd[1]));
804 throw ConnectionException(err);
805 }
806 }
807
808 unsigned int get_buf_size() {
809 return this->buf_size;
810 }
811
812 void close() {
813 if (this->read_fd[0] != -1) {
814 ::close(this->read_fd[0]);
815 this->read_fd[0] = -1;
816 }
817
818 if (this->write_fd[1] != -1) {
819 ::close(this->write_fd[1]);
820 this->write_fd[1] = -1;
821 }
822 }
823
824 std::pair<int, int> get_read_fd() {
825 return std::make_pair(this->read_fd[0],
826 this->read_fd[1]);
827 }
828
829 std::pair<int, int> get_write_fd() {
830 return std::make_pair(this->write_fd[0],
831 this->write_fd[1]);
832 }
833 };
834
839 class PipeAcceptor : public Acceptor {
840 private:
841 int read_fd[2];
842 int write_fd[2];
843
849 PipeAcceptor() : Acceptor(1) {
850 if (pipe(this->read_fd) != 0) {
851 std::runtime_error err("Could not open read pipe for FileDescriptor, "
852 "code " + std::to_string(errno));
853 throw ConnectionException(err);
854 }
855
856 if (pipe(this->write_fd) != 0) {
857 std::runtime_error err("Could not open write pipe for FileDescriptor, "
858 "code " + std::to_string(errno));
859 throw ConnectionException(err);
860 }
861 }
862
863 protected:
864 std::unique_ptr<Connection> accept_connection(unsigned int buf_size,
865 long timeout) {
866 std::string expected = "connect";
867 const int size = expected.size();
868
869 char buf[size];
870 int bytes_received = 0;
871
872 while (bytes_received < size) {
873 if (timeout != NO_TIMEOUT) {
874 struct pollfd poll_struct;
875 poll_struct.fd = this->read_fd[0];
876 poll_struct.events = POLLIN;
877
878 int code = ::poll(&poll_struct, 1, 1000 * timeout);
879
880 if (code == -1) {
881 throw ConnectionException();
882 } else if (code == 0) {
883 throw TimeoutException();
884 }
885 }
886
887 int received = ::read(this->read_fd[0], buf + bytes_received,
888 size - bytes_received);
889
890 if (received <= 0) {
891 break;
892 }
893
894 bytes_received += received;
895 }
896
897 std::string msg(buf, size);
898
899 if (msg != expected) {
900 std::runtime_error err("Message received from pipe when establishing connection "
901 "is \"" + msg + "\" instead of \"" + expected + "\".");
902 throw ConnectionException(err);
903 }
904
905 return std::unique_ptr<Connection>(new FileDescriptor(this->read_fd,
906 this->write_fd,
907 buf_size));
908 }
909
910 void close() {}
911
912 public:
916 class Factory : public Acceptor::Factory {
917 public:
926 std::unique_ptr<Acceptor> make_acceptor(int max_accepted) {
927 if (max_accepted != 1) {
928 throw std::runtime_error("max_accepted can only be 1 for FileDescriptor");
929 }
930
931 return std::unique_ptr<Acceptor>(new PipeAcceptor());
932 }
933
934 std::string get_type() {
935 return "pipe";
936 }
937 };
938
942 std::string get_connection_instructions() {
943 return std::to_string(this->write_fd[0]) + "_" + std::to_string(this->read_fd[1]);
944 }
945
946 std::string get_type() { return "pipe"; }
947 };
948#endif
949}
950
951#endif
Definition socket.hpp:218
virtual std::string get_type()=0
virtual std::unique_ptr< Acceptor > make_acceptor(int max_accepted)=0
Definition socket.hpp:176
virtual std::string get_type()=0
virtual std::string get_connection_instructions()=0
Acceptor(int max_accepted)
Definition socket.hpp:189
virtual void close()=0
virtual std::unique_ptr< Connection > accept_connection(unsigned int buf_size, long timeout)=0
std::unique_ptr< Connection > accept(unsigned int buf_size, long timeout=NO_TIMEOUT)
Definition socket.hpp:252
virtual ~Acceptor()
Definition socket.hpp:266
Definition socket.hpp:58
Definition socket.hpp:48
ConnectionException(std::exception &other)
Definition socket.hpp:51
ConnectionException()
Definition socket.hpp:50
Definition socket.hpp:72
virtual void write(unsigned int len, char *buf)=0
virtual std::string read(long timeout_seconds=NO_TIMEOUT)=0
virtual void write(std::string msg, bool new_line=true)=0
virtual unsigned int get_buf_size()=0
virtual void write(fs::path file)=0
virtual int read(char *buf, unsigned int len, long timeout_seconds)=0
virtual void close()=0
virtual ~Connection()
Definition socket.hpp:80
Definition socket.hpp:148
virtual std::string read(long timeout_seconds=NO_TIMEOUT)=0
virtual std::string get_address()=0
virtual unsigned int get_buf_size()=0
virtual unsigned short get_port()=0
virtual int read(char *buf, unsigned int len, long timeout_seconds)=0
virtual void write(fs::path file)=0
virtual void write(std::string msg, bool new_line=true)=0
virtual void write(unsigned int len, char *buf)=0
virtual void close()=0
virtual ~Socket()
Definition socket.hpp:153
std::unique_ptr< Acceptor > make_acceptor(int max_accepted)
Definition socket.hpp:568
std::string get_type()
Definition socket.hpp:575
Factory(std::string address, unsigned short port, bool try_subsequent_ports=false)
Definition socket.hpp:561
void close()
Definition socket.hpp:537
std::string get_type()
Definition socket.hpp:591
~TCPAcceptor()
Definition socket.hpp:580
std::unique_ptr< Connection > accept_connection(unsigned int buf_size, long timeout)
Definition socket.hpp:527
std::string get_connection_instructions()
Definition socket.hpp:587
std::string read(long timeout_seconds=NO_TIMEOUT)
Definition socket.hpp:357
void write(unsigned int len, char *buf)
Definition socket.hpp:465
void write(std::string msg, bool new_line)
Definition socket.hpp:431
TCPSocket(net::StreamSocket &sock, unsigned int buf_size)
Definition socket.hpp:307
unsigned int get_buf_size()
Definition socket.hpp:326
unsigned short get_port()
Definition socket.hpp:322
std::string get_address()
Definition socket.hpp:318
void write(fs::path file)
Definition socket.hpp:455
~TCPSocket()
Definition socket.hpp:314
int read(char *buf, unsigned int len, long timeout_seconds)
Definition socket.hpp:330
Definition socket.hpp:65
Definition socket.hpp:35
charstreambuf(std::unique_ptr< char[]> &begin, unsigned int length)
Definition socket.hpp:37
Definition archive.cpp:7
#define FILE_BUFFER_SIZE
Definition socket.hpp:28
#define NO_TIMEOUT
Definition socket.hpp:25
#define UNLIMITED_ACCEPTED
Definition socket.hpp:24