shrlok: require length before the message

This commit is contained in:
dece 2022-07-05 23:38:16 +02:00
parent 911b60b792
commit e7dfc359f5

View file

@ -1,4 +1,27 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""shrlok server: receive text/files from a socket, put them on a Web server.
The server expects messages with the following format to come through the
socket:
1. a length, as ASCII digits,
2. a null byte,
3. a JSON object containing at least the "type" key with a known value,
4. another null byte,
5. the text or file itself.
The length is of the JSON object + the null char + the content itself.
Example, with \\0 representing a null byte:
28\\0{"type":"txt"}\\0hello shrlok!
The content is 13 bytes (assuming no LF at the end), the header is 14 bytes,
plus the null byte it is 13 + 14 + 1 = 28 bytes, thus the prefixed length.
After the content is succesfully retrieved and put on an appropriate location,
the server will reply the file path through the socket and close the connection.
"""
import argparse import argparse
import json import json
@ -29,11 +52,9 @@ HTML_TEMPLATE = """\
class Handler(socketserver.StreamRequestHandler): class Handler(socketserver.StreamRequestHandler):
def handle(self): def handle(self):
fragments = [] data = self.receive_input()
while (chunk := self.request.recv(4096)):
fragments.append(chunk)
data = b"".join(fragments)
# Extract header.
try: try:
first_zero = data.index(b"\0") first_zero = data.index(b"\0")
header_data, data = data[:first_zero], data[first_zero + 1:] header_data, data = data[:first_zero], data[first_zero + 1:]
@ -42,17 +63,38 @@ class Handler(socketserver.StreamRequestHandler):
print("Bad header.") print("Bad header.")
return return
file_name = None
if header.get("type") == "txt": if header.get("type") == "txt":
file_name = write_text(data, title=header.get("title")) file_name = write_text(data, title=header.get("title"))
else: else:
print("Unknown type.") print("Unknown type.")
return if file_name is None:
if not file_name:
return return
print(f"{len(data)} bytes — {header}'{file_name}'.") print(f"{len(data)} bytes — {header}'{file_name}'.")
self.request.sendall(file_name.encode()) try:
self.request.sendall(file_name.encode())
except BrokenPipeError:
print("Broken pipe.")
def receive_input(self):
length = None
data = b""
while True:
chunk = self.request.recv(4096)
if not chunk:
break
data += chunk
if length is None:
try:
first_zero = data.index(b"\0")
except ValueError:
return b""
length_data, data = data[:first_zero], data[first_zero + 1:]
length = int(length_data.decode())
if len(data) >= length: # retrieval completed
break
return data
def write_text(data: bytes, title=None): def write_text(data: bytes, title=None):
@ -92,6 +134,7 @@ def main():
try: try:
with socketserver.UnixStreamServer(socket_path, Handler) as server: with socketserver.UnixStreamServer(socket_path, Handler) as server:
os.chmod(socket_path, 0o664)
server.serve_forever() server.serve_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
print("Stopping server.") print("Stopping server.")