diff --git a/python_tests/memo_test_types.py b/python_tests/memo_test_types.py index fd77ab74..651a1f33 100644 --- a/python_tests/memo_test_types.py +++ b/python_tests/memo_test_types.py @@ -234,3 +234,6 @@ class MemoBlob: def __init__(self, size_bytes: int): assert size_bytes <= len(RANDOM_BYTES) self.data = RANDOM_BYTES[:size_bytes] + +def func(x): + return x * 2 \ No newline at end of file diff --git a/python_tests/test_callable.py b/python_tests/test_callable.py new file mode 100644 index 00000000..9ca37c5f --- /dev/null +++ b/python_tests/test_callable.py @@ -0,0 +1,206 @@ +import builtins +import pytest +import math +from functools import partial +from dbzero import db0 +from .memo_test_types import MemoTestClass, func as func_from_other_module +from .conftest import DB0_DIR + +def simple_funcion(x): + return x + 1 + +def some_func(x): + return x - 1 + +multiplier = 10 +def closure_func(x): + return x * multiplier + +def fibonacci(n): + if n <= 1: + return n + return fibonacci(n-1) + fibonacci(n-2) + +def generator_func(n): + for i in range(n): + yield i + +def annotated_func(x: int) -> int: + return x + 1 + +def test_can_store_callable_as_member(db0_fixture): + obj = MemoTestClass(simple_funcion) + value = obj.value(1) + assert value == 2 + +def test_can_store_callable_from_other_module(db0_fixture): + obj = MemoTestClass(func_from_other_module) + value = obj.value(3) + assert value == 6 + +def test_can_store_callable_in_singleton(db0_fixture): + @db0.memo(singleton=True) + class SingletonWithCallable: + def __init__(self, func): + self.func = func + + obj = SingletonWithCallable(func_from_other_module) + value = obj.func(4) + assert value == 8 + prefix_name = db0.get_prefix_of(obj).name + db0.commit() + db0.close() + + # reopen and check again + db0.init(DB0_DIR) + db0.open(prefix_name, "rw") + + obj = SingletonWithCallable() + value = obj.func(5) + assert value == 10 + +def test_can_store_lambda_as_member(db0_fixture): + with pytest.raises(Exception): + obj = MemoTestClass(lambda x: x * 3) + +def test_can_store_class_method_as_member(db0_fixture): + class Helper: + def method(self, x): + return x * 2 + + helper = Helper() + with pytest.raises(Exception): + obj = MemoTestClass(helper.method) + +def test_callable_with_nested_class(db0_fixture): + @db0.memo() + class Container: + def __init__(self, inner_obj): + self.inner = inner_obj + + inner = MemoTestClass(func_from_other_module) + outer = Container(inner) + + result = outer.inner.value(6) + assert result == 12 + +def test_callable_replacement(db0_fixture): + @db0.memo() + class DynamicCallable: + def __init__(self, func): + self.func = func + + def set_func(self, new_func): + self.func = new_func + + obj = DynamicCallable(simple_funcion) + assert obj.func(5) == 6 + + obj.set_func(func_from_other_module) + assert obj.func(5) == 10 + + +def test_builtin_c_function_not_allowed(db0_fixture): + # A built-in C function should throw an exception + builtin_func = builtins.len + + with pytest.raises(AttributeError) as exc_info: + obj = MemoTestClass(builtin_func) + +def multi_arg_func(x, y, z): + return x + y + z + +def test_callable_with_multiple_args(db0_fixture): + """Test callable that accepts multiple arguments""" + + obj = MemoTestClass(multi_arg_func) + value = obj.value(1, 2, 3) + assert value == 6 + +def decorator(func): + def wrapper(x): + return func(x) * 2 + return wrapper + +@decorator +def decorated_func(x): + return x + 1 + +def test_callable_with_decorator(db0_fixture): + """Test storing a decorated function""" + with pytest.raises(AttributeError): + obj = MemoTestClass(decorated_func) + +def test_callable_with_local_func(db0_fixture): + """Test storing a decorated function""" + def some_func(x): + return x - 1 + with pytest.raises(AttributeError): + obj = MemoTestClass(some_func) + + +def test_callable_list_storage(db0_fixture): + """Test storing multiple callables in a list""" + @db0.memo() + class CallableList: + def __init__(self, funcs): + self.funcs = funcs + + obj = CallableList([simple_funcion, func_from_other_module]) + fnct = obj.funcs[0] + assert fnct(10) == 11 + +class Helper: + @staticmethod + def static_func(x): + return x * 3 + +def test_staticmethod_as_callable(db0_fixture): + """Test storing a static method as callable""" + obj = MemoTestClass(Helper.static_func) + value = obj.value(4) + assert value == 12 + + +class HelperClassMethod: + @classmethod + def class_func(cls, x): + return x * 4 + +def test_classmethod_should_fail(db0_fixture): + """Test that class methods raise an exception""" + with pytest.raises(AttributeError): + obj = MemoTestClass(HelperClassMethod.class_func) + +def test_callable_with_closure(db0_fixture): + """Test function with closure variables""" + obj = MemoTestClass(closure_func) + value = obj.value(5) + assert value == 50 + + +def test_callable_recursive_function(db0_fixture): + """Test storing recursive function""" + obj = MemoTestClass(fibonacci) + value = obj.value(6) + assert value == 8 + +def test_callable_generator_function(db0_fixture): + """Test storing generator function""" + obj = MemoTestClass(generator_func) + gen = obj.value(5) + result = list(gen) + assert result == [0, 1, 2, 3, 4] + +def test_callable_partial_function(db0_fixture): + """Test storing functools.partial object""" + partial_func = partial(multi_arg_func, 1, 2) + + with pytest.raises(AttributeError): + obj = MemoTestClass(partial_func) + +def test_callable_function_with_annotations(db0_fixture): + """Test storing function with type annotations""" + obj = MemoTestClass(annotated_func) + value = obj.value(7) + assert value == 8 \ No newline at end of file diff --git a/src/dbzero/bindings/TypeId.hpp b/src/dbzero/bindings/TypeId.hpp index f99c5be0..54e00165 100644 --- a/src/dbzero/bindings/TypeId.hpp +++ b/src/dbzero/bindings/TypeId.hpp @@ -36,6 +36,7 @@ namespace db0::bindings BYTES_ARRAY = 16, BOOLEAN = 17, DECIMAL = 18, + FUNCTION = 19, // dbzero wrappers of common language types MEMO_OBJECT = 100, DB0_LIST = 101, diff --git a/src/dbzero/bindings/python/PyToolkit.cpp b/src/dbzero/bindings/python/PyToolkit.cpp index c25db036..9580d86a 100644 --- a/src/dbzero/bindings/python/PyToolkit.cpp +++ b/src/dbzero/bindings/python/PyToolkit.cpp @@ -38,6 +38,117 @@ namespace db0::python PyToolkit::PyWorkspace PyToolkit::m_py_workspace; SafeRMutex PyToolkit::m_api_mutex; + void PyToolkit::throwErrorWithPyErrorCheck(const std::string& message, const std::string& error_detail) { + if (PyErr_Occurred()) { + PyObject *ptype, *pvalue, *ptraceback; + PyErr_Fetch(&ptype, &pvalue, &ptraceback); + PyObject* str_repr = PyObject_Str(pvalue); + const char* error_msg = str_repr ? PyUnicode_AsUTF8(str_repr) : "Unknown Python error"; + std::string error_str(error_msg); + Py_XDECREF(str_repr); + Py_XDECREF(ptype); + Py_XDECREF(pvalue); + Py_XDECREF(ptraceback); + THROWF(db0::InputException) << message << error_str << THROWF_END; + } else { + THROWF(db0::InputException) << message << error_detail << THROWF_END; + } + } + + std::string PyToolkit::getFullyQualifiedName(ObjectPtr func_obj) { + if (!func_obj) { + THROWF(db0::InputException) << "Null function object" << THROWF_END; + } + + // Reject bound/unbound methods + if (PyMethod_Check(func_obj)) { + THROWF(db0::InputException) << "Methods are not allowed as FUNCTION members" << THROWF_END; + } + + // Reject built-in C functions + if (PyCFunction_Check(func_obj)) { + THROWF(db0::InputException) << "Built-in C functions are not allowed as FUNCTION members" << THROWF_END; + } + + // Get function's __name__, __qualname__, and __module__ + auto name_obj = Py_OWN(PyObject_GetAttrString(func_obj, "__name__")); + auto qualname = Py_OWN(PyObject_GetAttrString(func_obj, "__qualname__")); + auto module_obj = Py_OWN(PyObject_GetAttrString(func_obj, "__module__")); + + if (!name_obj || !qualname || !module_obj) { + THROWF(db0::InputException) << "Failed to get function name, qualname, or module" << THROWF_END; + } + + // Decode UTF-8 strings + const char* name_cstr = PyUnicode_AsUTF8(*name_obj); + const char* qual_cstr = PyUnicode_AsUTF8(*qualname); + const char* module_cstr = PyUnicode_AsUTF8(*module_obj); + + if (!name_cstr || !qual_cstr || !module_cstr) { + THROWF(db0::InputException) << "Failed to decode function attributes as UTF-8" << THROWF_END; + } + + // Reject lambdas + if (strcmp(name_cstr, "") == 0) { + THROWF(db0::InputException) << "Lambda functions are not allowed as FUNCTION members" << THROWF_END; + } + + // Reject decorated or nested functions (qualname contains ) + if (strstr(qual_cstr, "") != nullptr) { + THROWF(db0::InputException) << "Decorated or nested functions are not allowed as FUNCTION members" << THROWF_END; + } + + // Construct fully qualified name: module.qualname + std::stringstream fqn_ss; + fqn_ss << module_cstr << "." << qual_cstr; + return fqn_ss.str(); + } + + typename PyToolkit::ObjectSharedPtr PyToolkit::getFunctionFromFullyQualifiedName(const char* fqn, size_t size) { + // Make a copy to tokenize + char* copy = strndup(fqn, size); + if (!copy) { + THROWF(db0::InputException) << "Failed to unload CALLABLE: memory allocation failed" << THROWF_END; + } + + // First token is the module root + char* p = strchr(copy, '.'); + if (!p) { // No dot = not fully qualified + free(copy); + THROWF(db0::InputException) << "Failed to unload CALLABLE: not a fully qualified name" << THROWF_END; + } + *p = '\0'; + const char* root = copy; + + // Import the module + auto module = Py_OWN(PyImport_ImportModule(root)); + if (!module) { + free(copy); + throwErrorWithPyErrorCheck("Failed to unload CALLABLE: ", + "could not import module"); + } + + auto obj = module; // Start walking attributes + + char* attr = p + 1; + while (attr && *attr) { + char* dot = strchr(attr, '.'); + if (dot) *dot = '\0'; + + auto next = Py_OWN(PyObject_GetAttrString(obj.get(), attr)); + + if (!next) { // Attribute missing + free(copy); + throwErrorWithPyErrorCheck("Failed to unload CALLABLE: ", + "attribute missing"); + } + obj = next; + attr = dot ? dot + 1 : NULL; + } + free(copy); + return obj; // New ref; caller DECREFs + } + std::string PyToolkit::getTypeName(ObjectPtr py_object) { return getTypeName(Py_TYPE(py_object)); } diff --git a/src/dbzero/bindings/python/PyToolkit.hpp b/src/dbzero/bindings/python/PyToolkit.hpp index 6b0b71f3..a0f21cd0 100644 --- a/src/dbzero/bindings/python/PyToolkit.hpp +++ b/src/dbzero/bindings/python/PyToolkit.hpp @@ -219,6 +219,15 @@ namespace db0::python // indicate failed operation with a specific value/code static void setError(ObjectPtr err_obj, std::uint64_t err_value); + // Throw exception with Python error details if available + static void throwErrorWithPyErrorCheck(const std::string& message, const std::string& error_detail = ""); + + // Get fully qualified name of a Python function (validates and rejects invalid function types) + static std::string getFullyQualifiedName(ObjectPtr func_obj); + + // Reconstruct a Python function from its fully qualified name + static ObjectSharedPtr getFunctionFromFullyQualifiedName(const char* fqn, size_t size); + // Check if the object has references from other language objects (other than LangCache) static bool hasLangRefs(ObjectPtr); // Check if there exist any references except specific number of external references diff --git a/src/dbzero/bindings/python/PyTypeManager.cpp b/src/dbzero/bindings/python/PyTypeManager.cpp index 25b342b3..d971d98f 100644 --- a/src/dbzero/bindings/python/PyTypeManager.cpp +++ b/src/dbzero/bindings/python/PyTypeManager.cpp @@ -58,6 +58,8 @@ namespace db0::python addStaticSimpleType(&PyBool_Type, TypeId::BOOLEAN); addStaticSimpleType(Py_TYPE(Py_None), TypeId::NONE); addStaticSimpleType(&PyUnicode_Type, TypeId::STRING); + addStaticSimpleType(&PyFunction_Type, TypeId::FUNCTION); + // add python list type addStaticType(&PyList_Type, TypeId::LIST); addStaticType(&PySet_Type, TypeId::SET); diff --git a/src/dbzero/object_model/value/Member.cpp b/src/dbzero/object_model/value/Member.cpp index e9c09ada..007a2d14 100644 --- a/src/dbzero/object_model/value/Member.cpp +++ b/src/dbzero/object_model/value/Member.cpp @@ -330,6 +330,20 @@ namespace db0::object_model type->incRef(false); return type->getUniqueAddress(); } + + // FUNCTION specialization + template <> Value createMember( + db0::swine_ptr &fixture, + PyObjectPtr obj_ptr, + StorageClass, + AccessFlags access_mode) + { + // Get and validate fully qualified name + auto fqn_str = PyToolkit::getFullyQualifiedName(obj_ptr); + + // Store in your fixture + return db0::v_object(*fixture, fqn_str, access_mode).getAddress(); + } template <> void registerCreateMemberFunctions( std::vector &, PyObjectPtr, StorageClass, AccessFlags)> &functions) @@ -368,6 +382,7 @@ namespace db0::object_model functions[static_cast(TypeId::DB0_BYTES_ARRAY)] = createMember; functions[static_cast(TypeId::DB0_WEAK_PROXY)] = createMember; functions[static_cast(TypeId::MEMO_TYPE)] = createMember; + functions[static_cast(TypeId::FUNCTION)] = createMember; } // STRING_REF specialization @@ -616,6 +631,18 @@ namespace db0::object_model return PyToolkit::getTypeManager().getLangConstant(val_code); } + // CALLABLE specialization + template <> typename PyToolkit::ObjectSharedPtr unloadMember( + db0::swine_ptr &fixture, Value value, unsigned int, AccessFlags access_mode) + { + db0::v_object string_ref(fixture->myPtr(value.asAddress()), access_mode); + auto str_ptr = string_ref->get(); + + // Reconstruct function from its qualified name + return PyToolkit::getFunctionFromFullyQualifiedName(str_ptr.get_raw(), str_ptr.size()); + } + + template <> void registerUnloadMemberFunctions( std::vector &, Value, unsigned int, AccessFlags)> &functions) { @@ -646,6 +673,7 @@ namespace db0::object_model functions[static_cast(StorageClass::OBJECT_WEAK_REF)] = unloadMember; functions[static_cast(StorageClass::OBJECT_LONG_WEAK_REF)] = unloadMember; functions[static_cast(StorageClass::PACK_2)] = unloadMember; + functions[static_cast(StorageClass::CALLABLE)] = unloadMember; } template diff --git a/src/dbzero/object_model/value/StorageClass.cpp b/src/dbzero/object_model/value/StorageClass.cpp index f18e67b4..b52341bc 100644 --- a/src/dbzero/object_model/value/StorageClass.cpp +++ b/src/dbzero/object_model/value/StorageClass.cpp @@ -51,6 +51,7 @@ namespace db0::object_model addMapping(TypeId::DB0_BYTES_ARRAY, PreStorageClass::DB0_BYTES_ARRAY); // Note: DB0_WEAK_PROXY by default maps to OBJECT_WEAK_REF but can also be OBJECT_LONG_WEAK_REF which needs to be checked addMapping(TypeId::DB0_WEAK_PROXY, PreStorageClass::OBJECT_WEAK_REF); + addMapping(TypeId::FUNCTION, PreStorageClass::CALLABLE); } PreStorageClass StorageClassMapper::getPreStorageClass(TypeId type_id, bool allow_packed) const diff --git a/src/dbzero/object_model/value/StorageClass.hpp b/src/dbzero/object_model/value/StorageClass.hpp index a9eebeb8..05b3a3b6 100644 --- a/src/dbzero/object_model/value/StorageClass.hpp +++ b/src/dbzero/object_model/value/StorageClass.hpp @@ -62,6 +62,8 @@ namespace db0::object_model OBJECT_WEAK_REF = 30, // deleted value (placeholder) DELETED = 31, + CALLABLE = 32, + COUNT = std::numeric_limits::max() - 32, // invalid / reserved value, never used in objects INVALID = std::numeric_limits::max() @@ -111,6 +113,7 @@ namespace db0::object_model // weak reference to other (Memo) instance on the same prefix OBJECT_WEAK_REF = static_cast(PreStorageClass::OBJECT_WEAK_REF), DELETED = static_cast(PreStorageClass::DELETED), + CALLABLE = static_cast(PreStorageClass::CALLABLE), // weak reference to other (Memo) instance from a foreign prefix OBJECT_LONG_WEAK_REF = static_cast(PreStorageClass::COUNT), // COUNT used to determine size of the StorageClass associated arrays