signed version of python code

This commit is contained in:
Vojta Mrazek 2024-07-18 13:16:15 +02:00
parent 4cd1189d4a
commit f34471bfe3
2 changed files with 18 additions and 4 deletions

View File

@ -349,16 +349,18 @@ class GeneralCircuit():
return self.out.return_bus_wires_values_python_flat() return self.out.return_bus_wires_values_python_flat()
# Generating flat Python code representation of circuit # Generating flat Python code representation of circuit
def get_python_code_flat(self, file_object):
def get_python_code_flat(self, file_object, retype=True):
"""Generates flat Python code representation of corresponding arithmetic circuit. """Generates flat Python code representation of corresponding arithmetic circuit.
Args: Args:
file_object (TextIOWrapper): Destination file object where circuit's representation will be written to. file_object (TextIOWrapper): Destination file object where circuit's representation will be written to.
retype (bool) specifies if signed output should return int64_t
""" """
file_object.write(self.get_prototype_python()) file_object.write(self.get_prototype_python())
file_object.write(self.get_init_python_flat()+"\n") file_object.write(self.get_init_python_flat()+"\n")
file_object.write(self.get_function_out_python_flat()) file_object.write(self.get_function_out_python_flat())
file_object.write(self.out.return_bus_wires_sign_extend_python_flat()) file_object.write(self.out.return_bus_wires_sign_extend_python_flat(retype=True))
file_object.write(f" return {self.out.prefix}"+"\n") file_object.write(f" return {self.out.prefix}"+"\n")
""" C CODE GENERATION """ """ C CODE GENERATION """

View File

@ -136,7 +136,7 @@ class Bus():
mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)] mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)]
return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} = ({self.prefix}) | {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions]) return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} = ({self.prefix}) | {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions])
def return_bus_wires_sign_extend_python_flat(self): def return_bus_wires_sign_extend_python_flat(self, retype: bool = False):
"""Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness. """Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness.
Returns: Returns:
@ -144,7 +144,19 @@ class Bus():
""" """
if self.signed is True: if self.signed is True:
last_bus_wire = self.bus[-1] last_bus_wire = self.bus[-1]
return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)])
assert self.N < 64, "Sign extension is not supported for bus with more than 64 bits"
if retype:
rewrite = f"""
if hasattr({self.prefix}, 'astype'):
{self.prefix} = {self.prefix}.astype("int64")
else:
from ctypes import c_int64
{self.prefix} = c_int64({self.prefix}).value\n"""
else:
rewrite = ""
return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) + rewrite
else: else:
return "" return ""