Skip to content

Commit

Permalink
cleaning up socket and making it work again
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Aug 29, 2024
1 parent 22b7b66 commit ffc62f8
Show file tree
Hide file tree
Showing 16 changed files with 181 additions and 150 deletions.
Binary file modified benchmarks/gojo.mojopkg
Binary file not shown.
Empty file removed examples/__init__.mojo
Empty file.
Empty file removed examples/scanner/__init__.mojo
Empty file.
Empty file removed examples/tcp/__init__.mojo
Empty file.
2 changes: 1 addition & 1 deletion examples/tcp/dial_client.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() raises:
return

# Read the response from the connection
var response = List[UInt8](capacity=4096)
var response = List[UInt8, True](capacity=4096)
var bytes_read: Int = 0
bytes_read, err = connection.read(response)
if err and str(err) != io.EOF:
Expand Down
2 changes: 1 addition & 1 deletion examples/tcp/get_request.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn main() raises:
return

# Read the response from the connection
var response = List[UInt8](capacity=4096)
var response = List[UInt8, True](capacity=4096)
var bytes_read: Int = 0
bytes_read, err = connection.read(response)
if err:
Expand Down
2 changes: 1 addition & 1 deletion examples/tcp/listener_server.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn main() raises:
var connection = listener.accept()

# Read the contents of the message from the client.
var bytes = List[UInt8](capacity=4096)
var bytes = List[UInt8, True](capacity=4096)
var bytes_read: Int
var err: Error
bytes_read, err = connection.read(bytes)
Expand Down
2 changes: 1 addition & 1 deletion examples/tcp/socket_client.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn main() raises:
bytes_sent, err = socket.write(message.as_bytes())
print("Message sent:", message)

var bytes = List[UInt8](capacity=16)
var bytes = List[UInt8, True](capacity=16)
var bytes_read: Int
bytes_read, err = socket.read(bytes)
if str(err) != str(io.EOF):
Expand Down
2 changes: 1 addition & 1 deletion examples/tcp/socket_server.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn main() raises:
print("Serving", str(connection.remote_address_as_tcp()))

# Read the contents of the message from the client.
var bytes = List[UInt8](capacity=4096)
var bytes = List[UInt8, True](capacity=4096)
var bytes_read: Int
var err: Error
bytes_read, err = connection.read(bytes)
Expand Down
Empty file removed examples/udp/__init__.mojo
Empty file.
1 change: 1 addition & 0 deletions src/gojo/net/fd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from ..syscall import (
close,
FileDescriptorBase,
)
from sys import external_call

alias O_RDWR = 0o2

Expand Down
13 changes: 12 additions & 1 deletion src/gojo/net/ip.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,21 @@ fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) ->
var bin_port = convert_port_to_binary(port)
var bin_ip = convert_ip_to_binary(ip_address, address_family)

var ai = sockaddr_in(address_family, bin_port, bin_ip, InlineArray[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0))
var ai = sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0))
return UnsafePointer.address_of(ai).bitcast[sockaddr]()


fn build_sockaddr_in(ip_address: String, port: Int, address_family: Int) -> sockaddr_in:
"""Build a sockaddr pointer from an IP address and port number.
https://learn.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-sockaddr_in.
"""
var bin_port = convert_port_to_binary(port)
var bin_ip = convert_ip_to_binary(ip_address, address_family)

return sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0))


