diff --git a/comtypes/test/monikers_helper.py b/comtypes/test/monikers_helper.py index 6ad3671b..52a8e4ac 100644 --- a/comtypes/test/monikers_helper.py +++ b/comtypes/test/monikers_helper.py @@ -4,8 +4,12 @@ from comtypes import GUID, IUnknown # https://learn.microsoft.com/en-us/windows/win32/api/objidl/ne-objidl-mksys +MKSYS_GENERICCOMPOSITE = 1 +MKSYS_FILEMONIKER = 2 MKSYS_ITEMMONIKER = 4 +CLSID_CompositeMoniker = GUID("{00000309-0000-0000-c000-000000000046}") +CLSID_FileMoniker = GUID("{00000303-0000-0000-C000-000000000046}") CLSID_AntiMoniker = GUID("{00000305-0000-0000-c000-000000000046}") CLSID_ItemMoniker = GUID("{00000304-0000-0000-c000-000000000046}") @@ -15,6 +19,18 @@ _ole32 = OleDLL("ole32") +_CreateGenericComposite = _ole32.CreateGenericComposite +_CreateGenericComposite.argtypes = [ + POINTER(IUnknown), # pmkFirst + POINTER(IUnknown), # pmkRest + POINTER(POINTER(IUnknown)), # ppmkComposite +] +_CreateGenericComposite.restype = HRESULT + +_CreateFileMoniker = _ole32.CreateFileMoniker +_CreateFileMoniker.argtypes = [LPCOLESTR, POINTER(POINTER(IUnknown))] +_CreateFileMoniker.restype = HRESULT + _CreateItemMoniker = _ole32.CreateItemMoniker _CreateItemMoniker.argtypes = [LPCOLESTR, LPCOLESTR, POINTER(POINTER(IUnknown))] _CreateItemMoniker.restype = HRESULT @@ -28,4 +44,5 @@ _GetRunningObjectTable.restype = HRESULT # Common COM Errors from Moniker/Binding Context operations +MK_E_NEEDGENERIC = -2147221022 # 0x800401E2 MK_E_UNAVAILABLE = -2147221021 # 0x800401E3 diff --git a/comtypes/test/test_bctx.py b/comtypes/test/test_bctx.py index adb88ba8..7f0b364d 100644 --- a/comtypes/test/test_bctx.py +++ b/comtypes/test/test_bctx.py @@ -1,9 +1,9 @@ import contextlib import unittest from _ctypes import COMError -from ctypes import POINTER, byref +from ctypes import POINTER, byref, sizeof -from comtypes import GUID, hresult +from comtypes import GUID, hresult, tagBIND_OPTS2 from comtypes.client import CreateObject, GetModule from comtypes.test.monikers_helper import ( ROTFLAGS_ALLOWANYCLIENT, @@ -64,3 +64,80 @@ def test_returns_rot(self): rot_from_func.Revoke(dw_reg) # After revoking: should NOT be running again self.assertEqual(rot_from_bctx, rot_from_func) + + +class Test_Register_Revoke_Release_ObjectBound(unittest.TestCase): + def test_register_and_revoke(self): + bctx = _create_bctx() + vidctl = CreateObject(msvidctl.MSVidCtl, interface=msvidctl.IMSVidCtl) + # Binds the object to the bind context, ensuring it stays alive during + # the binding operation. + hr = bctx.RegisterObjectBound(vidctl) + self.assertEqual(hr, hresult.S_OK) + # At this point, `bctx` holds a reference to `vidctl`. + # Unlike `RegisterObjectParam`, there is no public API to retrieve + # objects registered via `RegisterObjectBound` from `IBindCtx`. + # Therefore, direct testing of `vidctl`'s accessibility via `bctx` + # after binding (similar to `GetObjectParam`) is not possible. + # Releases the reference to the object previously registered. + hr = bctx.RevokeObjectBound(vidctl) + self.assertEqual(hr, hresult.S_OK) + # `bctx` holds a reference to `vidctl` again. + # Releases all object references currently held by the bind context. + bctx.RegisterObjectBound(vidctl) + hr = bctx.ReleaseBoundObjects() + self.assertEqual(hr, hresult.S_OK) + + +class Test_Get_Register_Revoke_ObjectParam(unittest.TestCase): + def test_get_and_register_and_revoke(self): + bctx = _create_bctx() + key = str(GUID.create_new()) + vidctl = CreateObject(msvidctl.MSVidCtl, interface=msvidctl.IMSVidCtl) + # `GetObjectParam` should fail as it's NOT registered yet + with self.assertRaises(COMError) as cm: + bctx.GetObjectParam(key) + self.assertEqual(cm.exception.hresult, hresult.E_FAIL) + # Register object + hr = bctx.RegisterObjectParam(key, vidctl) + self.assertEqual(hr, hresult.S_OK) + # `GetObjectParam` should succeed now + ret_obj = bctx.GetObjectParam(key) + self.assertEqual(ret_obj.QueryInterface(msvidctl.IMSVidCtl), vidctl) + # Revoke object + hr = bctx.RevokeObjectParam(key) + self.assertEqual(hr, hresult.S_OK) + # `GetObjectParam` should fail again after revoke + with self.assertRaises(COMError) as cm: + bctx.GetObjectParam(key) + self.assertEqual(cm.exception.hresult, hresult.E_FAIL) + + +class Test_Set_Get_BindOptions(unittest.TestCase): + def test_set_get_bind_options(self): + bctx = _create_bctx() + # Create an instance of `BIND_OPTS2` and set some values. + # In comtypes, instances of Structure subclasses like `tagBIND_OPTS2` + # can be passed directly as arguments where COM methods expect a + # pointer to the structure. + hr = bctx.RemoteSetBindOptions( + tagBIND_OPTS2( + cbStruct=sizeof(tagBIND_OPTS2), + grfFlags=0x11223344, + grfMode=0x55667788, + dwTickCountDeadline=12345, + ) + ) + self.assertEqual(hr, hresult.S_OK) + # Create a new instance for retrieval. + # The `cbStruct` field is crucial in COM as it indicates the size of + # the structure to the COM component, allowing it to handle different + # versions of the structure (for backward and forward compatibility). + # https://learn.microsoft.com/en-us/windows/win32/api/objidl/nf-objidl-ibindctx-getbindoptions#notes-to-callers + bind_opts = tagBIND_OPTS2(cbStruct=sizeof(tagBIND_OPTS2)) + ret = bctx.RemoteGetBindOptions(bind_opts) + self.assertIsInstance(ret, tagBIND_OPTS2) + self.assertEqual(bind_opts.cbStruct, sizeof(tagBIND_OPTS2)) + self.assertEqual(bind_opts.grfFlags, 0x11223344) + self.assertEqual(bind_opts.grfMode, 0x55667788) + self.assertEqual(bind_opts.dwTickCountDeadline, 12345) diff --git a/comtypes/test/test_moniker.py b/comtypes/test/test_moniker.py index f4999c74..2709b7db 100644 --- a/comtypes/test/test_moniker.py +++ b/comtypes/test/test_moniker.py @@ -1,15 +1,29 @@ import contextlib +import ctypes +import os +import tempfile import unittest -from ctypes import POINTER, byref +from _ctypes import COMError +from ctypes import POINTER, WinDLL, byref +from ctypes.wintypes import DWORD, LPCWSTR, LPWSTR, MAX_PATH +from pathlib import Path from comtypes import GUID, hresult from comtypes.client import CreateObject, GetModule +from comtypes.persist import IPersistFile from comtypes.test.monikers_helper import ( + MK_E_NEEDGENERIC, + MKSYS_FILEMONIKER, + MKSYS_GENERICCOMPOSITE, MKSYS_ITEMMONIKER, ROTFLAGS_ALLOWANYCLIENT, CLSID_AntiMoniker, + CLSID_CompositeMoniker, + CLSID_FileMoniker, CLSID_ItemMoniker, _CreateBindCtx, + _CreateFileMoniker, + _CreateGenericComposite, _CreateItemMoniker, _GetRunningObjectTable, ) @@ -17,7 +31,37 @@ with contextlib.redirect_stdout(None): # supress warnings GetModule("msvidctl.dll") from comtypes.gen import MSVidCtlLib as msvidctl -from comtypes.gen.MSVidCtlLib import IBindCtx, IMoniker, IRunningObjectTable +from comtypes.gen.MSVidCtlLib import ( + IBindCtx, + IEnumMoniker, + IMoniker, + IRunningObjectTable, +) + +_kernel32 = WinDLL("kernel32") + +_GetLongPathNameW = _kernel32.GetLongPathNameW +_GetLongPathNameW.argtypes = [LPCWSTR, LPWSTR, DWORD] +_GetLongPathNameW.restype = DWORD + + +def _get_long_path_name(path: str) -> str: + """Converts a path to its long form using GetLongPathNameW.""" + buffer = ctypes.create_unicode_buffer(MAX_PATH) + length = _GetLongPathNameW(path, buffer, MAX_PATH) + return buffer.value[:length] + + +def _create_generic_composite(mk_first: IMoniker, mk_rest: IMoniker) -> IMoniker: + mon = POINTER(IMoniker)() + _CreateGenericComposite(mk_first, mk_rest, byref(mon)) + return mon # type: ignore + + +def _create_file_moniker(path: str) -> IMoniker: + mon = POINTER(IMoniker)() + _CreateFileMoniker(path, byref(mon)) + return mon # type: ignore def _create_item_moniker(delim: str, item: str) -> IMoniker: @@ -41,6 +85,35 @@ def _create_rot() -> IRunningObjectTable: class Test_IsSystemMoniker_GetDisplayName_Inverse(unittest.TestCase): + def test_generic_composite(self): + item_id1 = str(GUID.create_new()) + item_id2 = str(GUID.create_new()) + mon = _create_generic_composite( + _create_item_moniker("!", item_id1), + _create_item_moniker("!", item_id2), + ) + self.assertEqual(mon.IsSystemMoniker(), MKSYS_GENERICCOMPOSITE) + bctx = _create_bctx() + self.assertEqual(mon.GetDisplayName(bctx, None), f"!{item_id1}!{item_id2}") + self.assertEqual(mon.GetClassID(), CLSID_CompositeMoniker) + self.assertEqual(mon.Inverse().GetClassID(), CLSID_CompositeMoniker) + + def test_file(self): + with tempfile.NamedTemporaryFile() as f: + mon = _create_file_moniker(f.name) + self.assertEqual(mon.IsSystemMoniker(), MKSYS_FILEMONIKER) + bctx = _create_bctx() + self.assertEqual( + os.path.normcase( + os.path.normpath( + _get_long_path_name(mon.GetDisplayName(bctx, None)) + ) + ), + os.path.normcase(os.path.normpath(_get_long_path_name(f.name))), + ) + self.assertEqual(mon.GetClassID(), CLSID_FileMoniker) + self.assertEqual(mon.Inverse().GetClassID(), CLSID_AntiMoniker) + def test_item(self): item_id = str(GUID.create_new()) mon = _create_item_moniker("!", item_id) @@ -51,6 +124,41 @@ def test_item(self): self.assertEqual(mon.Inverse().GetClassID(), CLSID_AntiMoniker) +class Test_ComposeWith(unittest.TestCase): + def test_item(self): + item_id = str(GUID.create_new()) + mon = _create_item_moniker("!", item_id) + item_mon2 = _create_item_moniker("!", str(GUID.create_new())) + self.assertEqual( + mon.ComposeWith(item_mon2, False).GetClassID(), + CLSID_CompositeMoniker, + ) + with self.assertRaises(COMError) as cm: + mon.ComposeWith(item_mon2, True) + self.assertEqual(cm.exception.hresult, MK_E_NEEDGENERIC) + + +class Test_IsEqual(unittest.TestCase): + def test_item(self): + item_id = str(GUID.create_new()) + mon1 = _create_item_moniker("!", item_id) + mon2 = _create_item_moniker("!", item_id) # Should be equal + mon3 = _create_item_moniker("!", str(GUID.create_new())) # Should not be equal + self.assertEqual(mon1.IsEqual(mon2), hresult.S_OK) + self.assertEqual(mon1.IsEqual(mon3), hresult.S_FALSE) + + +class Test_Hash(unittest.TestCase): + def test_item(self): + item_id = str(GUID.create_new()) + mon1 = _create_item_moniker("!", item_id) + mon2 = _create_item_moniker("!", item_id) # Should be equal + mon3 = _create_item_moniker("!", str(GUID.create_new())) # Should not be equal + self.assertEqual(mon1.Hash(), mon2.Hash()) + self.assertNotEqual(mon1.Hash(), mon3.Hash()) + self.assertNotEqual(mon2.Hash(), mon3.Hash()) + + class Test_IsRunning(unittest.TestCase): def test_item(self): vidctl = CreateObject(msvidctl.MSVidCtl, interface=msvidctl.IMSVidCtl) @@ -66,3 +174,96 @@ def test_item(self): rot.Revoke(dw_reg) # After revoking: should NOT be running again self.assertEqual(mon.IsRunning(bctx, None, None), hresult.S_FALSE) + + +class Test_CommonPrefixWith(unittest.TestCase): + def test_file(self): + bctx = _create_bctx() + # Create temporary directories and files for realistic File Monikers + with tempfile.TemporaryDirectory() as t: + tmpdir = Path(t) + dir_a = tmpdir / "dir_a" + dir_b = tmpdir / "dir_a" / "dir_b" + dir_b.mkdir(parents=True) + file1 = dir_a / "file1.txt" + file2 = dir_b / "file2.txt" + file3 = tmpdir / "file3.txt" + mon1 = _create_file_moniker(str(file1)) # tmpdir/dir_a/file1.txt + mon2 = _create_file_moniker(str(file2)) # tmpdir/dir_a/dir_b/file2.txt + mon3 = _create_file_moniker(str(file3)) # tmpdir/file3.txt + # Common prefix between mon1 and mon2 (tmpdir/dir_a) + self.assertEqual( + os.path.normcase( + os.path.normpath( + mon1.CommonPrefixWith(mon2).GetDisplayName(bctx, None) + ) + ), + os.path.normcase(os.path.normpath(dir_a)), + ) + # Common prefix between mon1 and mon3 (tmpdir) + self.assertEqual( + os.path.normcase( + os.path.normpath( + mon1.CommonPrefixWith(mon3).GetDisplayName(bctx, None) + ) + ), + os.path.normcase(os.path.normpath(tmpdir)), + ) + + +class Test_RelativePathTo(unittest.TestCase): + def test_file(self): + bctx = _create_bctx() + with tempfile.TemporaryDirectory() as t: + tmpdir = Path(t) + dir_a = tmpdir / "dir_a" + dir_b = tmpdir / "dir_b" + dir_a.mkdir() + dir_b.mkdir() + file1 = dir_a / "file1.txt" + file2 = dir_b / "file2.txt" + mon_from = _create_file_moniker(str(file1)) # tmpdir/dir_a/file1.txt + mon_to = _create_file_moniker(str(file2)) # tmpdir/dir_b/file2.txt + # The COM API returns paths with backslashes on Windows, so we normalize. + self.assertEqual( + # Check the display name of the relative moniker + # The moniker's `RelativePathTo` method calculates the path from + # the base of the `mon_from` to the target `mon_to`. + os.path.normcase( + os.path.normpath( + mon_from.RelativePathTo(mon_to).GetDisplayName(bctx, None) + ) + ), + # Calculate the relative path from the directory of file1 to file2 + os.path.normcase(os.path.normpath("..\\..\\dir_b\\file2.txt")), + ) + + +class Test_Enum(unittest.TestCase): + def test_generic_composite(self): + item_id1 = str(GUID.create_new()) + item_id2 = str(GUID.create_new()) + item_mon1 = _create_item_moniker("!", item_id1) + item_mon2 = _create_item_moniker("!", item_id2) + # Create a composite moniker to ensure multiple elements for enumeration + comp_mon = _create_generic_composite(item_mon1, item_mon2) + enum_moniker = comp_mon.Enum(True) # True for forward enumeration + self.assertIsInstance(enum_moniker, IEnumMoniker) + + +class Test_RemoteBindToObject(unittest.TestCase): + def test_file(self): + bctx = _create_bctx() + with tempfile.TemporaryDirectory() as t: + tmpdir = Path(t) + tmpfile = tmpdir / "tmp.lnk" + tmpfile.touch() + mon = _create_file_moniker(str(tmpfile)) + bound_obj = mon.RemoteBindToObject(bctx, None, IPersistFile._iid_) + pf = bound_obj.QueryInterface(IPersistFile) + self.assertEqual( + os.path.normcase(os.path.normpath(_get_long_path_name(str(tmpfile)))), + os.path.normcase( + os.path.normpath(_get_long_path_name(pf.GetCurFile())) + ), + ) diff --git a/comtypes/test/test_rot.py b/comtypes/test/test_rot.py index 88813ad1..e230f6b5 100644 --- a/comtypes/test/test_rot.py +++ b/comtypes/test/test_rot.py @@ -16,7 +16,12 @@ with contextlib.redirect_stdout(None): # supress warnings GetModule("msvidctl.dll") from comtypes.gen import MSVidCtlLib as msvidctl -from comtypes.gen.MSVidCtlLib import IBindCtx, IMoniker, IRunningObjectTable +from comtypes.gen.MSVidCtlLib import ( + IBindCtx, + IEnumMoniker, + IMoniker, + IRunningObjectTable, +) def _create_item_moniker(delim: str, item: str) -> IMoniker: @@ -62,3 +67,10 @@ def test_item(self): with self.assertRaises(COMError) as cm: rot.GetObject(mon) self.assertEqual(cm.exception.hresult, MK_E_UNAVAILABLE) + + +class Test_EnumRunning(unittest.TestCase): + def test_returns_enum_moniker(self): + rot = _create_rot() + enum_moniker = rot.EnumRunning() + self.assertIsInstance(enum_moniker, IEnumMoniker)