Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python_tests/memo_test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,6 @@ class MemoBlob:
def __init__(self, size_bytes: int):
assert size_bytes <= len(RANDOM_BYTES)
self.data = RANDOM_BYTES[:size_bytes]

def func(x):
return x * 2
206 changes: 206 additions & 0 deletions python_tests/test_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import builtins
import pytest
import math
from functools import partial
from dbzero import db0
from .memo_test_types import MemoTestClass, func as func_from_other_module
from .conftest import DB0_DIR

def simple_funcion(x):
return x + 1

def some_func(x):
return x - 1

multiplier = 10
def closure_func(x):
return x * multiplier

def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)

def generator_func(n):
for i in range(n):
yield i

def annotated_func(x: int) -> int:
return x + 1

def test_can_store_callable_as_member(db0_fixture):
obj = MemoTestClass(simple_funcion)
value = obj.value(1)
assert value == 2

def test_can_store_callable_from_other_module(db0_fixture):
obj = MemoTestClass(func_from_other_module)
value = obj.value(3)
assert value == 6

def test_can_store_callable_in_singleton(db0_fixture):
@db0.memo(singleton=True)
class SingletonWithCallable:
def __init__(self, func):
self.func = func

obj = SingletonWithCallable(func_from_other_module)
value = obj.func(4)
assert value == 8
prefix_name = db0.get_prefix_of(obj).name
db0.commit()
db0.close()

# reopen and check again
db0.init(DB0_DIR)
db0.open(prefix_name, "rw")

obj = SingletonWithCallable()
value = obj.func(5)
assert value == 10

def test_can_store_lambda_as_member(db0_fixture):
with pytest.raises(Exception):
obj = MemoTestClass(lambda x: x * 3)

def test_can_store_class_method_as_member(db0_fixture):
class Helper:
def method(self, x):
return x * 2

helper = Helper()
with pytest.raises(Exception):
obj = MemoTestClass(helper.method)

def test_callable_with_nested_class(db0_fixture):
@db0.memo()
class Container:
def __init__(self, inner_obj):
self.inner = inner_obj

inner = MemoTestClass(func_from_other_module)
outer = Container(inner)

result = outer.inner.value(6)
assert result == 12

def test_callable_replacement(db0_fixture):
@db0.memo()
class DynamicCallable:
def __init__(self, func):
self.func = func

def set_func(self, new_func):
self.func = new_func

obj = DynamicCallable(simple_funcion)
assert obj.func(5) == 6

obj.set_func(func_from_other_module)
assert obj.func(5) == 10


def test_builtin_c_function_not_allowed(db0_fixture):
# A built-in C function should throw an exception
builtin_func = builtins.len

with pytest.raises(AttributeError) as exc_info:
obj = MemoTestClass(builtin_func)

def multi_arg_func(x, y, z):
return x + y + z

def test_callable_with_multiple_args(db0_fixture):
"""Test callable that accepts multiple arguments"""

obj = MemoTestClass(multi_arg_func)
value = obj.value(1, 2, 3)
assert value == 6

def decorator(func):
def wrapper(x):
return func(x) * 2
return wrapper

@decorator
def decorated_func(x):
return x + 1

def test_callable_with_decorator(db0_fixture):
"""Test storing a decorated function"""
with pytest.raises(AttributeError):
obj = MemoTestClass(decorated_func)

def test_callable_with_local_func(db0_fixture):
"""Test storing a decorated function"""
def some_func(x):
return x - 1
with pytest.raises(AttributeError):
obj = MemoTestClass(some_func)


def test_callable_list_storage(db0_fixture):
"""Test storing multiple callables in a list"""
@db0.memo()
class CallableList:
def __init__(self, funcs):
self.funcs = funcs

obj = CallableList([simple_funcion, func_from_other_module])
fnct = obj.funcs[0]
assert fnct(10) == 11

class Helper:
@staticmethod
def static_func(x):
return x * 3