fn convert_sockaddr_to_host_port(sockaddr: UnsafePointer[sockaddr]) -> (HostPort, Error):
"""Casts a sockaddr pointer to a sockaddr_in pointer and converts the binary IP and port to a string and int respectively.
Expand Down
78 changes: 44 additions & 34 deletions src/gojo/net/socket.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ from .fd import FileDescriptor, FileDescriptorBase
from .ip import (
convert_binary_ip_to_string,
build_sockaddr_pointer,
build_sockaddr_in,
convert_binary_port_to_int,
convert_sockaddr_to_host_port,
)
Expand Down Expand Up @@ -178,21 +179,22 @@ struct Socket(FileDescriptorBase):
The return value is a connection where conn is a new socket object usable to send and receive data on the connection,
and address is the address bound to the socket on the other end of the connection.
"""
var remote_address_ptr = UnsafePointer[sockaddr].alloc(1)
var sin_size = socklen_t(sizeof[socklen_t]())
var remote_address = sockaddr()
var new_fd = accept(
self.fd.fd,
remote_address_ptr,
UnsafePointer[socklen_t].address_of(sin_size),
UnsafePointer.address_of(remote_address),
UnsafePointer.address_of(socklen_t(sizeof[socklen_t]())),
)
if new_fd == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("accept").unsafe_ptr())
raise Error("Failed to accept connection")

var remote: HostPort
var err: Error
remote, err = convert_sockaddr_to_host_port(remote_address_ptr)
remote, err = convert_sockaddr_to_host_port(UnsafePointer.address_of(remote_address))
if err:
raise err
_ = remote_address

return Socket(
new_fd,
Expand Down Expand Up @@ -230,14 +232,14 @@ struct Socket(FileDescriptorBase):
address: String - The IP address to bind the socket to.
port: The port number to bind the socket to.
"""
var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family)
print(sockaddr_pointer.bitcast[sockaddr_in]()[].sin_family)
print(sockaddr_pointer.bitcast[sockaddr_in]()[].sin_port)
print(sockaddr_pointer.bitcast[sockaddr_in]()[].sin_addr.s_addr)
if bind(self.fd.fd, sockaddr_pointer, sizeof[sockaddr]()) == -1:
# var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family)
var sa_in = build_sockaddr_in(address, port, self.address_family)
# var sa_in = build_sockaddr_in(address, port, self.address_family)
if bind(self.fd.fd, UnsafePointer.address_of(sa_in), sizeof[sockaddr_in]()) == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("bind").unsafe_ptr())
_ = shutdown(self.fd.fd, SHUT_RDWR)
raise Error("Binding socket failed. Wait a few seconds and try again?")
_ = sa_in

var local = self.get_sock_name()
self.local_address = BaseAddr(local.host, local.port)
Expand All @@ -252,21 +254,26 @@ struct Socket(FileDescriptorBase):
raise SocketClosedError

# TODO: Add check to see if the socket is bound and error if not.

var local_address_ptr = UnsafePointer[sockaddr].alloc(1)
var local_address_ptr_size = socklen_t(sizeof[sockaddr]())
var sa = sockaddr()
# print(sa.sa_family)
var status = getsockname(
self.fd.fd,
local_address_ptr,
UnsafePointer[socklen_t].address_of(local_address_ptr_size),
UnsafePointer.address_of(sa),
UnsafePointer.address_of(socklen_t(sizeof[sockaddr]())),
)
if status == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("getsockname").unsafe_ptr())
raise Error("Socket.get_sock_name: Failed to get address of local socket.")
var addr_in = local_address_ptr.bitcast[sockaddr_in]().take_pointee()

# print(sa.sa_family)
var addr_in = UnsafePointer.address_of(sa).bitcast[sockaddr_in]()
# print(sa.sa_family, addr_in.sin_addr.s_addr, addr_in.sin_port)
# var addr_in = local_address_ptr.bitcast[sockaddr_in]().take_pointee()
# print(convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AddressFamily.AF_INET, 16), convert_binary_port_to_int(addr_in.sin_port))
# _ = sa
_ = addr_in
return HostPort(
host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AddressFamily.AF_INET, 16),
port=convert_binary_port_to_int(addr_in.sin_port),
host=convert_binary_ip_to_string(addr_in[].sin_addr.s_addr, AddressFamily.AF_INET, 16),
port=convert_binary_port_to_int(addr_in[].sin_port),
)

fn get_peer_name(self) -> (HostPort, Error):
Expand Down Expand Up @@ -340,11 +347,12 @@ struct Socket(FileDescriptorBase):
address: String - The IP address to connect to.
port: The port number to connect to.
"""
var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family)

