Adaptyst
A comprehensive and architecture-agnostic performance analysis tool
Loading...
Searching...
No Matches
socket.hpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2025 CERN
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4#ifndef ADAPTYST_SOCKET_HPP_
5#define ADAPTYST_SOCKET_HPP_
6
7#include "os_detect.h"
8#include <string>
9#include <queue>
10#include <memory>
11#include <iostream>
12#include <filesystem>
13#include <Poco/Net/ServerSocket.h>
14#include <cstring>
15#include <unistd.h>
16#include <fstream>
17#include <poll.h>
18#include <Poco/Buffer.h>
19#include <Poco/Net/NetException.h>
20#include <Poco/StreamCopier.h>
21#include <Poco/FileStream.h>
22#include <Poco/Net/SocketStream.h>
23
24#define UNLIMITED_ACCEPTED -1
25#define NO_TIMEOUT -1
26
27#ifndef FILE_BUFFER_SIZE
28#define FILE_BUFFER_SIZE 1048576
29#endif
30
31namespace adaptyst {
32 namespace net = Poco::Net;
33 namespace fs = std::filesystem;
34
35 class charstreambuf : public std::streambuf {
36 public:
37 charstreambuf(std::unique_ptr<char[]> &begin, unsigned int length) {
38 this->setg(begin.get(), begin.get(), begin.get() + length - 1);
39 }
40 };
41
48 class ConnectionException : public std::exception {
49 public:
51 ConnectionException(std::exception &other) : std::exception(other) { }
52 };
53
59
60 };
61
65 class TimeoutException : public std::exception {
66
67 };
68
72 class Connection {
73 protected:
77 virtual void close() = 0;
78
79 public:
80 virtual ~Connection() { }
81
94 virtual int read(char *buf, unsigned int len, long timeout_seconds) = 0;
95
106 virtual std::string read(long timeout_seconds = NO_TIMEOUT) = 0;
107
117 virtual void write(std::string msg, bool new_line = true) = 0;
118
126 virtual void write(fs::path file) = 0;
127
136 virtual void write(unsigned int len, char *buf) = 0;
137
141 virtual unsigned int get_buf_size() = 0;
142 };
143
147 class Socket : public Connection {
148 protected:
149 virtual void close() = 0;
150
151 public:
152 virtual ~Socket() { }
153
157 virtual std::string get_address() = 0;
158
162 virtual unsigned short get_port() = 0;
163
164 virtual unsigned int get_buf_size() = 0;
165 virtual int read(char *buf, unsigned int len, long timeout_seconds) = 0;
166 virtual std::string read(long timeout_seconds = NO_TIMEOUT) = 0;
167 virtual void write(std::string msg, bool new_line = true) = 0;
168 virtual void write(fs::path file) = 0;
169 virtual void write(unsigned int len, char *buf) = 0;
170 };
171
175 class Acceptor {
176 private:
177 int max_accepted;
178 int accepted;
179
180 protected:
188 Acceptor(int max_accepted) {
189 this->max_accepted = max_accepted;
190 this->accepted = 0;
191 }
192
205 virtual std::unique_ptr<Connection> accept_connection(unsigned int buf_size,
206 long timeout) = 0;
207
211 virtual void close() = 0;
212
213 public:
217 class Factory {
218 public:
226 virtual std::unique_ptr<Acceptor> make_acceptor(int max_accepted) = 0;
227
232 virtual std::string get_type() = 0;
233 };
234
251 std::unique_ptr<Connection> accept(unsigned int buf_size,
252 long timeout = NO_TIMEOUT) {
253 if (this->max_accepted != UNLIMITED_ACCEPTED &&
254 this->accepted >= this->max_accepted) {
255 throw std::runtime_error("Maximum accepted connections reached.");
256 }
257
258 std::unique_ptr<Connection> connection = this->accept_connection(buf_size,
259 timeout);
260 this->accepted++;
261
262 return connection;
263 }
264
265 virtual ~Acceptor() { }
266
275 virtual std::string get_connection_instructions() = 0;
276
281 virtual std::string get_type() = 0;
282 };
283
287 class TCPSocket : public Socket {
288 private:
289 net::StreamSocket socket;
290 std::unique_ptr<char[]> buf;
291 unsigned int buf_size;
292 int start_pos;
293 std::queue<std::string> buffered_msgs;
294
295 protected:
296 void close();
297
298 public:
306 TCPSocket(net::StreamSocket &sock, unsigned int buf_size) {
307 this->socket = sock;
308 this->buf.reset(new char[buf_size]);
309 this->buf_size = buf_size;
310 this->start_pos = 0;
311 }
312
314 this->close();
315 }
316
317 std::string get_address() {
318 return this->socket.address().host().toString();
319 }
320
321 unsigned short get_port() {
322 return this->socket.address().port();
323 }
324
325 unsigned int get_buf_size() {
326 return this->buf_size;
327 }
328
329 int read(char *buf, unsigned int len, long timeout_seconds) {
330 try {
331 this->socket.setReceiveTimeout(Poco::Timespan(timeout_seconds, 0));
332 int bytes = this->socket.receiveBytes(buf, len);
333 this->socket.setReceiveTimeout(Poco::Timespan());
334 return bytes;
335 } catch (net::NetException &e) {
336 this->socket.setReceiveTimeout(Poco::Timespan());
337 throw ConnectionException(e);
338 } catch (Poco::TimeoutException &e) {
339 this->socket.setReceiveTimeout(Poco::Timespan());
340 throw TimeoutException();
341 }
342 }
343
344 std::string read(long timeout_seconds = NO_TIMEOUT) {
345 try {
346 if (!this->buffered_msgs.empty()) {
347 std::string msg = this->buffered_msgs.front();
348 this->buffered_msgs.pop();
349 return msg;
350 }
351
352 std::string cur_msg = "";
353
354 while (true) {
355 int bytes_received;
356
357 if (timeout_seconds == NO_TIMEOUT) {
358 bytes_received =
359 this->socket.receiveBytes(this->buf.get() + this->start_pos,
360 this->buf_size - this->start_pos);
361 } else {
362 bytes_received =
363 this->read(this->buf.get() + this->start_pos,
364 this->buf_size - this->start_pos, timeout_seconds);
365 }
366
367 if (bytes_received == 0) {
368 return std::string(this->buf.get(), this->start_pos);
369 }
370
371 bool first_msg_to_receive = true;
372 std::string first_msg;
373
374 charstreambuf buf(this->buf, bytes_received + this->start_pos);
375 std::istream in(&buf);
376
377 int cur_pos = 0;
378 bool last_is_newline = this->buf.get()[bytes_received + this->start_pos - 1] == '\n';
379
380 while (!in.eof()) {
381 std::string msg;
382 std::getline(in, msg);
383
384 if (in.eof() && !last_is_newline) {
385 int size = bytes_received + this->start_pos - cur_pos;
386
387 if (size == this->buf_size) {
388 cur_msg += std::string(this->buf.get(), this->buf_size);
389 this->start_pos = 0;
390 } else {
391 std::memmove(this->buf.get(), this->buf.get() + cur_pos, size);
392 this->start_pos = size;
393 }
394 } else {
395 if (!cur_msg.empty() || !msg.empty()) {
396 if (first_msg_to_receive) {
397 first_msg = cur_msg + msg;
398 first_msg_to_receive = false;
399 } else {
400 this->buffered_msgs.push(cur_msg + msg);
401 }
402
403 cur_msg = "";
404 }
405
406 cur_pos += msg.length() + 1;
407 }
408 }
409
410 if (last_is_newline) {
411 this->start_pos = 0;
412 }
413
414 if (!first_msg_to_receive) {
415 return first_msg;
416 }
417 }
418
419 // Should not get here.
420 return "";
421 } catch (net::NetException &e) {
422 throw ConnectionException(e);
423 }
424 }
425
426 void write(std::string msg, bool new_line) {
427 try {
428 if (new_line) {
429 msg += "\n";
430 }
431
432 const char *buf = msg.c_str();
433
434 int bytes_written = this->socket.sendBytes(buf, msg.size());
435
436 if (bytes_written != msg.size()) {
437 std::runtime_error err("Wrote " +
438 std::to_string(bytes_written) +
439 " bytes instead of " +
440 std::to_string(msg.size()) +
441 " to " +
442 this->socket.address().toString());
443 throw ConnectionException(err);
444 }
445 } catch (net::NetException &e) {
446 throw ConnectionException(e);
447 }
448 }
449
450 void write(fs::path file) {
451 try {
452 net::SocketStream socket_stream(this->socket);
453 Poco::FileInputStream stream(file, std::ios::in | std::ios::binary);
454 Poco::StreamCopier::copyStream(stream, socket_stream);
455 } catch (net::NetException &e) {
456 throw ConnectionException(e);
457 }
458 }
459
460 void write(unsigned int len, char *buf) {
461 try {
462 int bytes_written = this->socket.sendBytes(buf, len);
463 if (bytes_written != len) {
464 std::runtime_error err("Wrote " +
465 std::to_string(bytes_written) +
466 " bytes instead of " +
467 std::to_string(len) +
468 " to " +
469 this->socket.address().toString());
470 throw ConnectionException(err);
471 }
472 } catch (net::NetException &e) {
473 throw ConnectionException(e);
474 }
475 }
476 };
477
481 class TCPAcceptor : public Acceptor {
482 private:
483 net::ServerSocket acceptor;
484
485 TCPAcceptor(std::string address, unsigned short port,
486 int max_accepted,
487 bool try_subsequent_ports) : Acceptor(max_accepted) {
488 if (try_subsequent_ports) {
489 bool success = false;
490 while (!success) {
491 try {
492 this->acceptor.bind(net::SocketAddress(address, port), false);
493 success = true;
494 } catch (net::NetException &e) {
495 if (e.message().find("already in use") != std::string::npos) {
496 port++;
497 } else {
498 throw ConnectionException(e);
499 }
500 }
501 }
502 } else {
503 try {
504 this->acceptor.bind(net::SocketAddress(address, port), false);
505 } catch (net::NetException &e) {
506 if (e.message().find("already in use") != std::string::npos) {
507 throw AlreadyInUseException();
508 } else {
509 throw ConnectionException(e);
510 }
511 }
512 }
513
514 try {
515 this->acceptor.listen();
516 } catch (net::NetException &e) {
517 throw ConnectionException(e);
518 }
519 }
520
521 protected:
522 std::unique_ptr<Connection> accept_connection(unsigned int buf_size,
523 long timeout) {
524 try {
525 net::StreamSocket socket = this->acceptor.acceptConnection();
526 return std::make_unique<TCPSocket>(socket, buf_size);
527 } catch (net::NetException &e) {
528 throw ConnectionException(e);
529 }
530 }
531
532 void close() { this->acceptor.close(); }
533
534 public:
538 class Factory : public Acceptor::Factory {
539 private:
540 std::string address;
541 unsigned short port;
542 bool try_subsequent_ports;
543
544 public:
556 Factory(std::string address, unsigned short port,
557 bool try_subsequent_ports = false) {
558 this->address = address;
559 this->port = port;
560 this->try_subsequent_ports = try_subsequent_ports;
561 };
562
563 std::unique_ptr<Acceptor> make_acceptor(int max_accepted) {
564 return std::unique_ptr<Acceptor>(new TCPAcceptor(this->address,
565 this->port,
566 max_accepted,
567 this->try_subsequent_ports));
568 }
569
570 std::string get_type() {
571 return "tcp";
572 }
573 };
574
576 this->close();
577 }
578
583 return this->acceptor.address().host().toString() + "_" + std::to_string(this->acceptor.address().port());
584 }
585
586 std::string get_type() { return "tcp"; }
587 };
588
589#ifdef ADAPTYST_UNIX
594 class FileDescriptor : public Connection {
595 private:
596 int read_fd[2];
597 int write_fd[2];
598 unsigned int buf_size;
599 std::queue<std::string> buffered_msgs;
600 std::unique_ptr<char[]> buf;
601 int start_pos;
602
603 public:
615 FileDescriptor(int read_fd[2],
616 int write_fd[2],
617 unsigned int buf_size) {
618 this->buf.reset(new char[buf_size]);
619 this->buf_size = buf_size;
620 this->start_pos = 0;
621
622 if (read_fd != nullptr) {
623 this->read_fd[0] = read_fd[0];
624 this->read_fd[1] = read_fd[1];
625 } else {
626 this->read_fd[0] = -1;
627 this->read_fd[1] = -1;
628 }
629
630 if (write_fd != nullptr) {
631 this->write_fd[0] = write_fd[0];
632 this->write_fd[1] = write_fd[1];
633 } else {
634 this->write_fd[0] = -1;
635 this->write_fd[1] = -1;
636 }
637 }
638
639 ~FileDescriptor() {
640 this->close();
641 }
642
643 int read(char *buf, unsigned int len, long timeout_seconds) {
644 struct pollfd poll_struct;
645 poll_struct.fd = this->read_fd[0];
646 poll_struct.events = POLLIN;
647
648 int code = ::poll(&poll_struct, 1, 1000 * timeout_seconds);
649
650 if (code == -1) {
651 throw ConnectionException();
652 } else if (code == 0) {
653 throw TimeoutException();
654 }
655
656 return ::read(this->read_fd[0], buf, len);
657 }
658
659 std::string read(long timeout_seconds = NO_TIMEOUT) {
660 if (!this->buffered_msgs.empty()) {
661 std::string msg = this->buffered_msgs.front();
662 this->buffered_msgs.pop();
663 return msg;
664 }
665
666 std::string cur_msg = "";
667
668 while (true) {
669 int bytes_received;
670
671 if (timeout_seconds == NO_TIMEOUT) {
672 bytes_received =
673 ::read(this->read_fd[0], this->buf.get() + this->start_pos,
674 this->buf_size - this->start_pos);
675
676 if (bytes_received == -1) {
677 throw ConnectionException();
678 }
679 } else {
680 bytes_received = this->read(this->buf.get() + this->start_pos,
681 this->buf_size - this->start_pos,
682 timeout_seconds);
683 }
684
685 if (bytes_received == 0) {
686 return std::string(this->buf.get(), this->start_pos);
687 }
688
689 bool first_msg_to_receive = true;
690 std::string first_msg;
691
692 charstreambuf buf(this->buf, bytes_received + this->start_pos);
693 std::istream in(&buf);
694
695 int cur_pos = 0;
696 bool last_is_newline = this->buf.get()[bytes_received + this->start_pos - 1] == '\n';
697
698 while (!in.eof()) {
699 std::string msg;
700 std::getline(in, msg);
701
702 if (in.eof() && !last_is_newline) {
703 int size = bytes_received + this->start_pos - cur_pos;
704
705 if (size == this->buf_size) {
706 cur_msg += std::string(this->buf.get(), this->buf_size);
707 this->start_pos = 0;
708 } else {
709 std::memmove(this->buf.get(), this->buf.get() + cur_pos, size);
710 this->start_pos = size;
711 }
712 } else {
713 if (!cur_msg.empty() || !msg.empty()) {
714 if (first_msg_to_receive) {
715 first_msg = cur_msg + msg;
716 first_msg_to_receive = false;
717 } else {
718 this->buffered_msgs.push(cur_msg + msg);
719 }
720
721 cur_msg = "";
722 }
723
724 cur_pos += msg.length() + 1;
725 }
726 }
727
728 if (last_is_newline) {
729 this->start_pos = 0;
730 }
731
732 if (!first_msg_to_receive) {
733 return first_msg;
734 }
735 }
736
737 // Should not get here.
738 return "";
739 }
740
741 void write(std::string msg, bool new_line) {
742 if (new_line) {
743 msg += "\n";
744 }
745
746 const char *buf = msg.c_str();
747 int written = ::write(this->write_fd[1], buf, msg.size());
748
749 if (written != msg.size()) {
750 std::runtime_error err("Wrote " +
751 std::to_string(written) +
752 " bytes instead of " +
753 std::to_string(msg.size()) +
754 " to fd " +
755 std::to_string(this->write_fd[1]));
756 throw ConnectionException(err);
757 }
758 }
759
760 void write(fs::path file) {
761 std::unique_ptr<char> buf(new char[FILE_BUFFER_SIZE]);
762 std::ifstream file_stream(file, std::ios_base::in |
763 std::ios_base::binary);
764
765 if (!file_stream) {
766 std::runtime_error err("Could not open the file " +
767 file.string() + "!");
768 throw ConnectionException(err);
769 }
770
771 while (file_stream) {
772 file_stream.read(buf.get(), FILE_BUFFER_SIZE);
773 int bytes_read = file_stream.gcount();
774 int bytes_written = ::write(this->write_fd[1], buf.get(),
775 bytes_read);
776
777 if (bytes_written != bytes_read) {
778 std::runtime_error err("Wrote " +
779 std::to_string(bytes_written) +
780 " bytes instead of " +
781 std::to_string(bytes_read) +
782 " to fd " +
783 std::to_string(this->write_fd[1]));
784 throw ConnectionException(err);
785 }
786 }
787 }
788
789 void write(unsigned int len, char *buf) {
790 int bytes_written = ::write(this->write_fd[1], buf, len);
791
792 if (bytes_written != len) {
793 std::runtime_error err("Wrote " +
794 std::to_string(bytes_written) +
795 " bytes instead of " +
796 std::to_string(len) +
797 " to fd " +
798 std::to_string(this->write_fd[1]));
799 throw ConnectionException(err);
800 }
801 }
802
803 unsigned int get_buf_size() {
804 return this->buf_size;
805 }
806
807 void close() {
808 if (this->read_fd[0] != -1) {
809 ::close(this->read_fd[0]);
810 this->read_fd[0] = -1;
811 }
812
813 if (this->write_fd[1] != -1) {
814 ::close(this->write_fd[1]);
815 this->write_fd[1] = -1;
816 }
817 }
818 };
819
824 class PipeAcceptor : public Acceptor {
825 private:
826 int read_fd[2];
827 int write_fd[2];
828
834 PipeAcceptor() : Acceptor(1) {
835 if (pipe(this->read_fd) != 0) {
836 std::runtime_error err("Could not open read pipe for FileDescriptor, "
837 "code " + std::to_string(errno));
838 throw ConnectionException(err);
839 }
840
841 if (pipe(this->write_fd) != 0) {
842 std::runtime_error err("Could not open write pipe for FileDescriptor, "
843 "code " + std::to_string(errno));
844 throw ConnectionException(err);
845 }
846 }
847
848 protected:
849 std::unique_ptr<Connection> accept_connection(unsigned int buf_size,
850 long timeout) {
851 std::string expected = "connect";
852 const int size = expected.size();
853
854 char buf[size];
855 int bytes_received = 0;
856
857 while (bytes_received < size) {
858 if (timeout != NO_TIMEOUT) {
859 struct pollfd poll_struct;
860 poll_struct.fd = this->read_fd[0];
861 poll_struct.events = POLLIN;
862
863 int code = ::poll(&poll_struct, 1, 1000 * timeout);
864
865 if (code == -1) {
866 throw ConnectionException();
867 } else if (code == 0) {
868 throw TimeoutException();
869 }
870 }
871
872 int received = ::read(this->read_fd[0], buf + bytes_received,
873 size - bytes_received);
874
875 if (received <= 0) {
876 break;
877 }
878
879 bytes_received += received;
880 }
881
882 std::string msg(buf, size);
883
884 if (msg != expected) {
885 std::runtime_error err("Message received from pipe when establishing connection "
886 "is \"" + msg + "\" instead of \"" + expected + "\".");
887 throw ConnectionException(err);
888 }
889
890 return std::unique_ptr<Connection>(new FileDescriptor(this->read_fd,
891 this->write_fd,
892 buf_size));
893 }
894
895 void close() {}
896
897 public:
901 class Factory : public Acceptor::Factory {
902 public:
911 std::unique_ptr<Acceptor> make_acceptor(int max_accepted) {
912 if (max_accepted != 1) {
913 throw std::runtime_error("max_accepted can only be 1 for FileDescriptor");
914 }
915
916 return std::unique_ptr<Acceptor>(new PipeAcceptor());
917 }
918
919 std::string get_type() {
920 return "pipe";
921 }
922 };
923
927 std::string get_connection_instructions() {
928 return std::to_string(this->write_fd[0]) + "_" + std::to_string(this->read_fd[1]);
929 }
930
931 std::string get_type() { return "pipe"; }
932 };
933#endif
934}
935
936#endif
Definition socket.hpp:217
virtual std::string get_type()=0
virtual std::unique_ptr< Acceptor > make_acceptor(int max_accepted)=0
Definition socket.hpp:175
virtual std::string get_type()=0
virtual std::string get_connection_instructions()=0
Acceptor(int max_accepted)
Definition socket.hpp:188
virtual void close()=0
virtual std::unique_ptr< Connection > accept_connection(unsigned int buf_size, long timeout)=0
std::unique_ptr< Connection > accept(unsigned int buf_size, long timeout=NO_TIMEOUT)
Definition socket.hpp:251
virtual ~Acceptor()
Definition socket.hpp:265
Definition socket.hpp:58
Definition socket.hpp:48
ConnectionException(std::exception &other)
Definition socket.hpp:51
ConnectionException()
Definition socket.hpp:50
Definition socket.hpp:72
virtual void write(unsigned int len, char *buf)=0
virtual std::string read(long timeout_seconds=NO_TIMEOUT)=0
virtual void write(std::string msg, bool new_line=true)=0
virtual unsigned int get_buf_size()=0
virtual void write(fs::path file)=0
virtual int read(char *buf, unsigned int len, long timeout_seconds)=0
virtual void close()=0
virtual ~Connection()
Definition socket.hpp:80
Definition socket.hpp:147
virtual std::string read(long timeout_seconds=NO_TIMEOUT)=0
virtual std::string get_address()=0
virtual unsigned int get_buf_size()=0
virtual unsigned short get_port()=0
virtual int read(char *buf, unsigned int len, long timeout_seconds)=0
virtual void write(fs::path file)=0
virtual void write(std::string msg, bool new_line=true)=0
virtual void write(unsigned int len, char *buf)=0
virtual void close()=0
virtual ~Socket()
Definition socket.hpp:152
std::unique_ptr< Acceptor > make_acceptor(int max_accepted)
Definition socket.hpp:563
std::string get_type()
Definition socket.hpp:570
Factory(std::string address, unsigned short port, bool try_subsequent_ports=false)
Definition socket.hpp:556
void close()
Definition socket.hpp:532
std::string get_type()
Definition socket.hpp:586
~TCPAcceptor()
Definition socket.hpp:575
std::unique_ptr< Connection > accept_connection(unsigned int buf_size, long timeout)
Definition socket.hpp:522
std::string get_connection_instructions()
Definition socket.hpp:582
std::string read(long timeout_seconds=NO_TIMEOUT)
Definition socket.hpp:344
void write(unsigned int len, char *buf)
Definition socket.hpp:460
void write(std::string msg, bool new_line)
Definition socket.hpp:426
TCPSocket(net::StreamSocket &sock, unsigned int buf_size)
Definition socket.hpp:306
unsigned int get_buf_size()
Definition socket.hpp:325
unsigned short get_port()
Definition socket.hpp:321
std::string get_address()
Definition socket.hpp:317
void write(fs::path file)
Definition socket.hpp:450
~TCPSocket()
Definition socket.hpp:313
int read(char *buf, unsigned int len, long timeout_seconds)
Definition socket.hpp:329
Definition socket.hpp:65
Definition socket.hpp:35
charstreambuf(std::unique_ptr< char[]> &begin, unsigned int length)
Definition socket.hpp:37
Definition output.hpp:12
#define FILE_BUFFER_SIZE
Definition socket.hpp:28
#define NO_TIMEOUT
Definition socket.hpp:25
#define UNLIMITED_ACCEPTED
Definition socket.hpp:24