signed version of python code

This commit is contained in:
Vojta Mrazek 2024-07-18 13:16:15 +02:00
parent 4cd1189d4a
commit f34471bfe3
2 changed files with 18 additions and 4 deletions

View File

@ -349,16 +349,18 @@ class GeneralCircuit():
return self.out.return_bus_wires_values_python_flat()
# Generating flat Python code representation of circuit
def get_python_code_flat(self, file_object):
def get_python_code_flat(self, file_object, retype=True):
"""Generates flat Python code representation of corresponding arithmetic circuit.
Args:
file_object (TextIOWrapper): Destination file object where circuit's representation will be written to.
retype (bool) specifies if signed output should return int64_t
"""
file_object.write(self.get_prototype_python())
file_object.write(self.get_init_python_flat()+"\n")
file_object.write(self.get_function_out_python_flat())
file_object.write(self.out.return_bus_wires_sign_extend_python_flat())
file_object.write(self.out.return_bus_wires_sign_extend_python_flat(retype=True))
file_object.write(f" return {self.out.prefix}"+"\n")
""" C CODE GENERATION """

View File

@ -136,7 +136,7 @@ class Bus():
mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)]
return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} = ({self.prefix}) | {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions])
def return_bus_wires_sign_extend_python_flat(self):
def return_bus_wires_sign_extend_python_flat(self, retype: bool = False):
"""Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness.
Returns:
@ -144,7 +144,19 @@ class Bus():
"""
if self.signed is True:
last_bus_wire = self.bus[-1]
return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)])
assert self.N < 64, "Sign extension is not supported for bus with more than 64 bits"
if retype:
rewrite = f"""
if hasattr({self.prefix}, 'astype'):
{self.prefix} = {self.prefix}.astype("int64")
else:
from ctypes import c_int64
{self.prefix} = c_int64({self.prefix}).value\n"""
else:
rewrite = ""
return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) + rewrite
else:
return ""