ariths-gen/tests/test_shuffle.py
Vojta 0778177106
Some checks failed
BUILD / build (push) Failing after 28s
BUILD / test (push) Has been skipped
BUILD / Python ${{ matrix.python-version }} test (3.10) (push) Has been skipped
BUILD / Python ${{ matrix.python-version }} test (3.11) (push) Has been skipped
BUILD / Python ${{ matrix.python-version }} test (3.12) (push) Has been skipped
BUILD / Python ${{ matrix.python-version }} test (3.9) (push) Has been skipped
BUILD / documentation (push) Has been skipped
tools folder; shuffle circuit
2025-01-30 14:17:08 +01:00

73 lines
1.9 KiB
Python

from ariths_gen.multi_bit_circuits.multipliers import UnsignedArrayMultiplier
import numpy as np
from ariths_gen.tools.shuffle_circuit import ShuffleCircuit
from ariths_gen.wire_components import Bus
from io import StringIO
def test_shuffle_circuit():
a = Bus(N=4, prefix="a")
b = Bus(N=4, prefix="b")
m = UnsignedArrayMultiplier(a, b, prefix="m")
na = np.arange(0, 2**4).reshape(-1, 1)
nb = np.arange(0, 2**4).reshape(1, -1)
assert(np.all(m(na, nb) == na * nb))
o = StringIO()
m.get_cgp_code_flat(o)
shuffled = ShuffleCircuit.from_circuit(m, strategy="random")
assert(np.all(shuffled(na, nb) == na * nb))
def test_shuffle_cgp():
a = Bus(N=4, prefix="a")
b = Bus(N=4, prefix="b")
m = UnsignedArrayMultiplier(a, b, prefix="m")
na = np.arange(0, 2**4).reshape(-1, 1)
nb = np.arange(0, 2**4).reshape(1, -1)
assert(np.all(m(na, nb) == na * nb))
o = StringIO()
m.get_cgp_code_flat(o)
cgp = o.getvalue()
shuffled = ShuffleCircuit(code=cgp.strip(), input_widths=[4, 4])
assert(np.all(shuffled(na, nb) == na * nb))
def test_shuffle_strategies():
a = Bus(N=4, prefix="a")
b = Bus(N=4, prefix="b")
m = UnsignedArrayMultiplier(a, b, prefix="m")
na = np.arange(0, 2**4).reshape(-1, 1)
nb = np.arange(0, 2**4).reshape(1, -1)
assert(np.all(m(na, nb) == na * nb))
o = StringIO()
m.get_cgp_code_flat(o)
cgp = o.getvalue()
shuffled = ShuffleCircuit(code=cgp.strip(), input_widths=[4, 4], strategy="min")
assert(np.all(shuffled(na, nb) == na * nb))
shuffled = ShuffleCircuit(code=cgp.strip(), input_widths=[4, 4], strategy="max")
assert(np.all(shuffled(na, nb) == na * nb))
shuffled = ShuffleCircuit(code=cgp.strip(), input_widths=[4, 4], strategy="random")
assert(np.all(shuffled(na, nb) == na * nb))
if __name__ == "__main__":
test_shuffle_cgp()
test_shuffle_circuit()
test_shuffle_strategies()