fix bug in python interpretation

This commit is contained in:
Vojta Mrazek 2023-02-22 09:43:24 +01:00
parent f17e87738e
commit 35240abc63
3 changed files with 46 additions and 6 deletions

View File

@ -127,7 +127,7 @@ class Bus():
# Ensures correct binding between the bus wire index and the wire itself # Ensures correct binding between the bus wire index and the wire itself
# It is used for the case when multiple of the same wire (e.g. `ContantWireValue0()`) are present in the bus (its id would otherwise be incorrect when using `self.bus.index(_)`) # It is used for the case when multiple of the same wire (e.g. `ContantWireValue0()`) are present in the bus (its id would otherwise be incorrect when using `self.bus.index(_)`)
mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)] mapped_positions = [(w_id, self.bus[w_id]) for w_id in range(self.N)]
return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} |= {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions]) return "".join([f" {self.prefix} = 0\n"] + [f" {self.prefix} = ({self.prefix}) | {w[1].return_wire_value_python_flat(offset=w[0])}" for w in mapped_positions])
def return_bus_wires_sign_extend_python_flat(self): def return_bus_wires_sign_extend_python_flat(self):
"""Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness. """Sign extends the bus's corresponding Python variable (object) to ensure proper flat Python code variable signedness.
@ -137,7 +137,7 @@ class Bus():
""" """
if self.signed is True: if self.signed is True:
last_bus_wire = self.bus[-1] last_bus_wire = self.bus[-1]
return "".join([f" {self.prefix} |= {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)]) return "".join([f" {self.prefix} = ({self.prefix}) | {last_bus_wire.return_wire_value_python_flat(offset=i)}" for i in range(len(self.bus), 64)])
else: else:
return "" return ""

View File

@ -40,13 +40,13 @@ class Wire():
str: Python code bitwise shift for storing (constant/variable) wire value at desired offset position. str: Python code bitwise shift for storing (constant/variable) wire value at desired offset position.
""" """
if self.is_const(): if self.is_const():
return f"({self.c_const}) << {offset}\n" return f"(({self.c_const}) << {offset})\n"
# If wire is part of an input bus (where wire names are concatenated from bus prefix and their index position inside the bus in square brackets) # If wire is part of an input bus (where wire names are concatenated from bus prefix and their index position inside the bus in square brackets)
# then the wire value is obtained from bitwise shifting the required wire from the parent bus ('parent_bus.prefix' is the same value as 'self.prefix') # then the wire value is obtained from bitwise shifting the required wire from the parent bus ('parent_bus.prefix' is the same value as 'self.prefix')
elif self.is_buswire(): elif self.is_buswire():
return f"(({self.prefix} >> {self.index}) & 0x01) << {offset}\n" return f"((({self.prefix} >> {self.index}) & 0x01) << {offset})\n"
else: else:
return f"(({self.name} >> 0) & 0x01) << {offset}\n" return f"((({self.name} >> 0) & 0x01) << {offset})\n"
""" C CODE GENERATION """ """ C CODE GENERATION """
def get_declaration_c(self): def get_declaration_c(self):

View File

@ -35,6 +35,17 @@ from ariths_gen.multi_bit_circuits.approximate_multipliers import (
UnsignedBrokenArrayMultiplier, UnsignedBrokenArrayMultiplier,
UnsignedBrokenCarrySaveMultiplier UnsignedBrokenCarrySaveMultiplier
) )
from ariths_gen.one_bit_circuits.logic_gates import (
AndGate,
NandGate,
OrGate,
NorGate,
XorGate,
XnorGate,
NotGate
)
import numpy as np import numpy as np
@ -164,3 +175,32 @@ def test_mac():
r = mymac(av, bv, cv) r = mymac(av, bv, cv)
expected = (av * bv) + cv expected = (av * bv) + cv
np.testing.assert_array_equal(r, expected) np.testing.assert_array_equal(r, expected)
def test_direct():
class err_circuit(GeneralCircuit):
def __init__(self, prefix: str = "", name: str = "adder", inner_component: bool = True, a: Bus = Bus(), b: Bus = Bus()):
super().__init__(prefix = prefix, name=name, out_N = (a.N + 1), inner_component=inner_component, inputs = [a, b])
self.N = 1
self.prefix = prefix
self.a = Bus(prefix=a.prefix, wires_list=a.bus)
self.b = Bus(prefix=b.prefix, wires_list=b.bus)
self.out = Bus(self.prefix+"_out", self.N+1)
a_0 = self.a.get_wire(0)
b_0 = self.b.get_wire(0)
or_1 = OrGate(a_0, b_0, prefix=self.prefix+"_or"+str(self.get_instance_num(cls=OrGate)), parent_component=self)
self.add_component(or_1)
self.out.connect(0, a_0)
self.out.connect(1, or_1.out)
av = np.arange(0, 4).reshape(1, -1)
bv = np.arange(0, 4).reshape(-1, 1)
example = err_circuit(prefix = "err_circuit", a = Bus("a", 2) , b = Bus("b", 2))
r = example(av, bv)
expected = np.array([[0, 3, 0, 3], [2, 3 ,2, 3], [0, 3, 0, 3], [2, 3, 2, 3]])
np.testing.assert_equal(r, expected)
print(r)