4#ifndef ADAPTYST_SOCKET_HPP_
5#define ADAPTYST_SOCKET_HPP_
13#include <Poco/Net/ServerSocket.h>
18#include <Poco/Buffer.h>
19#include <Poco/Net/NetException.h>
20#include <Poco/StreamCopier.h>
21#include <Poco/FileStream.h>
22#include <Poco/Net/SocketStream.h>
24#define UNLIMITED_ACCEPTED -1
27#ifndef FILE_BUFFER_SIZE
28#define FILE_BUFFER_SIZE 1048576
32 namespace net = Poco::Net;
33 namespace fs = std::filesystem;
38 this->setg(begin.get(), begin.get(), begin.get() + length - 1);
95 virtual int read(
char *buf,
unsigned int len,
long timeout_seconds) = 0;
118 virtual void write(std::string msg,
bool new_line =
true) = 0;
127 virtual void write(fs::path file) = 0;
137 virtual void write(
unsigned int len,
char *buf) = 0;
166 virtual int read(
char *buf,
unsigned int len,
long timeout_seconds) = 0;
168 virtual void write(std::string msg,
bool new_line =
true) = 0;
169 virtual void write(fs::path file) = 0;
170 virtual void write(
unsigned int len,
char *buf) = 0;
190 this->max_accepted = max_accepted;
252 std::unique_ptr<Connection>
accept(
unsigned int buf_size,
255 this->accepted >= this->max_accepted) {
256 throw std::runtime_error(
"Maximum accepted connections reached.");
290 net::StreamSocket socket;
291 std::unique_ptr<char[]> buf;
292 unsigned int buf_size;
294 std::queue<std::string> buffered_msgs;
307 TCPSocket(net::StreamSocket &sock,
unsigned int buf_size) {
309 this->buf.reset(
new char[buf_size]);
310 this->buf_size = buf_size;
319 return this->socket.address().host().toString();
323 return this->socket.address().port();
327 return this->buf_size;
330 int read(
char *buf,
unsigned int len,
long timeout_seconds) {
331 auto set_timeout = [
this, &timeout_seconds]() {
333 this->socket.setReceiveTimeout(Poco::Timespan(timeout_seconds, 0));
337 auto unset_timeout = [
this, &timeout_seconds]() {
339 this->socket.setReceiveTimeout(Poco::Timespan());
345 int bytes = this->socket.receiveBytes(buf, len);
348 }
catch (net::NetException &e) {
351 }
catch (Poco::TimeoutException &e) {
359 if (!this->buffered_msgs.empty()) {
360 std::string msg = this->buffered_msgs.front();
361 this->buffered_msgs.pop();
365 std::string cur_msg =
"";
369 this->
read(this->buf.get() + this->start_pos,
370 this->buf_size - this->start_pos, timeout_seconds);
372 if (bytes_received == 0) {
373 return std::string(this->buf.get(), this->start_pos);
376 bool first_msg_to_receive =
true;
377 std::string first_msg;
379 charstreambuf buf(this->buf, bytes_received + this->start_pos);
380 std::istream in(&buf);
383 bool last_is_newline = this->buf.get()[bytes_received + this->start_pos - 1] ==
'\n';
387 std::getline(in, msg);
389 if (in.eof() && !last_is_newline) {
390 int size = bytes_received + this->start_pos - cur_pos;
392 if (size == this->buf_size) {
393 cur_msg += std::string(this->buf.get(), this->buf_size);
396 std::memmove(this->buf.get(), this->buf.get() + cur_pos, size);
397 this->start_pos = size;
400 if (!cur_msg.empty() || !msg.empty()) {
401 if (first_msg_to_receive) {
402 first_msg = cur_msg + msg;
403 first_msg_to_receive =
false;
405 this->buffered_msgs.push(cur_msg + msg);
411 cur_pos += msg.length() + 1;
415 if (last_is_newline) {
419 if (!first_msg_to_receive) {
426 }
catch (net::NetException &e) {
431 void write(std::string msg,
bool new_line) {
437 const char *buf = msg.c_str();
439 int bytes_written = this->socket.sendBytes(buf, msg.size());
441 if (bytes_written != msg.size()) {
442 std::runtime_error err(
"Wrote " +
443 std::to_string(bytes_written) +
444 " bytes instead of " +
445 std::to_string(msg.size()) +
447 this->socket.address().toString());
450 }
catch (net::NetException &e) {
457 net::SocketStream socket_stream(this->socket);
458 Poco::FileInputStream stream(file, std::ios::in | std::ios::binary);
459 Poco::StreamCopier::copyStream(stream, socket_stream);
460 }
catch (net::NetException &e) {
465 void write(
unsigned int len,
char *buf) {
467 int bytes_written = this->socket.sendBytes(buf, len);
468 if (bytes_written != len) {
469 std::runtime_error err(
"Wrote " +
470 std::to_string(bytes_written) +
471 " bytes instead of " +
472 std::to_string(len) +
474 this->socket.address().toString());
477 }
catch (net::NetException &e) {
488 net::ServerSocket acceptor;
490 TCPAcceptor(std::string address,
unsigned short port,
492 bool try_subsequent_ports) :
Acceptor(max_accepted) {
493 if (try_subsequent_ports) {
494 bool success =
false;
497 this->acceptor.bind(net::SocketAddress(address, port),
false);
499 }
catch (net::NetException &e) {
500 if (e.message().find(
"already in use") != std::string::npos) {
509 this->acceptor.bind(net::SocketAddress(address, port),
false);
510 }
catch (net::NetException &e) {
511 if (e.message().find(
"already in use") != std::string::npos) {
520 this->acceptor.listen();
521 }
catch (net::NetException &e) {
530 net::StreamSocket socket = this->acceptor.acceptConnection();
531 return std::make_unique<TCPSocket>(socket, buf_size);
532 }
catch (net::NetException &e) {
537 void close() { this->acceptor.close(); }
547 bool try_subsequent_ports;
561 Factory(std::string address,
unsigned short port,
562 bool try_subsequent_ports =
false) {
563 this->address = address;
565 this->try_subsequent_ports = try_subsequent_ports;
569 return std::unique_ptr<Acceptor>(
new TCPAcceptor(this->address,
572 this->try_subsequent_ports));
588 return this->acceptor.address().host().toString() +
"_" + std::to_string(this->acceptor.address().port());
599 class FileDescriptor :
public Connection {
603 unsigned int buf_size;
604 std::queue<std::string> buffered_msgs;
605 std::unique_ptr<char[]> buf;
607 bool close_on_destruct;
623 FileDescriptor(
int read_fd[2],
625 unsigned int buf_size,
626 bool close_on_destruct =
true) {
627 this->buf.reset(
new char[buf_size]);
628 this->buf_size = buf_size;
630 this->close_on_destruct = close_on_destruct;
632 if (read_fd !=
nullptr) {
633 this->read_fd[0] = read_fd[0];
634 this->read_fd[1] = read_fd[1];
636 this->read_fd[0] = -1;
637 this->read_fd[1] = -1;
640 if (write_fd !=
nullptr) {
641 this->write_fd[0] = write_fd[0];
642 this->write_fd[1] = write_fd[1];
644 this->write_fd[0] = -1;
645 this->write_fd[1] = -1;
650 if (this->close_on_destruct) {
655 int read(
char *buf,
unsigned int len,
long timeout_seconds) {
657 struct pollfd poll_struct;
658 poll_struct.fd = this->read_fd[0];
659 poll_struct.events = POLLIN;
661 int code = ::poll(&poll_struct, 1, 1000 * timeout_seconds);
664 throw ConnectionException();
665 }
else if (code == 0) {
666 throw TimeoutException();
670 return ::read(this->read_fd[0], buf, len);
673 std::string read(
long timeout_seconds =
NO_TIMEOUT) {
674 if (!this->buffered_msgs.empty()) {
675 std::string msg = this->buffered_msgs.front();
676 this->buffered_msgs.pop();
680 std::string cur_msg =
"";
684 this->read(this->buf.get() + this->start_pos,
685 this->buf_size - this->start_pos,
688 if (bytes_received == -1) {
689 throw ConnectionException();
690 }
else if (bytes_received == 0) {
691 return std::string(this->buf.get(), this->start_pos);
694 bool first_msg_to_receive =
true;
695 std::string first_msg;
697 charstreambuf buf(this->buf, bytes_received + this->start_pos);
698 std::istream in(&buf);
701 bool last_is_newline = this->buf.get()[bytes_received + this->start_pos - 1] ==
'\n';
705 std::getline(in, msg);
707 if (in.eof() && !last_is_newline) {
708 int size = bytes_received + this->start_pos - cur_pos;
710 if (size == this->buf_size) {
711 cur_msg += std::string(this->buf.get(), this->buf_size);
714 std::memmove(this->buf.get(), this->buf.get() + cur_pos, size);
715 this->start_pos = size;
718 if (!cur_msg.empty() || !msg.empty()) {
719 if (first_msg_to_receive) {
720 first_msg = cur_msg + msg;
721 first_msg_to_receive =
false;
723 this->buffered_msgs.push(cur_msg + msg);
729 cur_pos += msg.length() + 1;
733 if (last_is_newline) {
737 if (!first_msg_to_receive) {
746 void write(std::string msg,
bool new_line) {
751 const char *buf = msg.c_str();
752 int written = ::write(this->write_fd[1], buf, msg.size());
754 if (written != msg.size()) {
755 std::runtime_error err(
"Wrote " +
756 std::to_string(written) +
757 " bytes instead of " +
758 std::to_string(msg.size()) +
760 std::to_string(this->write_fd[1]));
761 throw ConnectionException(err);
765 void write(fs::path file) {
767 std::ifstream file_stream(file, std::ios_base::in |
768 std::ios_base::binary);
771 std::runtime_error err(
"Could not open the file " +
772 file.string() +
"!");
773 throw ConnectionException(err);
776 while (file_stream) {
778 int bytes_read = file_stream.gcount();
779 int bytes_written = ::write(this->write_fd[1], buf.get(),
782 if (bytes_written != bytes_read) {
783 std::runtime_error err(
"Wrote " +
784 std::to_string(bytes_written) +
785 " bytes instead of " +
786 std::to_string(bytes_read) +
788 std::to_string(this->write_fd[1]));
789 throw ConnectionException(err);
794 void write(
unsigned int len,
char *buf) {
795 int bytes_written = ::write(this->write_fd[1], buf, len);
797 if (bytes_written != len) {
798 std::runtime_error err(
"Wrote " +
799 std::to_string(bytes_written) +
800 " bytes instead of " +
801 std::to_string(len) +
803 std::to_string(this->write_fd[1]));
804 throw ConnectionException(err);
808 unsigned int get_buf_size() {
809 return this->buf_size;
813 if (this->read_fd[0] != -1) {
814 ::close(this->read_fd[0]);
815 this->read_fd[0] = -1;
818 if (this->write_fd[1] != -1) {
819 ::close(this->write_fd[1]);
820 this->write_fd[1] = -1;
824 std::pair<int, int> get_read_fd() {
825 return std::make_pair(this->read_fd[0],
829 std::pair<int, int> get_write_fd() {
830 return std::make_pair(this->write_fd[0],
839 class PipeAcceptor :
public Acceptor {
849 PipeAcceptor() : Acceptor(1) {
850 if (pipe(this->read_fd) != 0) {
851 std::runtime_error err(
"Could not open read pipe for FileDescriptor, "
852 "code " + std::to_string(errno));
853 throw ConnectionException(err);
856 if (pipe(this->write_fd) != 0) {
857 std::runtime_error err(
"Could not open write pipe for FileDescriptor, "
858 "code " + std::to_string(errno));
859 throw ConnectionException(err);
864 std::unique_ptr<Connection> accept_connection(
unsigned int buf_size,
866 std::string expected =
"connect";
867 const int size = expected.size();
870 int bytes_received = 0;
872 while (bytes_received < size) {
874 struct pollfd poll_struct;
875 poll_struct.fd = this->read_fd[0];
876 poll_struct.events = POLLIN;
878 int code = ::poll(&poll_struct, 1, 1000 * timeout);
881 throw ConnectionException();
882 }
else if (code == 0) {
883 throw TimeoutException();
887 int received = ::read(this->read_fd[0], buf + bytes_received,
888 size - bytes_received);
894 bytes_received += received;
897 std::string msg(buf, size);
899 if (msg != expected) {
900 std::runtime_error err(
"Message received from pipe when establishing connection "
901 "is \"" + msg +
"\" instead of \"" + expected +
"\".");
902 throw ConnectionException(err);
905 return std::unique_ptr<Connection>(
new FileDescriptor(this->read_fd,
916 class Factory :
public Acceptor::Factory {
926 std::unique_ptr<Acceptor> make_acceptor(
int max_accepted) {
927 if (max_accepted != 1) {
928 throw std::runtime_error(
"max_accepted can only be 1 for FileDescriptor");
931 return std::unique_ptr<Acceptor>(
new PipeAcceptor());
934 std::string get_type() {
942 std::string get_connection_instructions() {
943 return std::to_string(this->write_fd[0]) +
"_" + std::to_string(this->read_fd[1]);
946 std::string get_type() {
return "pipe"; }
Definition socket.hpp:218
virtual std::string get_type()=0
virtual std::unique_ptr< Acceptor > make_acceptor(int max_accepted)=0
Definition socket.hpp:176
virtual std::string get_type()=0
virtual std::string get_connection_instructions()=0
Acceptor(int max_accepted)
Definition socket.hpp:189
virtual std::unique_ptr< Connection > accept_connection(unsigned int buf_size, long timeout)=0
std::unique_ptr< Connection > accept(unsigned int buf_size, long timeout=NO_TIMEOUT)
Definition socket.hpp:252
virtual ~Acceptor()
Definition socket.hpp:266
ConnectionException(std::exception &other)
Definition socket.hpp:51
ConnectionException()
Definition socket.hpp:50
virtual void write(unsigned int len, char *buf)=0
virtual std::string read(long timeout_seconds=NO_TIMEOUT)=0
virtual void write(std::string msg, bool new_line=true)=0
virtual unsigned int get_buf_size()=0
virtual void write(fs::path file)=0
virtual int read(char *buf, unsigned int len, long timeout_seconds)=0
virtual ~Connection()
Definition socket.hpp:80
Definition socket.hpp:148
virtual std::string read(long timeout_seconds=NO_TIMEOUT)=0
virtual std::string get_address()=0
virtual unsigned int get_buf_size()=0
virtual unsigned short get_port()=0
virtual int read(char *buf, unsigned int len, long timeout_seconds)=0
virtual void write(fs::path file)=0
virtual void write(std::string msg, bool new_line=true)=0
virtual void write(unsigned int len, char *buf)=0
virtual ~Socket()
Definition socket.hpp:153
std::unique_ptr< Acceptor > make_acceptor(int max_accepted)
Definition socket.hpp:568
std::string get_type()
Definition socket.hpp:575
Factory(std::string address, unsigned short port, bool try_subsequent_ports=false)
Definition socket.hpp:561
void close()
Definition socket.hpp:537
std::string get_type()
Definition socket.hpp:591
~TCPAcceptor()
Definition socket.hpp:580
std::unique_ptr< Connection > accept_connection(unsigned int buf_size, long timeout)
Definition socket.hpp:527
std::string get_connection_instructions()
Definition socket.hpp:587
std::string read(long timeout_seconds=NO_TIMEOUT)
Definition socket.hpp:357
void write(unsigned int len, char *buf)
Definition socket.hpp:465
void write(std::string msg, bool new_line)
Definition socket.hpp:431
TCPSocket(net::StreamSocket &sock, unsigned int buf_size)
Definition socket.hpp:307
unsigned int get_buf_size()
Definition socket.hpp:326
unsigned short get_port()
Definition socket.hpp:322
std::string get_address()
Definition socket.hpp:318
void write(fs::path file)
Definition socket.hpp:455
~TCPSocket()
Definition socket.hpp:314
int read(char *buf, unsigned int len, long timeout_seconds)
Definition socket.hpp:330
charstreambuf(std::unique_ptr< char[]> &begin, unsigned int length)
Definition socket.hpp:37
#define FILE_BUFFER_SIZE
Definition socket.hpp:28
#define NO_TIMEOUT
Definition socket.hpp:25
#define UNLIMITED_ACCEPTED
Definition socket.hpp:24