if connect(self.fd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1:
var sa_in = build_sockaddr_in(address, port, self.address_family)
if connect(self.fd.fd, UnsafePointer.address_of(sa_in), sizeof[sockaddr_in]()) == -1:
_ = external_call["perror", c_void, UnsafePointer[UInt8]](String("connect").unsafe_ptr())
self.shutdown()
return Error("Socket.connect: Failed to connect to the remote socket at: " + address + ":" + str(port))
_ = sa_in

var remote: HostPort
var err: Error
Expand Down Expand Up @@ -444,7 +452,7 @@ struct Socket(FileDescriptorBase):
return bytes, Error()

fn _read(inout self, inout dest: UnsafePointer[UInt8], capacity: Int) -> (Int, Error):
"""Receive data from the socket into the buffer dest. Equivalent to recv_into().
"""Receive data from the socket into the buffer dest.
Args:
dest: The buffer to read data into.
Expand All @@ -456,24 +464,26 @@ struct Socket(FileDescriptorBase):
return self.fd._read(dest, capacity)

fn read(inout self, inout dest: List[UInt8, True]) -> (Int, Error):
"""Receive data from the socket into the buffer dest. Equivalent to recv_into().
"""Receive data from the socket into the buffer dest. Equivalent to `recv_into()`.
Args:
dest: The buffer to read data into.
Returns:
The number of bytes read, and an error if one occurred.
"""
if dest.size == dest.capacity:
return 0, Error("net.socket.Socket.read: no space left in destination buffer.")

var dest_ptr = dest.unsafe_ptr().offset(dest.size)
var bytes_read: Int
var err: Error
bytes_read, err = self._read(dest_ptr, dest.capacity - dest.size)
dest.size += bytes_read

return bytes_read, err
return self.fd.read(dest)
# if dest.size == dest.capacity:
# return 0, Error("net.socket.Socket.read: no space left in destination buffer.")

# var dest_ptr = dest.unsafe_ptr().offset(dest.size)
# var bytes_read: Int
# var err: Error
# bytes_read, err = self._read(dest_ptr, dest.capacity - dest.size)
# dest.size += bytes_read

# print(bytes_read, str(err))
# return bytes_read, err

fn receive_from(inout self, size: Int = io.BUFFER_SIZE) -> (List[UInt8, True], HostPort, Error):
"""Receive data from the socket into the buffer dest.
Expand Down
93 changes: 51 additions & 42 deletions src/gojo/syscall/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ struct SocketOptions:
struct in_addr:
var s_addr: in_addr_t

fn __init__(inout self, addr: in_addr_t = 0):
self.s_addr = addr


@value
@register_passable("trivial")
Expand All @@ -297,17 +300,35 @@ struct in6_addr:


@value
@register_passable("trivial")
struct sockaddr:
var sa_family: sa_family_t
var sa_data: InlineArray[c_char, 14]
var sa_data: StaticTuple[c_char, 14]

fn __init__(inout self, family: sa_family_t = 0, data: StaticTuple[c_char, 14] = StaticTuple[c_char, 14]()):
self.sa_family = family
self.sa_data = data


@value
@register_passable("trivial")
struct sockaddr_in:
var sin_family: sa_family_t
var sin_port: in_port_t
var sin_addr: in_addr
var sin_zero: InlineArray[c_char, 8]
var sin_zero: StaticTuple[c_char, 8]

fn __init__(
inout self,
family: sa_family_t = 0,
port: in_port_t = 0,
addr: in_addr = in_addr(),
zero: StaticTuple[c_char, 8] = StaticTuple[c_char, 8](),
):
self.sin_family = family
self.sin_port = port
self.sin_addr = addr
self.sin_zero = zero


@value
Expand Down Expand Up @@ -337,26 +358,6 @@ struct addrinfo:
var ai_addr: UnsafePointer[sockaddr]
var ai_next: UnsafePointer[addrinfo]

fn __init__(
inout self,
ai_flags: c_int = 0,
ai_family: c_int = 0,
ai_socktype: c_int = 0,
ai_protocol: c_int = 0,
ai_addrlen: socklen_t = 0,
ai_canonname: UnsafePointer[UInt8] = UnsafePointer[UInt8](),
ai_addr: UnsafePointer[sockaddr] = UnsafePointer[sockaddr](),
ai_next: UnsafePointer[addrinfo] = UnsafePointer[addrinfo](),
):
self.ai_flags = ai_flags
self.ai_family = ai_family
self.ai_socktype = ai_socktype
self.ai_protocol = ai_protocol
self.ai_addrlen = ai_addrlen
self.ai_canonname = ai_canonname
self.ai_addr = ai_addr
self.ai_next = ai_next


@value
@register_passable("trivial")
Expand All @@ -375,26 +376,6 @@ struct addrinfo_unix:
var ai_canonname: UnsafePointer[UInt8]
var ai_next: UnsafePointer[addrinfo]

fn __init__(
inout self,
ai_flags: c_int = 0,
ai_family: c_int = 0,
ai_socktype: c_int = 0,
ai_protocol: c_int = 0,
ai_addrlen: socklen_t = 0,
ai_canonname: UnsafePointer[UInt8] = UnsafePointer[UInt8](),
ai_addr: UnsafePointer[sockaddr] = UnsafePointer[sockaddr](),
ai_next: UnsafePointer[addrinfo] = UnsafePointer[addrinfo](),
):
self.ai_flags = ai_flags
self.ai_family = ai_family
self.ai_socktype = ai_socktype
self.ai_protocol = ai_protocol
self.ai_addrlen = ai_addrlen
self.ai_canonname = ai_canonname
self.ai_addr = ai_addr
self.ai_next = ai_next


# --- ( Network Related Syscalls & Structs )------------------------------------

Expand Down Expand Up @@ -675,6 +656,16 @@ fn bind(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t)
)