def test_staticmethod_as_callable(db0_fixture):
"""Test storing a static method as callable"""
obj = MemoTestClass(Helper.static_func)
value = obj.value(4)
assert value == 12


class HelperClassMethod:
@classmethod
def class_func(cls, x):
return x * 4

def test_classmethod_should_fail(db0_fixture):
"""Test that class methods raise an exception"""
with pytest.raises(AttributeError):
obj = MemoTestClass(HelperClassMethod.class_func)

def test_callable_with_closure(db0_fixture):
"""Test function with closure variables"""
obj = MemoTestClass(closure_func)
value = obj.value(5)
assert value == 50


def test_callable_recursive_function(db0_fixture):
"""Test storing recursive function"""
obj = MemoTestClass(fibonacci)
value = obj.value(6)
assert value == 8

def test_callable_generator_function(db0_fixture):
"""Test storing generator function"""
obj = MemoTestClass(generator_func)
gen = obj.value(5)
result = list(gen)
assert result == [0, 1, 2, 3, 4]

def test_callable_partial_function(db0_fixture):
"""Test storing functools.partial object"""
partial_func = partial(multi_arg_func, 1, 2)

with pytest.raises(AttributeError):
obj = MemoTestClass(partial_func)

def test_callable_function_with_annotations(db0_fixture):
"""Test storing function with type annotations"""
obj = MemoTestClass(annotated_func)
value = obj.value(7)
assert value == 8
1 change: 1 addition & 0 deletions src/dbzero/bindings/TypeId.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace db0::bindings
BYTES_ARRAY = 16,
BOOLEAN = 17,
DECIMAL = 18,
FUNCTION = 19,
// dbzero wrappers of common language types
MEMO_OBJECT = 100,
DB0_LIST = 101,
Expand Down
111 changes: 111 additions & 0 deletions src/dbzero/bindings/python/PyToolkit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,117 @@ namespace db0::python
PyToolkit::PyWorkspace PyToolkit::m_py_workspace;
SafeRMutex PyToolkit::m_api_mutex;

void PyToolkit::throwErrorWithPyErrorCheck(const std::string& message, const std::string& error_detail) {
if (PyErr_Occurred()) {
PyObject *ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyObject* str_repr = PyObject_Str(pvalue);
const char* error_msg = str_repr ? PyUnicode_AsUTF8(str_repr) : "Unknown Python error";
std::string error_str(error_msg);
Py_XDECREF(str_repr);
Py_XDECREF(ptype);
Py_XDECREF(pvalue);
Py_XDECREF(ptraceback);
THROWF(db0::InputException) << message << error_str << THROWF_END;
} else {
THROWF(db0::InputException) << message << error_detail << THROWF_END;
}
}

std::string PyToolkit::getFullyQualifiedName(ObjectPtr func_obj) {
if (!func_obj) {
THROWF(db0::InputException) << "Null function object" << THROWF_END;
}

// Reject bound/unbound methods
if (PyMethod_Check(func_obj)) {
THROWF(db0::InputException) << "Methods are not allowed as FUNCTION members" << THROWF_END;
}

// Reject built-in C functions
if (PyCFunction_Check(func_obj)) {
THROWF(db0::InputException) << "Built-in C functions are not allowed as FUNCTION members" << THROWF_END;
}

// Get function's __name__, __qualname__, and __module__
auto name_obj = Py_OWN(PyObject_GetAttrString(func_obj, "__name__"));
auto qualname = Py_OWN(PyObject_GetAttrString(func_obj, "__qualname__"));
auto module_obj = Py_OWN(PyObject_GetAttrString(func_obj, "__module__"));

if (!name_obj || !qualname || !module_obj) {
THROWF(db0::InputException) << "Failed to get function name, qualname, or module" << THROWF_END;
}

// Decode UTF-8 strings
const char* name_cstr = PyUnicode_AsUTF8(*name_obj);
const char* qual_cstr = PyUnicode_AsUTF8(*qualname);
const char* module_cstr = PyUnicode_AsUTF8(*module_obj);

if (!name_cstr || !qual_cstr || !module_cstr) {
THROWF(db0::InputException) << "Failed to decode function attributes as UTF-8" << THROWF_END;
}

// Reject lambdas
if (strcmp(name_cstr, "<lambda>") == 0) {
THROWF(db0::InputException) << "Lambda functions are not allowed as FUNCTION members" << THROWF_END;
}

// Reject decorated or nested functions (qualname contains <locals>)
if (strstr(qual_cstr, "<locals>") != nullptr) {
THROWF(db0::InputException) << "Decorated or nested functions are not allowed as FUNCTION members" << THROWF_END;
}

// Construct fully qualified name: module.qualname
std::stringstream fqn_ss;
fqn_ss << module_cstr << "." << qual_cstr;
return fqn_ss.str();
}

