/*
    This file is part of TON Blockchain Library.
    TON Blockchain Library is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 2 of the License, or
    (at your option) any later version.
    TON Blockchain Library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.
    You should have received a copy of the GNU Lesser General Public License
    along with TON Blockchain Library.  If not, see .
    Copyright 2017-2019 Telegram Systems LLP
*/
#include "td/utils/port/SocketFd.h"
#include "td/utils/common.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
#include "td/utils/port/PollFlags.h"
#if TD_PORT_WINDOWS
#include "td/utils/buffer.h"
#include "td/utils/port/detail/Iocp.h"
#include "td/utils/SpinLock.h"
#include "td/utils/VectorQueue.h"
#endif
#if TD_PORT_POSIX
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#endif
#include 
#include 
namespace td {
namespace detail {
#if TD_PORT_WINDOWS
class SocketFdImpl : private Iocp::Callback {
 public:
  explicit SocketFdImpl(NativeFd native_fd) : info(std::move(native_fd)) {
    VLOG(fd) << get_native_fd() << " create from native_fd";
    get_poll_info().add_flags(PollFlags::Write());
    Iocp::get()->subscribe(get_native_fd(), this);
    is_read_active_ = true;
    notify_iocp_connected();
  }
  SocketFdImpl(NativeFd native_fd, const IPAddress &addr) : info(std::move(native_fd)) {
    VLOG(fd) << get_native_fd() << " create from native_fd and connect";
    get_poll_info().add_flags(PollFlags::Write());
    Iocp::get()->subscribe(get_native_fd(), this);
    LPFN_CONNECTEX ConnectExPtr = nullptr;
    GUID guid = WSAID_CONNECTEX;
    DWORD numBytes;
    auto error =
        ::WSAIoctl(get_native_fd().socket(), SIO_GET_EXTENSION_FUNCTION_POINTER, static_cast(&guid),
                   sizeof(guid), static_cast(&ConnectExPtr), sizeof(ConnectExPtr), &numBytes, nullptr, nullptr);
    if (error) {
      on_error(OS_SOCKET_ERROR("WSAIoctl failed"));
      return;
    }
    std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
    inc_refcnt();
    is_read_active_ = true;
    auto status = ConnectExPtr(get_native_fd().socket(), addr.get_sockaddr(), narrow_cast(addr.get_sockaddr_len()),
                               nullptr, 0, nullptr, &read_overlapped_);
    if (status == TRUE || !check_status("Failed to connect")) {
      is_read_active_ = false;
      dec_refcnt();
    }
  }
  void close() {
    notify_iocp_close();
  }
  PollableFdInfo &get_poll_info() {
    return info;
  }
  const PollableFdInfo &get_poll_info() const {
    return info;
  }
  const NativeFd &get_native_fd() const {
    return info.native_fd();
  }
  Result write(Slice data) {
    // LOG(ERROR) << "Write: " << format::as_hex_dump<0>(data);
    output_writer_.append(data);
    return write_finish(data.size());
  }
  Result writev(Span slices) {
    size_t total_size = 0;
    for (auto io_slice : slices) {
      total_size += as_slice(io_slice).size();
    }
    auto left_size = total_size;
    for (auto io_slice : slices) {
      auto slice = as_slice(io_slice);
      output_writer_.append(slice, left_size);
      left_size -= slice.size();
    }
    return write_finish(total_size);
  }
  Result write_finish(size_t total_size) {
    if (is_write_waiting_) {
      auto lock = lock_.lock();
      is_write_waiting_ = false;
      lock.reset();
      notify_iocp_write();
    }
    return total_size;
  }
  Result read(MutableSlice slice) {
    if (get_poll_info().get_flags().has_pending_error()) {
      TRY_STATUS(get_pending_error());
    }
    input_reader_.sync_with_writer();
    auto res = input_reader_.advance(td::min(slice.size(), input_reader_.size()), slice);
    if (res == 0) {
      get_poll_info().clear_flags(PollFlags::Read());
    } else {
      // LOG(ERROR) << "Read: " << format::as_hex_dump<0>(Slice(slice.substr(0, res)));
    }
    return res;
  }
  Status get_pending_error() {
    Status res;
    {
      auto lock = lock_.lock();
      if (!pending_errors_.empty()) {
        res = pending_errors_.pop();
      }
      if (res.is_ok()) {
        get_poll_info().clear_flags(PollFlags::Error());
      }
    }
    return res;
  }
 private:
  PollableFdInfo info;
  SpinLock lock_;
  std::atomic refcnt_{1};
  bool close_flag_{false};
  bool is_connected_{false};
  bool is_read_active_{false};
  ChainBufferWriter input_writer_;
  ChainBufferReader input_reader_ = input_writer_.extract_reader();
  WSAOVERLAPPED read_overlapped_;
  VectorQueue pending_errors_;
  bool is_write_active_{false};
  std::atomic is_write_waiting_{false};
  ChainBufferWriter output_writer_;
  ChainBufferReader output_reader_ = output_writer_.extract_reader();
  WSAOVERLAPPED write_overlapped_;
  char close_overlapped_;
  bool check_status(Slice message) {
    auto last_error = WSAGetLastError();
    if (last_error == ERROR_IO_PENDING) {
      return true;
    }
    on_error(OS_SOCKET_ERROR(message));
    return false;
  }
  void loop_read() {
    CHECK(is_connected_);
    CHECK(!is_read_active_);
    if (close_flag_) {
      return;
    }
    std::memset(&read_overlapped_, 0, sizeof(read_overlapped_));
    auto dest = input_writer_.prepare_append();
    WSABUF buf;
    buf.len = narrow_cast(dest.size());
    buf.buf = dest.data();
    DWORD flags = 0;
    int status = WSARecv(get_native_fd().socket(), &buf, 1, nullptr, &flags, &read_overlapped_, nullptr);
    if (status == 0 || check_status("Failed to read from connection")) {
      inc_refcnt();
      is_read_active_ = true;
    }
  }
  void loop_write() {
    CHECK(is_connected_);
    CHECK(!is_write_active_);
    output_reader_.sync_with_writer();
    auto to_write = output_reader_.prepare_read();
    if (to_write.empty()) {
      auto lock = lock_.lock();
      output_reader_.sync_with_writer();
      to_write = output_reader_.prepare_read();
      if (to_write.empty()) {
        is_write_waiting_ = true;
        return;
      }
    }
    if (to_write.empty()) {
      return;
    }
    std::memset(&write_overlapped_, 0, sizeof(write_overlapped_));
    constexpr size_t buf_size = 20;
    WSABUF buf[buf_size];
    auto it = output_reader_.clone();
    size_t buf_i;
    for (buf_i = 0; buf_i < buf_size; buf_i++) {
      auto src = it.prepare_read();
      if (src.empty()) {
        break;
      }
      buf[buf_i].len = narrow_cast(src.size());
      buf[buf_i].buf = const_cast(src.data());
      it.confirm_read(src.size());
    }
    int status =
        WSASend(get_native_fd().socket(), buf, narrow_cast(buf_i), nullptr, 0, &write_overlapped_, nullptr);
    if (status == 0 || check_status("Failed to write to connection")) {
      inc_refcnt();
      is_write_active_ = true;
    }
  }
  void on_iocp(Result r_size, WSAOVERLAPPED *overlapped) override {
    // called from other thread
    if (dec_refcnt() || close_flag_) {
      VLOG(fd) << "Ignore IOCP (socket is closing)";
      return;
    }
    if (r_size.is_error()) {
      return on_error(get_socket_pending_error(get_native_fd(), overlapped, r_size.move_as_error()));
    }
    if (!is_connected_ && overlapped == &read_overlapped_) {
      return on_connected();
    }
    auto size = r_size.move_as_ok();
    if (overlapped == &write_overlapped_) {
      return on_write(size);
    }
    if (overlapped == nullptr) {
      CHECK(size == 0);
      return on_write(size);
    }
    if (overlapped == &read_overlapped_) {
      return on_read(size);
    }
    if (overlapped == reinterpret_cast(&close_overlapped_)) {
      return on_close();
    }
    UNREACHABLE();
  }
  void on_error(Status status) {
    VLOG(fd) << get_native_fd() << " on error " << status;
    {
      auto lock = lock_.lock();
      pending_errors_.push(std::move(status));
    }
    get_poll_info().add_flags_from_poll(PollFlags::Error());
  }
  void on_connected() {
    VLOG(fd) << get_native_fd() << " on connected";
    CHECK(!is_connected_);
    CHECK(is_read_active_);
    is_connected_ = true;
    is_read_active_ = false;
    loop_read();
    loop_write();
  }
  void on_read(size_t size) {
    VLOG(fd) << get_native_fd() << " on read " << size;
    CHECK(is_read_active_);
    is_read_active_ = false;
    if (size == 0) {
      get_poll_info().add_flags_from_poll(PollFlags::Close());
      return;
    }
    input_writer_.confirm_append(size);
    get_poll_info().add_flags_from_poll(PollFlags::Read());
    loop_read();
  }
  void on_write(size_t size) {
    VLOG(fd) << get_native_fd() << " on write " << size;
    if (size == 0) {
      if (is_write_active_) {
        return;
      }
      is_write_active_ = true;
    }
    CHECK(is_write_active_);
    is_write_active_ = false;
    output_reader_.advance(size);
    loop_write();
  }
  void on_close() {
    VLOG(fd) << get_native_fd() << " on close";
    close_flag_ = true;
    info.set_native_fd({});
  }
  bool dec_refcnt() {
    VLOG(fd) << get_native_fd() << " dec_refcnt from " << refcnt_;
    if (--refcnt_ == 0) {
      delete this;
      return true;
    }
    return false;
  }
  void inc_refcnt() {
    CHECK(refcnt_ != 0);
    refcnt_++;
    VLOG(fd) << get_native_fd() << " inc_refcnt to " << refcnt_;
  }
  void notify_iocp_write() {
    inc_refcnt();
    Iocp::get()->post(0, this, nullptr);
  }
  void notify_iocp_close() {
    Iocp::get()->post(0, this, reinterpret_cast(&close_overlapped_));
  }
  void notify_iocp_connected() {
    inc_refcnt();
    Iocp::get()->post(0, this, &read_overlapped_);
  }
};
void SocketFdImplDeleter::operator()(SocketFdImpl *impl) {
  impl->close();
}
class InitWSA {
 public:
  InitWSA() {
    /* Use the MAKEWORD(lowbyte, highbyte) macro declared in Windef.h */
    WORD wVersionRequested = MAKEWORD(2, 2);
    WSADATA wsaData;
    if (WSAStartup(wVersionRequested, &wsaData) != 0) {
      auto error = OS_SOCKET_ERROR("Failed to init WSA");
      LOG(FATAL) << error;
    }
  }
};
static InitWSA init_wsa;
#else
class SocketFdImpl {
 public:
  PollableFdInfo info;
  explicit SocketFdImpl(NativeFd fd) : info(std::move(fd)) {
  }
  PollableFdInfo &get_poll_info() {
    return info;
  }
  const PollableFdInfo &get_poll_info() const {
    return info;
  }
  const NativeFd &get_native_fd() const {
    return info.native_fd();
  }
  Result writev(Span slices) {
    int native_fd = get_native_fd().socket();
    auto write_res =
        detail::skip_eintr([&] { return ::writev(native_fd, slices.begin(), narrow_cast(slices.size())); });
    return write_finish(write_res);
  }
  Result write(Slice slice) {
    int native_fd = get_native_fd().socket();
    auto write_res = detail::skip_eintr([&] { return ::write(native_fd, slice.begin(), slice.size()); });
    return write_finish(write_res);
  }
  Result write_finish(ssize_t write_res) {
    auto write_errno = errno;
    if (write_res >= 0) {
      return narrow_cast(write_res);
    }
    if (write_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
        || write_errno == EWOULDBLOCK
#endif
    ) {
      get_poll_info().clear_flags(PollFlags::Write());
      return 0;
    }
    auto error = Status::PosixError(write_errno, PSLICE() << "Write to " << get_native_fd() << " has failed");
    switch (write_errno) {
      case EBADF:
      case ENXIO:
      case EFAULT:
      case EINVAL:
        LOG(FATAL) << error;
        UNREACHABLE();
      default:
        LOG(WARNING) << error;
      // fallthrough
      case ECONNRESET:
      case EDQUOT:
      case EFBIG:
      case EIO:
      case ENETDOWN:
      case ENETUNREACH:
      case ENOSPC:
      case EPIPE:
        get_poll_info().clear_flags(PollFlags::Write());
        get_poll_info().add_flags(PollFlags::Close());
        return std::move(error);
    }
  }
  Result read(MutableSlice slice) {
    if (get_poll_info().get_flags().has_pending_error()) {
      TRY_STATUS(get_pending_error());
    }
    int native_fd = get_native_fd().socket();
    CHECK(slice.size() > 0);
    auto read_res = detail::skip_eintr([&] { return ::read(native_fd, slice.begin(), slice.size()); });
    auto read_errno = errno;
    if (read_res >= 0) {
      if (read_res == 0) {
        errno = 0;
        get_poll_info().clear_flags(PollFlags::Read());
        get_poll_info().add_flags(PollFlags::Close());
      }
      return narrow_cast(read_res);
    }
    if (read_errno == EAGAIN
#if EAGAIN != EWOULDBLOCK
        || read_errno == EWOULDBLOCK
#endif
    ) {
      get_poll_info().clear_flags(PollFlags::Read());
      return 0;
    }
    auto error = Status::PosixError(read_errno, PSLICE() << "Read from " << get_native_fd() << " has failed");
    switch (read_errno) {
      case EISDIR:
      case EBADF:
      case ENXIO:
      case EFAULT:
      case EINVAL:
        LOG(FATAL) << error;
        UNREACHABLE();
      default:
        LOG(WARNING) << error;
      // fallthrough
      case ENOTCONN:
      case EIO:
      case ENOBUFS:
      case ENOMEM:
      case ECONNRESET:
      case ETIMEDOUT:
        get_poll_info().clear_flags(PollFlags::Read());
        get_poll_info().add_flags(PollFlags::Close());
        return std::move(error);
    }
  }
  Status get_pending_error() {
    if (!get_poll_info().get_flags().has_pending_error()) {
      return Status::OK();
    }
    TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
    get_poll_info().clear_flags(PollFlags::Error());
    return Status::OK();
  }
};
void SocketFdImplDeleter::operator()(SocketFdImpl *impl) {
  delete impl;
}
#endif
#if TD_PORT_POSIX
Status get_socket_pending_error(const NativeFd &fd) {
  int error = 0;
  socklen_t errlen = sizeof(error);
  if (getsockopt(fd.socket(), SOL_SOCKET, SO_ERROR, static_cast(&error), &errlen) == 0) {
    if (error == 0) {
      return Status::OK();
    }
    return Status::PosixError(error, PSLICE() << "Error on " << fd);
  }
  auto status = OS_SOCKET_ERROR(PSLICE() << "Can't load error on socket " << fd);
  LOG(INFO) << "Can't load pending socket error: " << status;
  return status;
}
#elif TD_PORT_WINDOWS
Status get_socket_pending_error(const NativeFd &fd, WSAOVERLAPPED *overlapped, Status iocp_error) {
  // We need to call WSAGetOverlappedResult() just so WSAGetLastError() will return the correct error. See
  // https://stackoverflow.com/questions/28925003/calling-wsagetlasterror-from-an-iocp-thread-return-incorrect-result
  DWORD num_bytes = 0;
  DWORD flags = 0;
  BOOL success = WSAGetOverlappedResult(fd.socket(), overlapped, &num_bytes, false, &flags);
  if (success) {
    LOG(ERROR) << "WSAGetOverlappedResult succeded after " << iocp_error;
    return iocp_error;
  }
  return OS_SOCKET_ERROR(PSLICE() << "Error on " << fd);
}
#endif
Status init_socket_options(NativeFd &native_fd) {
  TRY_STATUS(native_fd.set_is_blocking_unsafe(false));
  auto sock = native_fd.socket();
#if TD_PORT_POSIX
  int flags = 1;
#elif TD_PORT_WINDOWS
  BOOL flags = TRUE;
#endif
  setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&flags), sizeof(flags));
  setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&flags), sizeof(flags));
  setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&flags), sizeof(flags));
  // TODO: SO_REUSEADDR, SO_KEEPALIVE, TCP_NODELAY, SO_SNDBUF, SO_RCVBUF, TCP_QUICKACK, SO_LINGER
  return Status::OK();
}
}  // namespace detail
SocketFd::SocketFd() = default;
SocketFd::SocketFd(SocketFd &&) = default;
SocketFd &SocketFd::operator=(SocketFd &&) = default;
SocketFd::~SocketFd() = default;
SocketFd::SocketFd(unique_ptr impl) : impl_(impl.release()) {
}
Result SocketFd::from_native_fd(NativeFd fd) {
  TRY_STATUS(detail::init_socket_options(fd));
  return SocketFd(make_unique(std::move(fd)));
}
Result SocketFd::open(const IPAddress &address) {
  NativeFd native_fd{socket(address.get_address_family(), SOCK_STREAM, IPPROTO_TCP)};
  if (!native_fd) {
    return OS_SOCKET_ERROR("Failed to create a socket");
  }
  TRY_STATUS(detail::init_socket_options(native_fd));
#if TD_PORT_POSIX
  int e_connect =
      connect(native_fd.socket(), address.get_sockaddr(), narrow_cast(address.get_sockaddr_len()));
  if (e_connect == -1) {
    auto connect_errno = errno;
    if (connect_errno != EINPROGRESS) {
      return Status::PosixError(connect_errno, PSLICE() << "Failed to connect to " << address);
    }
  }
  return SocketFd(make_unique(std::move(native_fd)));
#elif TD_PORT_WINDOWS
  auto bind_addr = address.get_any_addr();
  auto e_bind = bind(native_fd.socket(), bind_addr.get_sockaddr(), narrow_cast(bind_addr.get_sockaddr_len()));
  if (e_bind != 0) {
    return OS_SOCKET_ERROR("Failed to bind a socket");
  }
  return SocketFd(make_unique(std::move(native_fd), address));
#endif
}
void SocketFd::close() {
  impl_.reset();
}
bool SocketFd::empty() const {
  return !impl_;
}
PollableFdInfo &SocketFd::get_poll_info() {
  return impl_->get_poll_info();
}
const PollableFdInfo &SocketFd::get_poll_info() const {
  return impl_->get_poll_info();
}
const NativeFd &SocketFd::get_native_fd() const {
  return impl_->get_native_fd();
}
Status SocketFd::get_pending_error() {
  return impl_->get_pending_error();
}
Result SocketFd::write(Slice slice) {
  return impl_->write(slice);
}
Result SocketFd::writev(Span slices) {
  return impl_->writev(slices);
}
Result SocketFd::read(MutableSlice slice) {
  return impl_->read(slice);
}
}  // namespace td