From 35240abc631fcea813e762db93fd25ba6d8756f4 Mon Sep 17 00:00:00 2001 From: Vojta Mrazek Date: Wed, 22 Feb 2023 09:43:24 +0100 Subject: [PATCH] fix bug in python interpretation --- ariths_gen/wire_components/buses.py | 4 +-- ariths_gen/wire_components/wires.py | 6 ++--- tests/test_all.py | 42 ++++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/ariths_gen/wire_components/buses.py b/ariths_gen/wire_components/buses.py index 5128f58..e7589a1 100644 --- a/ariths_gen/wire_components/buses.py +++ b/ariths_gen/wire_components/buses.py @@ -127,7 +127,7 @@ class Bus(): # Ensures correct binding between the bus wire index and the wire itself # It is used for the case when multiple of the same wire (e.g. `ContantWireValue0()`) are present in the bus (its id would otherwise be incorrect when using `self.bus.index(_)`) mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)] - return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} |= {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions]) + return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} = ({self.prefix}) | {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions]) def return_bus_wires_sign_extend_python_flat(self): """Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness. @@ -137,7 +137,7 @@ class Bus(): """ if self.signed is True: last_bus_wire = self.bus[-1] - return "".join([f" {self.prefix} |= {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) + return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) else: return "" diff --git a/ariths_gen/wire_components/wires.py b/ariths_gen/wire_components/wires.py index 645674c..a4f1b7a 100644 --- a/ariths_gen/wire_components/wires.py +++ b/ariths_gen/wire_components/wires.py @@ -40,13 +40,13 @@ class Wire(): str: Python code bitwise shift for storing (constant/variable) wire value at desired offset position. """ if self.is_const(): - return f"({self.c_const}) << {offset}\n" + return f"(({self.c_const}) << {offset})\n" # If wire is part of an input bus (where wire names are concatenated from bus prefix and their index position inside the bus in square brackets) # then the wire value is obtained from bitwise shifting the required wire from the parent bus ('parent_bus.prefix' is the same value as 'self.prefix') elif self.is_buswire(): - return f"(({self.prefix} >> {self.index}) & 0x01) << {offset}\n" + return f"((({self.prefix} >> {self.index}) & 0x01) << {offset})\n" else: - return f"(({self.name} >> 0) & 0x01) << {offset}\n" + return f"((({self.name} >> 0) & 0x01) << {offset})\n" """ C CODE GENERATION """ def get_declaration_c(self): diff --git a/tests/test_all.py b/tests/test_all.py index 5bb9f46..4c63e31 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -35,6 +35,17 @@ from ariths_gen.multi_bit_circuits.approximate_multipliers import ( UnsignedBrokenArrayMultiplier, UnsignedBrokenCarrySaveMultiplier ) + +from ariths_gen.one_bit_circuits.logic_gates import ( + AndGate, + NandGate, + OrGate, + NorGate, + XorGate, + XnorGate, + NotGate +) + import numpy as np @@ -163,4 +174,33 @@ def test_mac(): r = mymac(av, bv, cv) expected = (av * bv) + cv - np.testing.assert_array_equal(r, expected) \ No newline at end of file + np.testing.assert_array_equal(r, expected) + +def test_direct(): + class err_circuit(GeneralCircuit): + def __init__(self, prefix: str = "", name: str = "adder", inner_component: bool = True, a: Bus = Bus(), b: Bus = Bus()): + super().__init__(prefix = prefix, name=name, out_N = (a.N + 1), inner_component=inner_component, inputs = [a, b]) + self.N = 1 + self.prefix = prefix + self.a = Bus(prefix=a.prefix, wires_list=a.bus) + self.b = Bus(prefix=b.prefix, wires_list=b.bus) + self.out = Bus(self.prefix+"_out", self.N+1) + + a_0 = self.a.get_wire(0) + b_0 = self.b.get_wire(0) + + or_1 = OrGate(a_0, b_0, prefix=self.prefix+"_or"+str(self.get_instance_num(cls=OrGate)), parent_component=self) + self.add_component(or_1) + + self.out.connect(0, a_0) + self.out.connect(1, or_1.out) + + + av = np.arange(0, 4).reshape(1, -1) + bv = np.arange(0, 4).reshape(-1, 1) + example = err_circuit(prefix = "err_circuit", a = Bus("a", 2) , b = Bus("b", 2)) + + r = example(av, bv) + expected = np.array([[0, 3, 0, 3], [2, 3 ,2, 3], [0, 3, 0, 3], [2, 3, 2, 3]]) + np.testing.assert_equal(r, expected) + print(r) \ No newline at end of file