typename PyToolkit::ObjectSharedPtr PyToolkit::getFunctionFromFullyQualifiedName(const char* fqn, size_t size) {
// Make a copy to tokenize
char* copy = strndup(fqn, size);
if (!copy) {
THROWF(db0::InputException) << "Failed to unload CALLABLE: memory allocation failed" << THROWF_END;
}

// First token is the module root
char* p = strchr(copy, '.');
if (!p) { // No dot = not fully qualified
free(copy);
THROWF(db0::InputException) << "Failed to unload CALLABLE: not a fully qualified name" << THROWF_END;
}
*p = '\0';
const char* root = copy;

// Import the module
auto module = Py_OWN(PyImport_ImportModule(root));
if (!module) {
free(copy);
throwErrorWithPyErrorCheck("Failed to unload CALLABLE: ",
"could not import module");
}

auto obj = module; // Start walking attributes

char* attr = p + 1;
while (attr && *attr) {
char* dot = strchr(attr, '.');
if (dot) *dot = '\0';

auto next = Py_OWN(PyObject_GetAttrString(obj.get(), attr));

if (!next) { // Attribute missing
free(copy);
throwErrorWithPyErrorCheck("Failed to unload CALLABLE: ",
"attribute missing");
}
obj = next;
attr = dot ? dot + 1 : NULL;
}
free(copy);
return obj; // New ref; caller DECREFs
}

std::string PyToolkit::getTypeName(ObjectPtr py_object) {
return getTypeName(Py_TYPE(py_object));
}
Expand Down
9 changes: 9 additions & 0 deletions src/dbzero/bindings/python/PyToolkit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ namespace db0::python
// indicate failed operation with a specific value/code
static void setError(ObjectPtr err_obj, std::uint64_t err_value);

// Throw exception with Python error details if available
static void throwErrorWithPyErrorCheck(const std::string& message, const std::string& error_detail = "");

// Get fully qualified name of a Python function (validates and rejects invalid function types)
static std::string getFullyQualifiedName(ObjectPtr func_obj);

// Reconstruct a Python function from its fully qualified name
static ObjectSharedPtr getFunctionFromFullyQualifiedName(const char* fqn, size_t size);

// Check if the object has references from other language objects (other than LangCache)
static bool hasLangRefs(ObjectPtr);
// Check if there exist any references except specific number of external references
Expand Down
2 changes: 2 additions & 0 deletions src/dbzero/bindings/python/PyTypeManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ namespace db0::python
addStaticSimpleType(&PyBool_Type, TypeId::BOOLEAN);
addStaticSimpleType(Py_TYPE(Py_None), TypeId::NONE);
addStaticSimpleType(&PyUnicode_Type, TypeId::STRING);
addStaticSimpleType(&PyFunction_Type, TypeId::FUNCTION);

// add python list type
addStaticType(&PyList_Type, TypeId::LIST);
addStaticType(&PySet_Type, TypeId::SET);
Expand Down
Loading