fn bind(socket: c_int, address: UnsafePointer[sockaddr_in], address_len: socklen_t) -> c_int:
"""Libc POSIX `bind` function
Reference: https://man7.org/linux/man-pages/man3/bind.3p.html
Fn signature: `int bind(int socket, const struct sockaddr *address, socklen_t address_len)`.
"""
return external_call["bind", c_int, c_int, UnsafePointer[sockaddr_in], socklen_t]( # FnName, RetType # Args
socket, address, address_len
)


fn listen(socket: c_int, backlog: c_int) -> c_int:
"""Libc POSIX `listen` function
Reference: https://man7.org/linux/man-pages/man3/listen.3p.html
Expand Down Expand Up @@ -734,6 +725,24 @@ fn connect(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen
)


fn connect(socket: c_int, address: UnsafePointer[sockaddr_in], address_len: socklen_t) -> c_int:
"""Libc POSIX `connect` function
Reference: https://man7.org/linux/man-pages/man3/connect.3p.html
Fn signature: `int connect(int socket, const struct sockaddr *address, socklen_t address_len)`.
Args:
socket: A File Descriptor.
address: A pointer to the address to connect to.
address_len: The size of the address.
Returns:
0 on success, -1 on error.
"""
return external_call["connect", c_int, c_int, UnsafePointer[sockaddr_in], socklen_t]( # FnName, RetType # Args
socket, address, address_len
)


fn recv(
socket: c_int,
buffer: UnsafePointer[UInt8],
Expand Down
Binary file modified test/gojo.mojopkg
Binary file not shown.
Loading

0 comments on commit ffc62f8

Please sign in to comment.