From f34471bfe3e3a976216556247d901180eb75b0d0 Mon Sep 17 00:00:00 2001 From: Vojta Mrazek Date: Thu, 18 Jul 2024 13:16:15 +0200 Subject: [PATCH] signed version of python code --- .../core/arithmetic_circuits/general_circuit.py | 6 ++++-- ariths_gen/wire_components/buses.py | 16 ++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ariths_gen/core/arithmetic_circuits/general_circuit.py b/ariths_gen/core/arithmetic_circuits/general_circuit.py index 76142a2..b8543e5 100644 --- a/ariths_gen/core/arithmetic_circuits/general_circuit.py +++ b/ariths_gen/core/arithmetic_circuits/general_circuit.py @@ -349,16 +349,18 @@ class GeneralCircuit(): return self.out.return_bus_wires_values_python_flat() # Generating flat Python code representation of circuit - def get_python_code_flat(self, file_object): + + def get_python_code_flat(self, file_object, retype=True): """Generates flat Python code representation of corresponding arithmetic circuit. Args: file_object (TextIOWrapper): Destination file object where circuit's representation will be written to. + retype (bool) specifies if signed output should return int64_t """ file_object.write(self.get_prototype_python()) file_object.write(self.get_init_python_flat()+"\n") file_object.write(self.get_function_out_python_flat()) - file_object.write(self.out.return_bus_wires_sign_extend_python_flat()) + file_object.write(self.out.return_bus_wires_sign_extend_python_flat(retype=True)) file_object.write(f" return {self.out.prefix}"+"\n") """ C CODE GENERATION """ diff --git a/ariths_gen/wire_components/buses.py b/ariths_gen/wire_components/buses.py index 24afa4b..1919ad6 100644 --- a/ariths_gen/wire_components/buses.py +++ b/ariths_gen/wire_components/buses.py @@ -136,7 +136,7 @@ class Bus(): mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)] return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} = ({self.prefix}) | {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions]) - def return_bus_wires_sign_extend_python_flat(self): + def return_bus_wires_sign_extend_python_flat(self, retype: bool = False): """Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness. Returns: @@ -144,7 +144,19 @@ class Bus(): """ if self.signed is True: last_bus_wire = self.bus[-1] - return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) + + assert self.N < 64, "Sign extension is not supported for bus with more than 64 bits" + if retype: + rewrite = f""" + if hasattr({self.prefix}, 'astype'): + {self.prefix} = {self.prefix}.astype("int64") + else: + from ctypes import c_int64 + {self.prefix} = c_int64({self.prefix}).value\n""" + else: + rewrite = "" + + return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) + rewrite else: return ""