summaryrefslogtreecommitdiff
path: root/src/network/tcp_client_socket_handler.cpp
blob: 6f67f02efa987210ab0881cc59093489fa161622 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#include <network/tcp_client_socket_handler.hpp>
#include <utils/timed_events.hpp>
#include <utils/scopeguard.hpp>
#include <network/poller.hpp>

#include <logger/logger.hpp>

#include <cstring>
#include <unistd.h>
#include <fcntl.h>

using namespace std::string_literals;

TCPClientSocketHandler::TCPClientSocketHandler(std::shared_ptr<Poller>& poller):
   TCPSocketHandler(poller),
   hostname_resolution_failed(false),
   connected(false),
   connecting(false)
{}

TCPClientSocketHandler::~TCPClientSocketHandler()
{
  this->close();
}

void TCPClientSocketHandler::init_socket(const struct addrinfo* rp)
{
  if (this->socket != -1)
    ::close(this->socket);
  if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
    throw std::runtime_error("Could not create socket: "s + std::strerror(errno));
  // Bind the socket to a specific address, if specified
  if (!this->bind_addr.empty())
    {
      // Convert the address from string format to a sockaddr that can be
      // used in bind()
      struct addrinfo* result;
      struct addrinfo hints{};
      memset(&hints, 0, sizeof(hints));
      hints.ai_flags = AI_NUMERICHOST;
      hints.ai_family = AF_UNSPEC;
      int err = ::getaddrinfo(this->bind_addr.data(), nullptr, &hints, &result);
      if (err != 0 || !result)
        log_error("Failed to bind socket to ", this->bind_addr, ": ",
                  gai_strerror(err));
      else
        {
          utils::ScopeGuard sg([result](){ freeaddrinfo(result); });
          for (; result; result = result->ai_next)
            {
              if ((::bind(this->socket,
                         reinterpret_cast<const struct sockaddr*>(result->ai_addr),
                         result->ai_addrlen)) == 0)
                break;
            }
          if (!result)
            log_error("Failed to bind socket to ", this->bind_addr, ": ",
                      strerror(errno));
          else
            log_info("Socket successfully bound to ", this->bind_addr);
        }
    }
  int optval = 1;
  if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
    log_warning("Failed to enable TCP keepalive on socket: ", strerror(errno));
  // Set the socket on non-blocking mode.  This is useful to receive a EAGAIN
  // error when connect() would block, to not block the whole process if a
  // remote is not responsive.
  const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
  if ((existing_flags == -1) ||
      (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
    throw std::runtime_error("Could not initialize socket: "s + std::strerror(errno));
}

void TCPClientSocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
{
  this->address = address;
  this->port = port;
  this->use_tls = tls;

  struct addrinfo* addr_res;

  if (!this->connecting)
    {
      // Get the addrinfo from getaddrinfo (or using udns), only if
      // this is the first call of this function.
      if (!this->resolver.is_resolved())
        {
          log_info("Trying to connect to ", address, ":", port);
          // Start the asynchronous process of resolving the hostname. Once
          // the addresses have been found and `resolved` has been set to true
          // (but connecting will still be false), TCPClientSocketHandler::connect()
          // needs to be called, again.
          this->resolver.resolve(address, port,
                                 [this](const struct addrinfo*)
                                 {
                                   log_debug("Resolution success, calling connect() again");
                                   this->connect();
                                 },
                                 [this](const char*)
                                 {
                                   log_debug("Resolution failed, calling connect() again");
                                   this->connect();
                                 });
          return;
        }
      else
        {
          // The DNS resolver resolved the hostname and the available addresses
          // where saved in the addrinfo linked list. Now, just use
          // this list to try to connect.
          addr_res = this->resolver.get_result().get();
          if (!addr_res)
            {
              this->hostname_resolution_failed = true;
              const auto msg = this->resolver.get_error_message();
              this->close();
              this->on_connection_failed(msg);
              return ;
            }
        }
    }
  else
    { // This function is called again, use the saved addrinfo structure,
      // instead of re-doing the whole getaddrinfo process.
      addr_res = &this->addrinfo;
    }

  for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
    {
      if (!this->connecting)
        {
          try {
            this->init_socket(rp);
          }
          catch (const std::runtime_error& error) {
            log_error("Failed to init socket: ", error.what());
            break;
          }
        }

      this->display_resolved_ip(rp);

      if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
          || errno == EISCONN)
        {
          log_info("Connection success.");
#ifdef BOTAN_FOUND
          if (this->use_tls)
            try {
                this->start_tls(this->address, this->port);
              } catch (const Botan::Exception& e)
              {
                this->on_connection_failed("TLS error: "s + e.what());
                this->close();
                return ;
              }
#endif
          TimedEventsManager::instance().cancel("connection_timeout" +
                                                std::to_string(this->socket));
          this->poller->add_socket_handler(this);
          this->connected = true;
          this->connecting = false;
          this->connection_date = std::chrono::system_clock::now();

          // Get our local TCP port and store it
          this->local_port = static_cast<uint16_t>(-1);
          if (rp->ai_family == AF_INET6)
            {
              struct sockaddr_in6 a{};
              socklen_t l = sizeof(a);
              if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
                this->local_port = ntohs(a.sin6_port);
            }
          else if (rp->ai_family == AF_INET)
            {
              struct sockaddr_in a{};
              socklen_t l = sizeof(a);
              if (::getsockname(this->socket, (struct sockaddr*)&a, &l) != -1)
                this->local_port = ntohs(a.sin_port);
            }

          log_debug("Local port: ", this->local_port, ", and remote port: ", this->port);

          this->on_connected();
          return ;
        }
      else if (errno == EINPROGRESS || errno == EALREADY)
        {   // retry this process later, when the socket
            // is ready to be written on.
          this->connecting = true;
          this->poller->add_socket_handler(this);
          this->poller->watch_send_events(this);
          // Save the addrinfo structure, to use it on the next call
          this->ai_addrlen = rp->ai_addrlen;
          memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
          memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
          this->addrinfo.ai_addr = reinterpret_cast<struct sockaddr*>(&this->ai_addr);
          this->addrinfo.ai_next = nullptr;
          // If the connection has not succeeded or failed in 5s, we consider
          // it to have failed
          TimedEventsManager::instance().add_event(
                                                   TimedEvent(std::chrono::steady_clock::now() + 5s,
                                                              std::bind(&TCPClientSocketHandler::on_connection_timeout, this),
                                                              "connection_timeout" + std::to_string(this->socket)));
          return ;
        }
      log_info("Connection failed:", std::strerror(errno));
    }
  log_error("All connection attempts failed.");
  this->close();
  this->on_connection_failed(std::strerror(errno));
  return ;
}

void TCPClientSocketHandler::on_connection_timeout()
{
  this->close();
  this->on_connection_failed("connection timed out");
}

void TCPClientSocketHandler::connect()
{
  this->connect(this->address, this->port, this->use_tls);
}

void TCPClientSocketHandler::close()
{
  TimedEventsManager::instance().cancel("connection_timeout" +
                                        std::to_string(this->socket));

  TCPSocketHandler::close();

  this->connected = false;
  this->connecting = false;
  this->port.clear();
  this->resolver.clear();
}

void TCPClientSocketHandler::display_resolved_ip(struct addrinfo* rp) const
{
  if (rp->ai_family == AF_INET)
    log_debug("Trying IPv4 address ", addr_to_string(rp));
  else if (rp->ai_family == AF_INET6)
    log_debug("Trying IPv6 address ", addr_to_string(rp));
}

bool TCPClientSocketHandler::is_connected() const
{
  return this->connected;
}

bool TCPClientSocketHandler::is_connecting() const
{
  return this->connecting || this->resolver.is_resolving();
}

std::string TCPClientSocketHandler::get_port() const
{
  return this->port;
}

bool TCPClientSocketHandler::match_port_pairt(const uint16_t local, const uint16_t remote) const
{
  const auto remote_port = static_cast<uint16_t>(std::stoi(this->port));
  return this->is_connected() && local == this->local_port && remote == remote_port;
}