shrlok: require length before the message
This commit is contained in:
parent
911b60b792
commit
e7dfc359f5
|
@ -1,4 +1,27 @@
|
|||
#!/usr/bin/env python3
|
||||
"""shrlok server: receive text/files from a socket, put them on a Web server.
|
||||
|
||||
The server expects messages with the following format to come through the
|
||||
socket:
|
||||
|
||||
1. a length, as ASCII digits,
|
||||
2. a null byte,
|
||||
3. a JSON object containing at least the "type" key with a known value,
|
||||
4. another null byte,
|
||||
5. the text or file itself.
|
||||
|
||||
The length is of the JSON object + the null char + the content itself.
|
||||
|
||||
Example, with \\0 representing a null byte:
|
||||
|
||||
28\\0{"type":"txt"}\\0hello shrlok!
|
||||
|
||||
The content is 13 bytes (assuming no LF at the end), the header is 14 bytes,
|
||||
plus the null byte it is 13 + 14 + 1 = 28 bytes, thus the prefixed length.
|
||||
|
||||
After the content is succesfully retrieved and put on an appropriate location,
|
||||
the server will reply the file path through the socket and close the connection.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
@ -29,11 +52,9 @@ HTML_TEMPLATE = """\
|
|||
class Handler(socketserver.StreamRequestHandler):
|
||||
|
||||
def handle(self):
|
||||
fragments = []
|
||||
while (chunk := self.request.recv(4096)):
|
||||
fragments.append(chunk)
|
||||
data = b"".join(fragments)
|
||||
data = self.receive_input()
|
||||
|
||||
# Extract header.
|
||||
try:
|
||||
first_zero = data.index(b"\0")
|
||||
header_data, data = data[:first_zero], data[first_zero + 1:]
|
||||
|
@ -42,17 +63,38 @@ class Handler(socketserver.StreamRequestHandler):
|
|||
print("Bad header.")
|
||||
return
|
||||
|
||||
file_name = None
|
||||
if header.get("type") == "txt":
|
||||
file_name = write_text(data, title=header.get("title"))
|
||||
else:
|
||||
print("Unknown type.")
|
||||
return
|
||||
|
||||
if not file_name:
|
||||
if file_name is None:
|
||||
return
|
||||
|
||||
print(f"{len(data)} bytes — {header} — '{file_name}'.")
|
||||
self.request.sendall(file_name.encode())
|
||||
try:
|
||||
self.request.sendall(file_name.encode())
|
||||
except BrokenPipeError:
|
||||
print("Broken pipe.")
|
||||
|
||||
def receive_input(self):
|
||||
length = None
|
||||
data = b""
|
||||
while True:
|
||||
chunk = self.request.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
if length is None:
|
||||
try:
|
||||
first_zero = data.index(b"\0")
|
||||
except ValueError:
|
||||
return b""
|
||||
length_data, data = data[:first_zero], data[first_zero + 1:]
|
||||
length = int(length_data.decode())
|
||||
if len(data) >= length: # retrieval completed
|
||||
break
|
||||
return data
|
||||
|
||||
|
||||
def write_text(data: bytes, title=None):
|
||||
|
@ -92,6 +134,7 @@ def main():
|
|||
|
||||
try:
|
||||
with socketserver.UnixStreamServer(socket_path, Handler) as server:
|
||||
os.chmod(socket_path, 0o664)
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping server.")
|
||||
|
|
Reference in a new issue